TUN-3989: Check in with Updater service in more situations and convey messages to user

This commit is contained in:
Nuno Diegues 2021-02-28 23:24:38 +00:00
parent 5c7b451e17
commit bcd71b56e9
11 changed files with 245 additions and 107 deletions

View File

@ -16,6 +16,9 @@
- Nested commands, such as `cloudflared tunnel run`, now consider CLI arguments even if they appear earlier on the - Nested commands, such as `cloudflared tunnel run`, now consider CLI arguments even if they appear earlier on the
command. For instance, `cloudflared --config config.yaml tunnel run` will now behave the same as command. For instance, `cloudflared --config config.yaml tunnel run` will now behave the same as
`cloudflared tunnel --config config.yaml run` `cloudflared tunnel --config config.yaml run`
- Warnings are now shown in the output logs whenever cloudflared is running without the most recent version and
`no-autoupdate` is `true`.
### Bug Fixes ### Bug Fixes

View File

@ -286,16 +286,14 @@ func StartServer(
} }
// update needs to be after DNS proxy is up to resolve equinox server address // update needs to be after DNS proxy is up to resolve equinox server address
if updater.IsAutoupdateEnabled(c, log) { wg.Add(1)
autoupdateFreq := c.Duration("autoupdate-freq") go func() {
log.Info().Dur("autoupdateFreq", autoupdateFreq).Msg("Autoupdate frequency is set") defer wg.Done()
wg.Add(1) autoupdater := updater.NewAutoUpdater(
go func() { c.Bool("no-autoupdate"), c.Duration("autoupdate-freq"), &listeners, log,
defer wg.Done() )
autoupdater := updater.NewAutoUpdater(c.Duration("autoupdate-freq"), &listeners, log) errC <- autoupdater.Run(ctx)
errC <- autoupdater.Run(ctx) }()
}()
}
// Serve DNS proxy stand-alone if no hostname or tag or app is going to run // Serve DNS proxy stand-alone if no hostname or tag or app is going to run
if dnsProxyStandAlone(c) { if dnsProxyStandAlone(c) {

View File

@ -13,6 +13,7 @@ import (
"text/tabwriter" "text/tabwriter"
"time" "time"
"github.com/cloudflare/cloudflared/cmd/cloudflared/updater"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/mitchellh/go-homedir" "github.com/mitchellh/go-homedir"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -149,6 +150,9 @@ func createCommand(c *cli.Context) error {
} }
name := c.Args().First() name := c.Args().First()
warningChecker := updater.StartWarningCheck(c)
defer warningChecker.LogWarningIfAny(sc.log)
_, err = sc.create(name, c.String(CredFileFlag)) _, err = sc.create(name, c.String(CredFileFlag))
return errors.Wrap(err, "failed to create tunnel") return errors.Wrap(err, "failed to create tunnel")
} }
@ -227,6 +231,9 @@ func listCommand(c *cli.Context) error {
filter.ByTunnelID(tunnelID) filter.ByTunnelID(tunnelID)
} }
warningChecker := updater.StartWarningCheck(c)
defer warningChecker.LogWarningIfAny(sc.log)
tunnels, err := sc.list(filter) tunnels, err := sc.list(filter)
if err != nil { if err != nil {
return err return err
@ -271,6 +278,7 @@ func listCommand(c *cli.Context) error {
} else { } else {
fmt.Println("You have no tunnels, use 'cloudflared tunnel create' to define a new tunnel") fmt.Println("You have no tunnels, use 'cloudflared tunnel create' to define a new tunnel")
} }
return nil return nil
} }
@ -349,6 +357,9 @@ func deleteCommand(c *cli.Context) error {
return cliutil.UsageError(`"cloudflared tunnel delete" requires at least 1 argument, the ID or name of the tunnel to delete.`) return cliutil.UsageError(`"cloudflared tunnel delete" requires at least 1 argument, the ID or name of the tunnel to delete.`)
} }
warningChecker := updater.StartWarningCheck(c)
defer warningChecker.LogWarningIfAny(sc.log)
tunnelIDs, err := sc.findIDs(c.Args().Slice()) tunnelIDs, err := sc.findIDs(c.Args().Slice())
if err != nil { if err != nil {
return err return err

View File

@ -7,6 +7,7 @@ import (
"text/tabwriter" "text/tabwriter"
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil" "github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
"github.com/cloudflare/cloudflared/cmd/cloudflared/updater"
"github.com/cloudflare/cloudflared/teamnet" "github.com/cloudflare/cloudflared/teamnet"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -81,6 +82,9 @@ func showRoutesCommand(c *cli.Context) error {
return errors.Wrap(err, "invalid config for routing filters") return errors.Wrap(err, "invalid config for routing filters")
} }
warningChecker := updater.StartWarningCheck(c)
defer warningChecker.LogWarningIfAny(sc.log)
routes, err := sc.listRoutes(filter) routes, err := sc.listRoutes(filter)
if err != nil { if err != nil {
return err return err
@ -95,6 +99,7 @@ func showRoutesCommand(c *cli.Context) error {
} else { } else {
fmt.Println("You have no routes, use 'cloudflared tunnel route ip add' to add a route") fmt.Println("You have no routes, use 'cloudflared tunnel route ip add' to add a route")
} }
return nil return nil
} }

View File

@ -0,0 +1,49 @@
package updater
import (
"github.com/rs/zerolog"
"github.com/urfave/cli/v2"
)
type VersionWarningChecker struct {
warningChan chan string
}
func StartWarningCheck(c *cli.Context) VersionWarningChecker {
checker := VersionWarningChecker{
warningChan: make(chan string),
}
go func() {
options := updateOptions{
updateDisabled: true,
isBeta: c.Bool("beta"),
isStaging: c.Bool("staging"),
isForced: false,
intendedVersion: "",
}
checkResult, err := CheckForUpdate(options)
if err == nil {
checker.warningChan <- checkResult.UserMessage()
}
close(checker.warningChan)
}()
return checker
}
func (checker VersionWarningChecker) getWarning() string {
select {
case message := <-checker.warningChan:
return message
default:
// No feedback on time, we don't wait for it, since this is best-effort.
return ""
}
}
func (checker VersionWarningChecker) LogWarningIfAny(log *zerolog.Logger) {
if warning := checker.getWarning(); warning != "" {
log.Warn().Msg(warning)
}
}

View File

@ -1,14 +1,15 @@
package updater package updater
// Version is the functions needed to perform an update // CheckResult is the behaviour resulting from checking in with the Update Service
type Version interface { type CheckResult interface {
Apply() error Apply() error
String() string Version() string
UserMessage() string
} }
// Service is the functions to get check for new updates // Service is the functions to get check for new updates
type Service interface { type Service interface {
Check() (Version, error) Check() (CheckResult, error)
} }
const ( const (
@ -23,4 +24,7 @@ const (
// VersionKeyName is the url parameter key to send to the checkin API to specific what version to upgrade or downgrade to // VersionKeyName is the url parameter key to send to the checkin API to specific what version to upgrade or downgrade to
VersionKeyName = "version" VersionKeyName = "version"
// ClientVersionName is the url parameter key to send the version that this cloudflared is currently running with
ClientVersionName = "clientVersion"
) )

View File

@ -3,18 +3,17 @@ package updater
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/rs/zerolog"
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
"time" "time"
"github.com/urfave/cli/v2"
"golang.org/x/crypto/ssh/terminal"
"github.com/cloudflare/cloudflared/cmd/cloudflared/config" "github.com/cloudflare/cloudflared/cmd/cloudflared/config"
"github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/logger"
"github.com/facebookgo/grace/gracenet" "github.com/facebookgo/grace/gracenet"
"github.com/rs/zerolog"
"github.com/urfave/cli/v2"
"golang.org/x/crypto/ssh/terminal"
) )
const ( const (
@ -61,16 +60,18 @@ func (e *statusErr) ExitCode() int {
} }
type updateOptions struct { type updateOptions struct {
isBeta bool updateDisabled bool
isStaging bool isBeta bool
isForced bool isStaging bool
version string isForced bool
intendedVersion string
} }
type UpdateOutcome struct { type UpdateOutcome struct {
Updated bool Updated bool
Version string Version string
Error error UserMessage string
Error error
} }
func (uo *UpdateOutcome) noUpdate() bool { func (uo *UpdateOutcome) noUpdate() bool {
@ -81,10 +82,10 @@ func Init(v string) {
version = v version = v
} }
func checkForUpdateAndApply(options updateOptions) UpdateOutcome { func CheckForUpdate(options updateOptions) (CheckResult, error) {
cfdPath, err := os.Executable() cfdPath, err := os.Executable()
if err != nil { if err != nil {
return UpdateOutcome{Error: err} return nil, err
} }
url := UpdateURL url := UpdateURL
@ -93,24 +94,22 @@ func checkForUpdateAndApply(options updateOptions) UpdateOutcome {
} }
s := NewWorkersService(version, url, cfdPath, Options{IsBeta: options.isBeta, s := NewWorkersService(version, url, cfdPath, Options{IsBeta: options.isBeta,
IsForced: options.isForced, RequestedVersion: options.version}) IsForced: options.isForced, RequestedVersion: options.intendedVersion})
v, err := s.Check() return s.Check()
}
func applyUpdate(options updateOptions, update CheckResult) UpdateOutcome {
if update.Version() == "" || options.updateDisabled {
return UpdateOutcome{UserMessage: update.UserMessage()}
}
err := update.Apply()
if err != nil { if err != nil {
return UpdateOutcome{Error: err} return UpdateOutcome{Error: err}
} }
//already on the latest version return UpdateOutcome{Updated: true, Version: update.Version(), UserMessage: update.UserMessage()}
if v == nil {
return UpdateOutcome{}
}
err = v.Apply()
if err != nil {
return UpdateOutcome{Error: err}
}
return UpdateOutcome{Updated: true, Version: v.String()}
} }
// Update is the handler for the update command from the command line // Update is the handler for the update command from the command line
@ -137,7 +136,13 @@ func Update(c *cli.Context) error {
log.Info().Msg("cloudflared is set to upgrade to the latest publish version regardless of the current version") log.Info().Msg("cloudflared is set to upgrade to the latest publish version regardless of the current version")
} }
updateOutcome := loggedUpdate(log, updateOptions{isBeta: isBeta, isStaging: isStaging, isForced: isForced, version: c.String("version")}) updateOutcome := loggedUpdate(log, updateOptions{
updateDisabled: false,
isBeta: isBeta,
isStaging: isStaging,
isForced: isForced,
intendedVersion: c.String("version"),
})
if updateOutcome.Error != nil { if updateOutcome.Error != nil {
return &statusErr{updateOutcome.Error} return &statusErr{updateOutcome.Error}
} }
@ -152,12 +157,18 @@ func Update(c *cli.Context) error {
// Checks for an update and applies it if one is available // Checks for an update and applies it if one is available
func loggedUpdate(log *zerolog.Logger, options updateOptions) UpdateOutcome { func loggedUpdate(log *zerolog.Logger, options updateOptions) UpdateOutcome {
updateOutcome := checkForUpdateAndApply(options) checkResult, err := CheckForUpdate(options)
if err != nil {
log.Err(err).Msg("update check failed")
return UpdateOutcome{Error: err}
}
updateOutcome := applyUpdate(options, checkResult)
if updateOutcome.Updated { if updateOutcome.Updated {
log.Info().Str(LogFieldVersion, updateOutcome.Version).Msg("cloudflared has been updated") log.Info().Str(LogFieldVersion, updateOutcome.Version).Msg("cloudflared has been updated")
} }
if updateOutcome.Error != nil { if updateOutcome.Error != nil {
log.Err(updateOutcome.Error).Msg("update check failed: %s") log.Err(updateOutcome.Error).Msg("update failed to apply")
} }
return updateOutcome return updateOutcome
@ -177,44 +188,53 @@ type configurable struct {
freq time.Duration freq time.Duration
} }
func NewAutoUpdater(freq time.Duration, listeners *gracenet.Net, log *zerolog.Logger) *AutoUpdater { func NewAutoUpdater(updateDisabled bool, freq time.Duration, listeners *gracenet.Net, log *zerolog.Logger) *AutoUpdater {
updaterConfigurable := &configurable{
enabled: true,
freq: freq,
}
if freq == 0 {
updaterConfigurable.enabled = false
updaterConfigurable.freq = DefaultCheckUpdateFreq
}
return &AutoUpdater{ return &AutoUpdater{
configurable: updaterConfigurable, configurable: createUpdateConfig(updateDisabled, freq, log),
listeners: listeners, listeners: listeners,
updateConfigChan: make(chan *configurable), updateConfigChan: make(chan *configurable),
log: log, log: log,
} }
} }
func createUpdateConfig(updateDisabled bool, freq time.Duration, log *zerolog.Logger) *configurable {
if isAutoupdateEnabled(log, updateDisabled, freq) {
log.Info().Dur("autoupdateFreq", freq).Msg("Autoupdate frequency is set")
return &configurable{
enabled: true,
freq: freq,
}
} else {
return &configurable{
enabled: false,
freq: DefaultCheckUpdateFreq,
}
}
}
func (a *AutoUpdater) Run(ctx context.Context) error { func (a *AutoUpdater) Run(ctx context.Context) error {
ticker := time.NewTicker(a.configurable.freq) ticker := time.NewTicker(a.configurable.freq)
for { for {
if a.configurable.enabled { updateOutcome := loggedUpdate(a.log, updateOptions{updateDisabled: !a.configurable.enabled})
updateOutcome := loggedUpdate(a.log, updateOptions{}) if updateOutcome.Updated {
if updateOutcome.Updated { Init(updateOutcome.Version)
if IsSysV() { if IsSysV() {
// SysV doesn't have a mechanism to keep service alive, we have to restart the process // SysV doesn't have a mechanism to keep service alive, we have to restart the process
a.log.Info().Msg("Restarting service managed by SysV...") a.log.Info().Msg("Restarting service managed by SysV...")
pid, err := a.listeners.StartProcess() pid, err := a.listeners.StartProcess()
if err != nil { if err != nil {
a.log.Err(err).Msg("Unable to restart server automatically") a.log.Err(err).Msg("Unable to restart server automatically")
return &statusErr{err: err} return &statusErr{err: err}
}
// stop old process after autoupdate. Otherwise we create a new process
// after each update
a.log.Info().Msgf("PID of the new process is %d", pid)
} }
return &statusSuccess{newVersion: updateOutcome.Version} // stop old process after autoupdate. Otherwise we create a new process
// after each update
a.log.Info().Msgf("PID of the new process is %d", pid)
} }
return &statusSuccess{newVersion: updateOutcome.Version}
} else if updateOutcome.UserMessage != "" {
a.log.Warn().Msg(updateOutcome.UserMessage)
} }
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
@ -229,27 +249,18 @@ func (a *AutoUpdater) Run(ctx context.Context) error {
} }
// Update is the method to pass new AutoUpdaterConfigurable to a running AutoUpdater. It is safe to be called concurrently // Update is the method to pass new AutoUpdaterConfigurable to a running AutoUpdater. It is safe to be called concurrently
func (a *AutoUpdater) Update(newFreq time.Duration) { func (a *AutoUpdater) Update(updateDisabled bool, newFreq time.Duration) {
newConfigurable := &configurable{ a.updateConfigChan <- createUpdateConfig(updateDisabled, newFreq, a.log)
enabled: true,
freq: newFreq,
}
// A ero duration means autoupdate is disabled
if newFreq == 0 {
newConfigurable.enabled = false
newConfigurable.freq = DefaultCheckUpdateFreq
}
a.updateConfigChan <- newConfigurable
} }
func IsAutoupdateEnabled(c *cli.Context, log *zerolog.Logger) bool { func isAutoupdateEnabled(log *zerolog.Logger, updateDisabled bool, updateFreq time.Duration) bool {
if !SupportAutoUpdate(log) { if !supportAutoUpdate(log) {
return false return false
} }
return !c.Bool("no-autoupdate") && c.Duration("autoupdate-freq") != 0 return !updateDisabled && updateFreq != 0
} }
func SupportAutoUpdate(log *zerolog.Logger) bool { func supportAutoUpdate(log *zerolog.Logger) bool {
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
log.Info().Msg(noUpdateOnWindowsMessage) log.Info().Msg(noUpdateOnWindowsMessage)
return false return false

View File

@ -2,17 +2,19 @@ package updater
import ( import (
"context" "context"
"flag"
"testing" "testing"
"github.com/facebookgo/grace/gracenet" "github.com/facebookgo/grace/gracenet"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/urfave/cli/v2"
) )
func TestDisabledAutoUpdater(t *testing.T) { func TestDisabledAutoUpdater(t *testing.T) {
listeners := &gracenet.Net{} listeners := &gracenet.Net{}
log := zerolog.Nop() log := zerolog.Nop()
autoupdater := NewAutoUpdater(0, listeners, &log) autoupdater := NewAutoUpdater(false, 0, listeners, &log)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
errC := make(chan error) errC := make(chan error)
go func() { go func() {
@ -26,3 +28,13 @@ func TestDisabledAutoUpdater(t *testing.T) {
// Make sure that autoupdater terminates after canceling the context // Make sure that autoupdater terminates after canceling the context
assert.Equal(t, context.Canceled, <-errC) assert.Equal(t, context.Canceled, <-errC)
} }
func TestCheckInWithUpdater(t *testing.T) {
flagSet := flag.NewFlagSet(t.Name(), flag.PanicOnError)
cliCtx := cli.NewContext(cli.NewApp(), flagSet, nil)
warningChecker := StartWarningCheck(cliCtx)
warning := warningChecker.getWarning()
// Assuming this runs either on a release or development version, then the Worker will never have anything to tell us.
assert.Empty(t, warning)
}

View File

@ -27,6 +27,7 @@ type VersionResponse struct {
Version string `json:"version"` Version string `json:"version"`
Checksum string `json:"checksum"` Checksum string `json:"checksum"`
IsCompressed bool `json:"compressed"` IsCompressed bool `json:"compressed"`
UserMessage string `json:"userMessage"`
Error string `json:"error"` Error string `json:"error"`
} }
@ -50,7 +51,7 @@ func NewWorkersService(currentVersion, url, targetPath string, opts Options) Ser
} }
// Check does a check in with the Workers API to get a new version update // Check does a check in with the Workers API to get a new version update
func (s *WorkersService) Check() (Version, error) { func (s *WorkersService) Check() (CheckResult, error) {
client := &http.Client{ client := &http.Client{
Timeout: clientTimeout, Timeout: clientTimeout,
} }
@ -59,6 +60,7 @@ func (s *WorkersService) Check() (Version, error) {
q := req.URL.Query() q := req.URL.Query()
q.Add(OSKeyName, runtime.GOOS) q.Add(OSKeyName, runtime.GOOS)
q.Add(ArchitectureKeyName, runtime.GOARCH) q.Add(ArchitectureKeyName, runtime.GOARCH)
q.Add(ClientVersionName, s.currentVersion)
if s.opts.IsBeta { if s.opts.IsBeta {
q.Add(BetaKeyName, "true") q.Add(BetaKeyName, "true")
@ -84,11 +86,12 @@ func (s *WorkersService) Check() (Version, error) {
return nil, errors.New(v.Error) return nil, errors.New(v.Error)
} }
if !s.opts.IsForced && !IsNewerVersion(s.currentVersion, v.Version) { var versionToUpdate = ""
return nil, nil if s.opts.IsForced || IsNewerVersion(s.currentVersion, v.Version) {
versionToUpdate = v.Version
} }
return NewWorkersVersion(v.URL, v.Version, v.Checksum, s.targetPath, v.IsCompressed), nil return NewWorkersVersion(v.URL, versionToUpdate, v.Checksum, s.targetPath, v.UserMessage, v.IsCompressed), nil
} }
// IsNewerVersion checks semantic versioning for the latest version // IsNewerVersion checks semantic versioning for the latest version

View File

@ -32,15 +32,20 @@ func respondWithData(w http.ResponseWriter, b []byte, status int) {
w.Write(b) w.Write(b)
} }
const mostRecentVersion = "2021.2.5"
const mostRecentBetaVersion = "2021.3.0"
const knownBuggyVersion = "2020.12.0"
const expectedUserMsg = "This message is expected when running a known buggy version"
func updateHandler(w http.ResponseWriter, r *http.Request) { func updateHandler(w http.ResponseWriter, r *http.Request) {
version := "2020.08.05" version := mostRecentVersion
host := fmt.Sprintf("http://%s", r.Host) host := fmt.Sprintf("http://%s", r.Host)
url := host + "/download" url := host + "/download"
query := r.URL.Query() query := r.URL.Query()
if query.Get(BetaKeyName) == "true" { if query.Get(BetaKeyName) == "true" {
version = "2020.08.06" version = mostRecentBetaVersion
url = host + "/beta" url = host + "/beta"
} }
@ -59,7 +64,12 @@ func updateHandler(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(h, version) fmt.Fprint(h, version)
checksum := fmt.Sprintf("%x", h.Sum(nil)) checksum := fmt.Sprintf("%x", h.Sum(nil))
v := VersionResponse{URL: url, Version: version, Checksum: checksum} var userMessage = ""
if query.Get(ClientVersionName) == knownBuggyVersion {
userMessage = expectedUserMsg
}
v := VersionResponse{URL: url, Version: version, Checksum: checksum, UserMessage: userMessage}
respondWithJSON(w, v, http.StatusOK) respondWithJSON(w, v, http.StatusOK)
} }
@ -96,7 +106,7 @@ func compressedDownloadHandler(w http.ResponseWriter, r *http.Request) {
} }
func downloadHandler(w http.ResponseWriter, r *http.Request) { func downloadHandler(w http.ResponseWriter, r *http.Request) {
version := "2020.08.05" version := mostRecentVersion
requestedVersion := r.URL.Query().Get(VersionKeyName) requestedVersion := r.URL.Query().Get(VersionKeyName)
if requestedVersion != "" { if requestedVersion != "" {
version = requestedVersion version = requestedVersion
@ -105,7 +115,7 @@ func downloadHandler(w http.ResponseWriter, r *http.Request) {
} }
func betaHandler(w http.ResponseWriter, r *http.Request) { func betaHandler(w http.ResponseWriter, r *http.Request) {
respondWithData(w, []byte("2020.08.06"), http.StatusOK) respondWithData(w, []byte(mostRecentBetaVersion), http.StatusOK)
} }
func failureHandler(w http.ResponseWriter, r *http.Request) { func failureHandler(w http.ResponseWriter, r *http.Request) {
@ -124,7 +134,7 @@ func createServer() *httptest.Server {
} }
func createTestFile(t *testing.T, path string) { func createTestFile(t *testing.T, path string) {
f, err := os.Create("tmpfile") f, err := os.Create(path)
require.NoError(t, err) require.NoError(t, err)
fmt.Fprint(f, "2020.08.04") fmt.Fprint(f, "2020.08.04")
f.Close() f.Close()
@ -142,13 +152,13 @@ func TestUpdateService(t *testing.T) {
s := NewWorkersService("2020.8.2", fmt.Sprintf("%s/updater", ts.URL), testFilePath, Options{}) s := NewWorkersService("2020.8.2", fmt.Sprintf("%s/updater", ts.URL), testFilePath, Options{})
v, err := s.Check() v, err := s.Check()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, v.String(), "2020.08.05") require.Equal(t, v.Version(), mostRecentVersion)
require.NoError(t, v.Apply()) require.NoError(t, v.Apply())
dat, err := ioutil.ReadFile(testFilePath) dat, err := ioutil.ReadFile(testFilePath)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, string(dat), "2020.08.05") require.Equal(t, string(dat), mostRecentVersion)
} }
func TestBetaUpdateService(t *testing.T) { func TestBetaUpdateService(t *testing.T) {
@ -162,13 +172,13 @@ func TestBetaUpdateService(t *testing.T) {
s := NewWorkersService("2020.8.2", fmt.Sprintf("%s/updater", ts.URL), testFilePath, Options{IsBeta: true}) s := NewWorkersService("2020.8.2", fmt.Sprintf("%s/updater", ts.URL), testFilePath, Options{IsBeta: true})
v, err := s.Check() v, err := s.Check()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, v.String(), "2020.08.06") require.Equal(t, v.Version(), mostRecentBetaVersion)
require.NoError(t, v.Apply()) require.NoError(t, v.Apply())
dat, err := ioutil.ReadFile(testFilePath) dat, err := ioutil.ReadFile(testFilePath)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, string(dat), "2020.08.06") require.Equal(t, string(dat), mostRecentBetaVersion)
} }
func TestFailUpdateService(t *testing.T) { func TestFailUpdateService(t *testing.T) {
@ -193,10 +203,11 @@ func TestNoUpdateService(t *testing.T) {
createTestFile(t, testFilePath) createTestFile(t, testFilePath)
defer os.Remove(testFilePath) defer os.Remove(testFilePath)
s := NewWorkersService("2020.8.5", fmt.Sprintf("%s/updater", ts.URL), testFilePath, Options{}) s := NewWorkersService(mostRecentVersion, fmt.Sprintf("%s/updater", ts.URL), testFilePath, Options{})
v, err := s.Check() v, err := s.Check()
require.NoError(t, err) require.NoError(t, err)
require.Nil(t, v) require.NotNil(t, v)
require.Empty(t, v.Version())
} }
func TestForcedUpdateService(t *testing.T) { func TestForcedUpdateService(t *testing.T) {
@ -210,13 +221,13 @@ func TestForcedUpdateService(t *testing.T) {
s := NewWorkersService("2020.8.5", fmt.Sprintf("%s/updater", ts.URL), testFilePath, Options{IsForced: true}) s := NewWorkersService("2020.8.5", fmt.Sprintf("%s/updater", ts.URL), testFilePath, Options{IsForced: true})
v, err := s.Check() v, err := s.Check()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, v.String(), "2020.08.05") require.Equal(t, v.Version(), mostRecentVersion)
require.NoError(t, v.Apply()) require.NoError(t, v.Apply())
dat, err := ioutil.ReadFile(testFilePath) dat, err := ioutil.ReadFile(testFilePath)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, string(dat), "2020.08.05") require.Equal(t, string(dat), mostRecentVersion)
} }
func TestUpdateSpecificVersionService(t *testing.T) { func TestUpdateSpecificVersionService(t *testing.T) {
@ -231,7 +242,7 @@ func TestUpdateSpecificVersionService(t *testing.T) {
s := NewWorkersService("2020.8.2", fmt.Sprintf("%s/updater", ts.URL), testFilePath, Options{RequestedVersion: reqVersion}) s := NewWorkersService("2020.8.2", fmt.Sprintf("%s/updater", ts.URL), testFilePath, Options{RequestedVersion: reqVersion})
v, err := s.Check() v, err := s.Check()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, reqVersion, v.String()) require.Equal(t, reqVersion, v.Version())
require.NoError(t, v.Apply()) require.NoError(t, v.Apply())
dat, err := ioutil.ReadFile(testFilePath) dat, err := ioutil.ReadFile(testFilePath)
@ -251,7 +262,7 @@ func TestCompressedUpdateService(t *testing.T) {
s := NewWorkersService("2020.8.2", fmt.Sprintf("%s/compressed", ts.URL), testFilePath, Options{}) s := NewWorkersService("2020.8.2", fmt.Sprintf("%s/compressed", ts.URL), testFilePath, Options{})
v, err := s.Check() v, err := s.Check()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "2020.09.02", v.String()) require.Equal(t, "2020.09.02", v.Version())
require.NoError(t, v.Apply()) require.NoError(t, v.Apply())
dat, err := ioutil.ReadFile(testFilePath) dat, err := ioutil.ReadFile(testFilePath)
@ -260,6 +271,21 @@ func TestCompressedUpdateService(t *testing.T) {
require.Equal(t, "2020.09.02", string(dat)) require.Equal(t, "2020.09.02", string(dat))
} }
func TestUpdateWhenRunningKnownBuggyVersion(t *testing.T) {
ts := createServer()
defer ts.Close()
testFilePath := "tmpfile"
createTestFile(t, testFilePath)
defer os.Remove(testFilePath)
s := NewWorkersService(knownBuggyVersion, fmt.Sprintf("%s/updater", ts.URL), testFilePath, Options{})
v, err := s.Check()
require.NoError(t, err)
require.Equal(t, v.Version(), mostRecentVersion)
require.Equal(t, v.UserMessage(), expectedUserMsg)
}
func TestVersionParsing(t *testing.T) { func TestVersionParsing(t *testing.T) {
require.False(t, IsNewerVersion("2020.8.2", "2020.8.2")) require.False(t, IsNewerVersion("2020.8.2", "2020.8.2"))
require.True(t, IsNewerVersion("2020.8.2", "2020.8.3")) require.True(t, IsNewerVersion("2020.8.2", "2020.8.3"))

View File

@ -54,6 +54,7 @@ type WorkersVersion struct {
version string version string
targetPath string targetPath string
isCompressed bool isCompressed bool
userMessage string
} }
// NewWorkersVersion creates a new Version object. This is normally created by a WorkersService JSON checkin response // NewWorkersVersion creates a new Version object. This is normally created by a WorkersService JSON checkin response
@ -61,8 +62,17 @@ type WorkersVersion struct {
// version is the version of this update // version is the version of this update
// checksum is the expected checksum of the downloaded file // checksum is the expected checksum of the downloaded file
// target path is where the file should be replace. Normally this the running cloudflared's path // target path is where the file should be replace. Normally this the running cloudflared's path
func NewWorkersVersion(url, version, checksum, targetPath string, isCompressed bool) Version { // userMessage is a possible message to convey back to the user after having checked in with the Updater Service
return &WorkersVersion{downloadURL: url, version: version, checksum: checksum, targetPath: targetPath, isCompressed: isCompressed} // isCompressed tells whether the asset to update cloudflared is compressed or not
func NewWorkersVersion(url, version, checksum, targetPath, userMessage string, isCompressed bool) CheckResult {
return &WorkersVersion{
downloadURL: url,
version: version,
checksum: checksum,
targetPath: targetPath,
isCompressed: isCompressed,
userMessage: userMessage,
}
} }
// Apply does the actual verification and update logic. // Apply does the actual verification and update logic.
@ -114,10 +124,16 @@ func (v *WorkersVersion) Apply() error {
} }
// String returns the version number of this update/release (e.g. 2020.08.05) // String returns the version number of this update/release (e.g. 2020.08.05)
func (v *WorkersVersion) String() string { func (v *WorkersVersion) Version() string {
return v.version return v.version
} }
// String returns a possible message to convey back to user after having checked in with the Updater Service. E.g.
// it can warn about the need to update the version currently running.
func (v *WorkersVersion) UserMessage() string {
return v.userMessage
}
// download the file from the link in the json // download the file from the link in the json
func download(url, filepath string, isCompressed bool) error { func download(url, filepath string, isCompressed bool) error {
client := &http.Client{ client := &http.Client{