diff --git a/cmd/cloudflared/updater/workers_service.go b/cmd/cloudflared/updater/workers_service.go index 9e41eabd..23a40968 100644 --- a/cmd/cloudflared/updater/workers_service.go +++ b/cmd/cloudflared/updater/workers_service.go @@ -5,8 +5,6 @@ import ( "errors" "net/http" "runtime" - "strconv" - "strings" ) // Options are the update options supported by the @@ -28,6 +26,7 @@ type VersionResponse struct { Checksum string `json:"checksum"` IsCompressed bool `json:"compressed"` UserMessage string `json:"userMessage"` + ShouldUpdate bool `json:"shouldUpdate"` Error string `json:"error"` } @@ -86,71 +85,10 @@ func (s *WorkersService) Check() (CheckResult, error) { return nil, errors.New(v.Error) } - var versionToUpdate = "" - if s.opts.IsForced || IsNewerVersion(s.currentVersion, v.Version) { + versionToUpdate := "" + if v.ShouldUpdate { versionToUpdate = v.Version } return NewWorkersVersion(v.URL, versionToUpdate, v.Checksum, s.targetPath, v.UserMessage, v.IsCompressed), nil } - -// IsNewerVersion checks semantic versioning for the latest version -// cloudflared tagging is more of a date than a semantic version, -// but the same comparision logic still holds for major.minor.patch -// e.g. 2020.8.2 is newer than 2020.8.1. -func IsNewerVersion(current string, check string) bool { - if strings.Contains(strings.ToLower(current), "dev") { - return false // dev builds shouldn't update - } - - cMajor, cMinor, cPatch, err := SemanticParts(current) - if err != nil { - return false - } - - nMajor, nMinor, nPatch, err := SemanticParts(check) - if err != nil { - return false - } - - if nMajor > cMajor { - return true - } - - if nMajor == cMajor && nMinor > cMinor { - return true - } - - if nMajor == cMajor && nMinor == cMinor && nPatch > cPatch { - return true - } - return false -} - -// SemanticParts gets the major, minor, and patch version of a semantic version string -// e.g. 3.1.2 would return 3, 1, 2, nil -func SemanticParts(version string) (major int, minor int, patch int, err error) { - major = 0 - minor = 0 - patch = 0 - parts := strings.Split(version, ".") - if len(parts) != 3 { - err = errors.New("invalid version") - return - } - major, err = strconv.Atoi(parts[0]) - if err != nil { - return - } - - minor, err = strconv.Atoi(parts[1]) - if err != nil { - return - } - - patch, err = strconv.Atoi(parts[2]) - if err != nil { - return - } - return -} diff --git a/cmd/cloudflared/updater/workers_service_test.go b/cmd/cloudflared/updater/workers_service_test.go index 0476aa4a..ccf95c89 100644 --- a/cmd/cloudflared/updater/workers_service_test.go +++ b/cmd/cloudflared/updater/workers_service_test.go @@ -6,6 +6,7 @@ import ( "compress/gzip" "crypto/sha256" "encoding/json" + "errors" "fmt" "io/ioutil" "log" @@ -13,6 +14,8 @@ import ( "net/http/httptest" "os" "runtime" + "strconv" + "strings" "testing" "github.com/stretchr/testify/require" @@ -68,8 +71,11 @@ func updateHandler(w http.ResponseWriter, r *http.Request) { if query.Get(ClientVersionName) == knownBuggyVersion { userMessage = expectedUserMsg } + shouldUpdate := requestedVersion != "" || IsNewerVersion(query.Get(ClientVersionName), version) - v := VersionResponse{URL: url, Version: version, Checksum: checksum, UserMessage: userMessage} + v := VersionResponse{ + URL: url, Version: version, Checksum: checksum, UserMessage: userMessage, ShouldUpdate: shouldUpdate, + } respondWithJSON(w, v, http.StatusOK) } @@ -81,7 +87,7 @@ func gzipUpdateHandler(w http.ResponseWriter, r *http.Request) { checksum := fmt.Sprintf("%x", h.Sum(nil)) url := fmt.Sprintf("http://%s/gzip-download.tgz", r.Host) - v := VersionResponse{URL: url, Version: version, Checksum: checksum} + v := VersionResponse{URL: url, Version: version, Checksum: checksum, ShouldUpdate: true} respondWithJSON(w, v, http.StatusOK) } @@ -122,6 +128,64 @@ func failureHandler(w http.ResponseWriter, r *http.Request) { respondWithJSON(w, VersionResponse{Error: "unsupported os and architecture"}, http.StatusBadRequest) } +func IsNewerVersion(current string, check string) bool { + if current == "" || check == "" { + return false + } + if strings.Contains(strings.ToLower(current), "dev") { + return false // dev builds shouldn't update + } + + cMajor, cMinor, cPatch, err := SemanticParts(current) + if err != nil { + return false + } + + nMajor, nMinor, nPatch, err := SemanticParts(check) + if err != nil { + return false + } + + if nMajor > cMajor { + return true + } + + if nMajor == cMajor && nMinor > cMinor { + return true + } + + if nMajor == cMajor && nMinor == cMinor && nPatch > cPatch { + return true + } + return false +} + +func SemanticParts(version string) (major int, minor int, patch int, err error) { + major = 0 + minor = 0 + patch = 0 + parts := strings.Split(version, ".") + if len(parts) != 3 { + err = errors.New("invalid version") + return + } + major, err = strconv.Atoi(parts[0]) + if err != nil { + return + } + + minor, err = strconv.Atoi(parts[1]) + if err != nil { + return + } + + patch, err = strconv.Atoi(parts[2]) + if err != nil { + return + } + return +} + func createServer() *httptest.Server { mux := http.NewServeMux() mux.HandleFunc("/updater", updateHandler) @@ -285,15 +349,3 @@ func TestUpdateWhenRunningKnownBuggyVersion(t *testing.T) { require.Equal(t, v.Version(), mostRecentVersion) require.Equal(t, v.UserMessage(), expectedUserMsg) } - -func TestVersionParsing(t *testing.T) { - require.False(t, IsNewerVersion("2020.8.2", "2020.8.2")) - require.True(t, IsNewerVersion("2020.8.2", "2020.8.3")) - require.True(t, IsNewerVersion("2020.8.2", "2021.1.2")) - require.True(t, IsNewerVersion("2020.8.2", "2020.9.1")) - require.True(t, IsNewerVersion("2020.8.2", "2020.12.45")) - require.False(t, IsNewerVersion("2020.8.2", "2020.6.3")) - require.False(t, IsNewerVersion("DEV", "2020.8.5")) - require.False(t, IsNewerVersion("2020.8.2", "asdlkfjasdf")) - require.True(t, IsNewerVersion("3.0.1", "4.2.1")) -}