AUTH-2993 added workers updater logic
This commit is contained in:
parent
2c9b7361b7
commit
ba4c8d8849
|
@ -81,16 +81,37 @@ func main() {
|
|||
|
||||
tunnel.Init(Version, shutdownC, graceShutdownC) // we need this to support the tunnel sub command...
|
||||
access.Init(shutdownC, graceShutdownC)
|
||||
updater.Init(Version)
|
||||
runApp(app, shutdownC, graceShutdownC)
|
||||
}
|
||||
|
||||
func commands(version func(c *cli.Context)) []*cli.Command {
|
||||
cmds := []*cli.Command{
|
||||
{
|
||||
Name: "update",
|
||||
Action: cliutil.ErrorHandler(updater.Update),
|
||||
Usage: "Update the agent if a new version exists",
|
||||
ArgsUsage: " ",
|
||||
Name: "update",
|
||||
Action: cliutil.ErrorHandler(updater.Update),
|
||||
Usage: "Update the agent if a new version exists",
|
||||
Flags: []cli.Flag{
|
||||
&cli.BoolFlag{
|
||||
Name: "beta",
|
||||
Usage: "specify if you wish to update to the latest beta version",
|
||||
},
|
||||
&cli.BoolFlag{
|
||||
Name: "force",
|
||||
Usage: "specify if you wish to force an upgrade to the latest version regardless of the current version",
|
||||
Hidden: true,
|
||||
},
|
||||
&cli.BoolFlag{
|
||||
Name: "staging",
|
||||
Usage: "specify if you wish to use the staging url for updating",
|
||||
Hidden: true,
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "version",
|
||||
Usage: "specify a version you wish to upgrade or downgrade to",
|
||||
Hidden: false,
|
||||
},
|
||||
},
|
||||
Description: `Looks for a new version on the official download server.
|
||||
If a new version exists, updates the agent binary and quits.
|
||||
Otherwise, does nothing.
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
package updater
|
||||
|
||||
// Version is the functions needed to perform an update
|
||||
type Version interface {
|
||||
Apply() error
|
||||
String() string
|
||||
}
|
||||
|
||||
// Service is the functions to get check for new updates
|
||||
type Service interface {
|
||||
Check() (Version, error)
|
||||
}
|
||||
|
||||
const (
|
||||
// OSKeyName is the url parameter key to send to the checkin API for the operating system of the local cloudflared (e.g. windows, darwin, linux)
|
||||
OSKeyName = "os"
|
||||
|
||||
// ArchitectureKeyName is the url parameter key to send to the checkin API for the architecture of the local cloudflared (e.g. amd64, x86)
|
||||
ArchitectureKeyName = "arch"
|
||||
|
||||
// BetaKeyName is the url parameter key to send to the checkin API to signal if the update should be a beta version or not
|
||||
BetaKeyName = "beta"
|
||||
|
||||
// VersionKeyName is the url parameter key to send to the checkin API to specific what version to upgrade or downgrade to
|
||||
VersionKeyName = "version"
|
||||
)
|
|
@ -13,7 +13,6 @@ import (
|
|||
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
"github.com/equinox-io/equinox"
|
||||
"github.com/facebookgo/grace/gracenet"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
@ -25,16 +24,12 @@ const (
|
|||
noUpdateOnWindowsMessage = "cloudflared will not automatically update on Windows systems."
|
||||
noUpdateManagedPackageMessage = "cloudflared will not automatically update if installed by a package manager."
|
||||
isManagedInstallFile = ".installedFromPackageManager"
|
||||
UpdateURL = "https://update.argotunnel.com"
|
||||
StagingUpdateURL = "https://staging-update.argotunnel.com"
|
||||
)
|
||||
|
||||
var (
|
||||
publicKey = []byte(`
|
||||
-----BEGIN ECDSA PUBLIC KEY-----
|
||||
MHYwEAYHKoZIzj0CAQYFK4EEACIDYgAE4OWZocTVZ8Do/L6ScLdkV+9A0IYMHoOf
|
||||
dsCmJ/QZ6aw0w9qkkwEpne1Lmo6+0pGexZzFZOH6w5amShn+RXt7qkSid9iWlzGq
|
||||
EKx0BZogHSor9Wy5VztdFaAaVbsJiCbO
|
||||
-----END ECDSA PUBLIC KEY-----
|
||||
`)
|
||||
version string
|
||||
)
|
||||
|
||||
// BinaryUpdated implements ExitCoder interface, the app will exit with status code 11
|
||||
|
@ -64,6 +59,13 @@ func (e *statusErr) ExitCode() int {
|
|||
return 10
|
||||
}
|
||||
|
||||
type updateOptions struct {
|
||||
isBeta bool
|
||||
isStaging bool
|
||||
isForced bool
|
||||
version string
|
||||
}
|
||||
|
||||
type UpdateOutcome struct {
|
||||
Updated bool
|
||||
Version string
|
||||
|
@ -74,29 +76,44 @@ func (uo *UpdateOutcome) noUpdate() bool {
|
|||
return uo.Error == nil && uo.Updated == false
|
||||
}
|
||||
|
||||
func checkForUpdateAndApply() UpdateOutcome {
|
||||
var opts equinox.Options
|
||||
if err := opts.SetPublicKeyPEM(publicKey); err != nil {
|
||||
return UpdateOutcome{Error: err}
|
||||
}
|
||||
func Init(v string) {
|
||||
version = v
|
||||
}
|
||||
|
||||
resp, err := equinox.Check(appID, opts)
|
||||
switch {
|
||||
case err == equinox.NotAvailableErr:
|
||||
return UpdateOutcome{}
|
||||
case err != nil:
|
||||
return UpdateOutcome{Error: err}
|
||||
}
|
||||
|
||||
err = resp.Apply()
|
||||
func checkForUpdateAndApply(options updateOptions) UpdateOutcome {
|
||||
cfdPath, err := os.Executable()
|
||||
if err != nil {
|
||||
return UpdateOutcome{Error: err}
|
||||
}
|
||||
|
||||
return UpdateOutcome{Updated: true, Version: resp.ReleaseVersion}
|
||||
url := UpdateURL
|
||||
if options.isStaging {
|
||||
url = StagingUpdateURL
|
||||
}
|
||||
|
||||
s := NewWorkersService(version, url, cfdPath, Options{IsBeta: options.isBeta,
|
||||
IsForced: options.isForced, RequestedVersion: options.version})
|
||||
|
||||
v, err := s.Check()
|
||||
if err != nil {
|
||||
return UpdateOutcome{Error: err}
|
||||
}
|
||||
|
||||
//already on the latest version
|
||||
if v == nil {
|
||||
return UpdateOutcome{}
|
||||
}
|
||||
|
||||
err = v.Apply()
|
||||
if err != nil {
|
||||
return UpdateOutcome{Error: err}
|
||||
}
|
||||
|
||||
return UpdateOutcome{Updated: true, Version: v.String()}
|
||||
}
|
||||
|
||||
func Update(_ *cli.Context) error {
|
||||
// Update is the handler for the update command from the command line
|
||||
func Update(c *cli.Context) error {
|
||||
logger, err := logger.New()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "error setting up logger")
|
||||
|
@ -107,7 +124,22 @@ func Update(_ *cli.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
updateOutcome := loggedUpdate(logger)
|
||||
isBeta := c.Bool("beta")
|
||||
if isBeta {
|
||||
logger.Info("cloudflared is set to update to the latest beta version")
|
||||
}
|
||||
|
||||
isStaging := c.Bool("staging")
|
||||
if isStaging {
|
||||
logger.Info("cloudflared is set to update from staging")
|
||||
}
|
||||
|
||||
isForced := c.Bool("force")
|
||||
if isForced {
|
||||
logger.Info("cloudflared is set to upgrade to the latest publish version regardless of the current version")
|
||||
}
|
||||
|
||||
updateOutcome := loggedUpdate(logger, updateOptions{isBeta: isBeta, isStaging: isStaging, isForced: isForced, version: c.String("version")})
|
||||
if updateOutcome.Error != nil {
|
||||
return &statusErr{updateOutcome.Error}
|
||||
}
|
||||
|
@ -121,8 +153,8 @@ func Update(_ *cli.Context) error {
|
|||
}
|
||||
|
||||
// Checks for an update and applies it if one is available
|
||||
func loggedUpdate(logger logger.Service) UpdateOutcome {
|
||||
updateOutcome := checkForUpdateAndApply()
|
||||
func loggedUpdate(logger logger.Service, options updateOptions) UpdateOutcome {
|
||||
updateOutcome := checkForUpdateAndApply(options)
|
||||
if updateOutcome.Updated {
|
||||
logger.Infof("cloudflared has been updated to version %s", updateOutcome.Version)
|
||||
}
|
||||
|
@ -168,7 +200,7 @@ func (a *AutoUpdater) Run(ctx context.Context) error {
|
|||
ticker := time.NewTicker(a.configurable.freq)
|
||||
for {
|
||||
if a.configurable.enabled {
|
||||
updateOutcome := loggedUpdate(a.logger)
|
||||
updateOutcome := loggedUpdate(a.logger, updateOptions{})
|
||||
if updateOutcome.Updated {
|
||||
os.Args = append(os.Args, "--is-autoupdated=true")
|
||||
if IsSysV() {
|
||||
|
|
|
@ -0,0 +1,154 @@
|
|||
package updater
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Options are the update options supported by the
|
||||
type Options struct {
|
||||
// IsBeta is for beta updates to be installed if available
|
||||
IsBeta bool
|
||||
|
||||
// IsForced is to forcibly download the latest version regardless of the current version
|
||||
IsForced bool
|
||||
|
||||
// RequestedVersion is the specific version to upgrade or downgrade to
|
||||
RequestedVersion string
|
||||
}
|
||||
|
||||
// VersionResponse is the JSON response from the Workers API endpoint
|
||||
type VersionResponse struct {
|
||||
URL string `json:"url"`
|
||||
Version string `json:"version"`
|
||||
Checksum string `json:"checksum"`
|
||||
IsCompressed bool `json:"compressed"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
// WorkersService implements Service.
|
||||
// It contains everything needed to check in with the WorkersAPI and download and apply the updates
|
||||
type WorkersService struct {
|
||||
currentVersion string
|
||||
url string
|
||||
targetPath string
|
||||
opts Options
|
||||
}
|
||||
|
||||
// NewWorkersService creates a new updater Service object.
|
||||
func NewWorkersService(currentVersion, url, targetPath string, opts Options) Service {
|
||||
return &WorkersService{
|
||||
currentVersion: currentVersion,
|
||||
url: url,
|
||||
targetPath: targetPath,
|
||||
opts: opts,
|
||||
}
|
||||
}
|
||||
|
||||
// Check does a check in with the Workers API to get a new version update
|
||||
func (s *WorkersService) Check() (Version, error) {
|
||||
client := &http.Client{
|
||||
Timeout: time.Second * 5,
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, s.url, nil)
|
||||
q := req.URL.Query()
|
||||
q.Add(OSKeyName, runtime.GOOS)
|
||||
q.Add(ArchitectureKeyName, runtime.GOARCH)
|
||||
|
||||
if s.opts.IsBeta {
|
||||
q.Add(BetaKeyName, "true")
|
||||
}
|
||||
|
||||
if s.opts.RequestedVersion != "" {
|
||||
q.Add(VersionKeyName, s.opts.RequestedVersion)
|
||||
}
|
||||
|
||||
req.URL.RawQuery = q.Encode()
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var v VersionResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&v); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if v.Error != "" {
|
||||
return nil, errors.New(v.Error)
|
||||
}
|
||||
|
||||
if !s.opts.IsForced && !IsNewerVersion(s.currentVersion, v.Version) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return NewWorkersVersion(v.URL, v.Version, v.Checksum, s.targetPath, v.IsCompressed), nil
|
||||
}
|
||||
|
||||
// IsNewerVersion checks semantic versioning for the latest version
|
||||
// cloudflared tagging is more of a date than a semantic version,
|
||||
// but the same comparision logic still holds for major.minor.patch
|
||||
// e.g. 2020.8.2 is newer than 2020.8.1.
|
||||
func IsNewerVersion(current string, check string) bool {
|
||||
if strings.Contains(strings.ToLower(current), "dev") {
|
||||
return false // dev builds shouldn't update
|
||||
}
|
||||
|
||||
cMajor, cMinor, cPatch, err := SemanticParts(current)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
nMajor, nMinor, nPatch, err := SemanticParts(check)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if nMajor > cMajor {
|
||||
return true
|
||||
}
|
||||
|
||||
if nMajor == cMajor && nMinor > cMinor {
|
||||
return true
|
||||
}
|
||||
|
||||
if nMajor == cMajor && nMinor == cMinor && nPatch > cPatch {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// SemanticParts gets the major, minor, and patch version of a semantic version string
|
||||
// e.g. 3.1.2 would return 3, 1, 2, nil
|
||||
func SemanticParts(version string) (major int, minor int, patch int, err error) {
|
||||
major = 0
|
||||
minor = 0
|
||||
patch = 0
|
||||
parts := strings.Split(version, ".")
|
||||
if len(parts) != 3 {
|
||||
err = errors.New("invalid version")
|
||||
return
|
||||
}
|
||||
major, err = strconv.Atoi(parts[0])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
minor, err = strconv.Atoi(parts[1])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
patch, err = strconv.Atoi(parts[2])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
|
@ -0,0 +1,263 @@
|
|||
package updater
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func respondWithJSON(w http.ResponseWriter, v interface{}, status int) {
|
||||
data, _ := json.Marshal(v)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
w.Write(data)
|
||||
}
|
||||
|
||||
func respondWithData(w http.ResponseWriter, b []byte, status int) {
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
w.WriteHeader(status)
|
||||
w.Write(b)
|
||||
}
|
||||
|
||||
func updateHandler(w http.ResponseWriter, r *http.Request) {
|
||||
version := "2020.08.05"
|
||||
host := "http://localhost:8090"
|
||||
url := host + "/download"
|
||||
|
||||
query := r.URL.Query()
|
||||
|
||||
if query.Get(BetaKeyName) == "true" {
|
||||
version = "2020.08.06"
|
||||
url = host + "/beta"
|
||||
}
|
||||
|
||||
requestedVersion := query.Get(VersionKeyName)
|
||||
if requestedVersion != "" {
|
||||
version = requestedVersion
|
||||
url = fmt.Sprintf("%s?version=%s", url, requestedVersion)
|
||||
}
|
||||
|
||||
if query.Get(ArchitectureKeyName) != runtime.GOARCH || query.Get(OSKeyName) != runtime.GOOS {
|
||||
respondWithJSON(w, VersionResponse{Error: "unsupported os and architecture"}, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
h := sha256.New()
|
||||
fmt.Fprint(h, version)
|
||||
checksum := fmt.Sprintf("%x", h.Sum(nil))
|
||||
|
||||
v := VersionResponse{URL: url, Version: version, Checksum: checksum}
|
||||
respondWithJSON(w, v, http.StatusOK)
|
||||
}
|
||||
|
||||
func gzipUpdateHandler(w http.ResponseWriter, r *http.Request) {
|
||||
log.Println("got a request!")
|
||||
version := "2020.09.02"
|
||||
h := sha256.New()
|
||||
fmt.Fprint(h, version)
|
||||
checksum := fmt.Sprintf("%x", h.Sum(nil))
|
||||
|
||||
v := VersionResponse{URL: "http://localhost:8090/gzip-download.tgz", Version: version, Checksum: checksum}
|
||||
respondWithJSON(w, v, http.StatusOK)
|
||||
}
|
||||
|
||||
func compressedDownloadHandler(w http.ResponseWriter, r *http.Request) {
|
||||
version := "2020.09.02"
|
||||
buf := new(bytes.Buffer)
|
||||
|
||||
gw := gzip.NewWriter(buf)
|
||||
tw := tar.NewWriter(gw)
|
||||
|
||||
header := &tar.Header{
|
||||
Size: int64(len(version)),
|
||||
Name: "download",
|
||||
}
|
||||
tw.WriteHeader(header)
|
||||
tw.Write([]byte(version))
|
||||
|
||||
tw.Close()
|
||||
gw.Close()
|
||||
|
||||
respondWithData(w, buf.Bytes(), http.StatusOK)
|
||||
}
|
||||
|
||||
func downloadHandler(w http.ResponseWriter, r *http.Request) {
|
||||
version := "2020.08.05"
|
||||
requestedVersion := r.URL.Query().Get(VersionKeyName)
|
||||
if requestedVersion != "" {
|
||||
version = requestedVersion
|
||||
}
|
||||
respondWithData(w, []byte(version), http.StatusOK)
|
||||
}
|
||||
|
||||
func betaHandler(w http.ResponseWriter, r *http.Request) {
|
||||
respondWithData(w, []byte("2020.08.06"), http.StatusOK)
|
||||
}
|
||||
|
||||
func failureHandler(w http.ResponseWriter, r *http.Request) {
|
||||
respondWithJSON(w, VersionResponse{Error: "unsupported os and architecture"}, http.StatusBadRequest)
|
||||
}
|
||||
|
||||
func startServer() {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/updater", updateHandler)
|
||||
mux.HandleFunc("/download", downloadHandler)
|
||||
mux.HandleFunc("/beta", betaHandler)
|
||||
mux.HandleFunc("/fail", failureHandler)
|
||||
mux.HandleFunc("/compressed", gzipUpdateHandler)
|
||||
mux.HandleFunc("/gzip-download.tgz", compressedDownloadHandler)
|
||||
http.ListenAndServe(":8090", mux)
|
||||
}
|
||||
|
||||
func createTestFile(t *testing.T, path string) {
|
||||
f, err := os.Create("tmpfile")
|
||||
assert.NoError(t, err)
|
||||
fmt.Fprint(f, "2020.08.04")
|
||||
f.Close()
|
||||
}
|
||||
|
||||
func TestUpdateService(t *testing.T) {
|
||||
go startServer()
|
||||
|
||||
testFilePath := "tmpfile"
|
||||
createTestFile(t, testFilePath)
|
||||
defer os.Remove(testFilePath)
|
||||
|
||||
s := NewWorkersService("2020.8.2", "http://localhost:8090/updater", testFilePath, Options{})
|
||||
v, err := s.Check()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, v.String(), "2020.08.05")
|
||||
|
||||
assert.NoError(t, v.Apply())
|
||||
dat, err := ioutil.ReadFile(testFilePath)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, string(dat), "2020.08.05")
|
||||
}
|
||||
|
||||
func TestBetaUpdateService(t *testing.T) {
|
||||
go startServer()
|
||||
|
||||
testFilePath := "tmpfile"
|
||||
createTestFile(t, testFilePath)
|
||||
defer os.Remove(testFilePath)
|
||||
|
||||
s := NewWorkersService("2020.8.2", "http://localhost:8090/updater", testFilePath, Options{IsBeta: true})
|
||||
v, err := s.Check()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, v.String(), "2020.08.06")
|
||||
|
||||
assert.NoError(t, v.Apply())
|
||||
dat, err := ioutil.ReadFile(testFilePath)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, string(dat), "2020.08.06")
|
||||
}
|
||||
|
||||
func TestFailUpdateService(t *testing.T) {
|
||||
go startServer()
|
||||
|
||||
testFilePath := "tmpfile"
|
||||
createTestFile(t, testFilePath)
|
||||
defer os.Remove(testFilePath)
|
||||
|
||||
s := NewWorkersService("2020.8.2", "http://localhost:8090/fail", testFilePath, Options{})
|
||||
v, err := s.Check()
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, v)
|
||||
}
|
||||
|
||||
func TestNoUpdateService(t *testing.T) {
|
||||
go startServer()
|
||||
|
||||
testFilePath := "tmpfile"
|
||||
createTestFile(t, testFilePath)
|
||||
defer os.Remove(testFilePath)
|
||||
|
||||
s := NewWorkersService("2020.8.5", "http://localhost:8090/updater", testFilePath, Options{})
|
||||
v, err := s.Check()
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, v)
|
||||
}
|
||||
|
||||
func TestForcedUpdateService(t *testing.T) {
|
||||
go startServer()
|
||||
|
||||
testFilePath := "tmpfile"
|
||||
createTestFile(t, testFilePath)
|
||||
defer os.Remove(testFilePath)
|
||||
|
||||
s := NewWorkersService("2020.8.5", "http://localhost:8090/updater", testFilePath, Options{IsForced: true})
|
||||
v, err := s.Check()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, v.String(), "2020.08.05")
|
||||
|
||||
assert.NoError(t, v.Apply())
|
||||
dat, err := ioutil.ReadFile(testFilePath)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, string(dat), "2020.08.05")
|
||||
}
|
||||
|
||||
func TestUpdateSpecificVersionService(t *testing.T) {
|
||||
go startServer()
|
||||
|
||||
testFilePath := "tmpfile"
|
||||
createTestFile(t, testFilePath)
|
||||
defer os.Remove(testFilePath)
|
||||
reqVersion := "2020.9.1"
|
||||
|
||||
s := NewWorkersService("2020.8.2", "http://localhost:8090/updater", testFilePath, Options{RequestedVersion: reqVersion})
|
||||
v, err := s.Check()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, reqVersion, v.String())
|
||||
|
||||
assert.NoError(t, v.Apply())
|
||||
dat, err := ioutil.ReadFile(testFilePath)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, reqVersion, string(dat))
|
||||
}
|
||||
|
||||
func TestCompressedUpdateService(t *testing.T) {
|
||||
go startServer()
|
||||
|
||||
testFilePath := "tmpfile"
|
||||
createTestFile(t, testFilePath)
|
||||
defer os.Remove(testFilePath)
|
||||
|
||||
s := NewWorkersService("2020.8.2", "http://localhost:8090/compressed", testFilePath, Options{})
|
||||
v, err := s.Check()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "2020.09.02", v.String())
|
||||
|
||||
assert.NoError(t, v.Apply())
|
||||
dat, err := ioutil.ReadFile(testFilePath)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "2020.09.02", string(dat))
|
||||
}
|
||||
|
||||
func TestVersionParsing(t *testing.T) {
|
||||
assert.False(t, IsNewerVersion("2020.8.2", "2020.8.2"))
|
||||
assert.True(t, IsNewerVersion("2020.8.2", "2020.8.3"))
|
||||
assert.True(t, IsNewerVersion("2020.8.2", "2021.1.2"))
|
||||
assert.True(t, IsNewerVersion("2020.8.2", "2020.9.1"))
|
||||
assert.True(t, IsNewerVersion("2020.8.2", "2020.12.45"))
|
||||
assert.False(t, IsNewerVersion("2020.8.2", "2020.6.3"))
|
||||
assert.False(t, IsNewerVersion("DEV", "2020.8.5"))
|
||||
assert.False(t, IsNewerVersion("2020.8.2", "asdlkfjasdf"))
|
||||
assert.True(t, IsNewerVersion("3.0.1", "4.2.1"))
|
||||
}
|
|
@ -0,0 +1,227 @@
|
|||
package updater
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"compress/gzip"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"text/template"
|
||||
"time"
|
||||
)
|
||||
|
||||
// stop the service
|
||||
// rename cloudflared.exe to cloudflared.exe.old
|
||||
// rename cloudflared.exe.new to cloudflared.exe
|
||||
// delete cloudflared.exe.old
|
||||
// start the service
|
||||
// delete the batch file
|
||||
const windowsUpdateCommandTemplate = `@echo off
|
||||
sc stop cloudflared >nul 2>&1
|
||||
rename "{{.TargetPath}}" {{.OldName}}
|
||||
rename "{{.NewPath}}" {{.BinaryName}}
|
||||
del "{{.OldPath}}"
|
||||
sc start cloudflared >nul 2>&1
|
||||
del {{.BatchName}}`
|
||||
const batchFileName = "cfd_update.bat"
|
||||
|
||||
// Prepare some data to insert into the template.
|
||||
type batchData struct {
|
||||
TargetPath string
|
||||
OldName string
|
||||
NewPath string
|
||||
OldPath string
|
||||
BinaryName string
|
||||
BatchName string
|
||||
}
|
||||
|
||||
// WorkersVersion implements the Version interface.
|
||||
// It contains everything needed to preform a version upgrade
|
||||
type WorkersVersion struct {
|
||||
downloadURL string
|
||||
checksum string
|
||||
version string
|
||||
targetPath string
|
||||
isCompressed bool
|
||||
}
|
||||
|
||||
// NewWorkersVersion creates a new Version object. This is normally created by a WorkersService JSON checkin response
|
||||
// url is where to download the file
|
||||
// version is the version of this update
|
||||
// checksum is the expected checksum of the downloaded file
|
||||
// target path is where the file should be replace. Normally this the running cloudflared's path
|
||||
func NewWorkersVersion(url, version, checksum, targetPath string, isCompressed bool) Version {
|
||||
return &WorkersVersion{downloadURL: url, version: version, checksum: checksum, targetPath: targetPath, isCompressed: isCompressed}
|
||||
}
|
||||
|
||||
// Apply does the actual verification and update logic.
|
||||
// This includes signature and checksum validation,
|
||||
// replacing the binary, etc
|
||||
func (v *WorkersVersion) Apply() error {
|
||||
newFilePath := fmt.Sprintf("%s.new", v.targetPath)
|
||||
os.Remove(newFilePath) //remove any failed updates before download
|
||||
|
||||
// download the file
|
||||
if err := download(v.downloadURL, newFilePath, v.isCompressed); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// check that the file is what is expected
|
||||
if err := isValidChecksum(v.checksum, newFilePath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oldFilePath := fmt.Sprintf("%s.old", v.targetPath)
|
||||
// Windows requires more effort to self update, especially when it is running as a service:
|
||||
// you have to stop the service (if running as one) in order to move/rename the binary
|
||||
// but now the binary isn't running though, so an external process
|
||||
// has to move the old binary out and the new one in then start the service
|
||||
// the easiest way to do this is with a batch file (or with a DLL, but that gets ugly for a cross compiled binary like cloudflared)
|
||||
// a batch file isn't ideal, but it is the simplest path forward for the constraints Windows creates
|
||||
if runtime.GOOS == "windows" {
|
||||
if err := writeBatchFile(v.targetPath, newFilePath, oldFilePath); err != nil {
|
||||
return err
|
||||
}
|
||||
rootDir := filepath.Dir(v.targetPath)
|
||||
batchPath := filepath.Join(rootDir, batchFileName)
|
||||
return runWindowsBatch(batchPath)
|
||||
}
|
||||
|
||||
// now move the current file out, move the new file in and delete the old file
|
||||
if err := os.Rename(v.targetPath, oldFilePath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.Rename(newFilePath, v.targetPath); err != nil {
|
||||
//attempt rollback
|
||||
os.Rename(oldFilePath, v.targetPath)
|
||||
return err
|
||||
}
|
||||
os.Remove(oldFilePath)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// String returns the version number of this update/release (e.g. 2020.08.05)
|
||||
func (v *WorkersVersion) String() string {
|
||||
return v.version
|
||||
}
|
||||
|
||||
// download the file from the link in the json
|
||||
func download(url, filepath string, isCompressed bool) error {
|
||||
client := &http.Client{
|
||||
Timeout: time.Second * 5,
|
||||
}
|
||||
resp, err := client.Get(url)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var r io.Reader
|
||||
r = resp.Body
|
||||
|
||||
// compressed macos binary, need to decompress
|
||||
if isCompressed || isCompressedFile(url) {
|
||||
// first the gzip reader
|
||||
gr, err := gzip.NewReader(resp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer gr.Close()
|
||||
|
||||
// now the tar
|
||||
tr := tar.NewReader(gr)
|
||||
|
||||
// advance the reader pass the header, which will be the single binary file
|
||||
tr.Next()
|
||||
|
||||
r = tr
|
||||
}
|
||||
|
||||
out, err := os.OpenFile(filepath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0755)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
_, err = io.Copy(out, r)
|
||||
return err
|
||||
}
|
||||
|
||||
// isCompressedFile is a really simple file extension check to see if this is a macos tar and gzipped
|
||||
func isCompressedFile(urlstring string) bool {
|
||||
if strings.HasSuffix(urlstring, ".tgz") {
|
||||
return true
|
||||
}
|
||||
|
||||
u, err := url.Parse(urlstring)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return strings.HasSuffix(u.Path, ".tgz")
|
||||
}
|
||||
|
||||
// checks if the checksum in the json response matches the checksum of the file download
|
||||
func isValidChecksum(checksum, filePath string) error {
|
||||
f, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
h := sha256.New()
|
||||
if _, err := io.Copy(h, f); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
hash := fmt.Sprintf("%x", h.Sum(nil))
|
||||
|
||||
if checksum != hash {
|
||||
return errors.New("checksum validation failed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeBatchFile writes a batch file out to disk
|
||||
// see the dicussion on why it has to be done this way
|
||||
func writeBatchFile(targetPath string, newPath string, oldPath string) error {
|
||||
os.Remove(batchFileName) //remove any failed updates before download
|
||||
f, err := os.Create(batchFileName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
cfdName := filepath.Base(targetPath)
|
||||
oldName := filepath.Base(oldPath)
|
||||
|
||||
data := batchData{
|
||||
TargetPath: targetPath,
|
||||
OldName: oldName,
|
||||
NewPath: newPath,
|
||||
OldPath: oldPath,
|
||||
BinaryName: cfdName,
|
||||
BatchName: batchFileName,
|
||||
}
|
||||
|
||||
t, err := template.New("batch").Parse(windowsUpdateCommandTemplate)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return t.Execute(f, data)
|
||||
}
|
||||
|
||||
// run each OS command for windows
|
||||
func runWindowsBatch(batchFile string) error {
|
||||
cmd := exec.Command("cmd", "/c", batchFile)
|
||||
return cmd.Start()
|
||||
}
|
|
@ -97,7 +97,7 @@ def main():
|
|||
|
||||
msg = release.body
|
||||
|
||||
for filename in glob.glob(".artifacts/*.*"):
|
||||
for filename in glob.glob("artifacts/*.*"):
|
||||
pkg_hash = get_sha256(filename)
|
||||
# add the sha256 of the new artifact to the release message body
|
||||
msg = update_or_add_message(msg, filename, pkg_hash)
|
||||
|
|
|
@ -9,6 +9,7 @@ import os
|
|||
import shutil
|
||||
import hashlib
|
||||
import requests
|
||||
import tarfile
|
||||
|
||||
from github import Github, GithubException, UnknownObjectException
|
||||
|
||||
|
@ -175,8 +176,21 @@ def main():
|
|||
|
||||
release.upload_asset(args.path, name=args.name)
|
||||
|
||||
# check and extract if the file is a tar and gzipped file (as is the case with the macos builds)
|
||||
binary_path = args.path
|
||||
if binary_path.endswith("tgz"):
|
||||
try:
|
||||
shutil.rmtree('cfd')
|
||||
except OSError as e:
|
||||
pass
|
||||
zipfile = tarfile.open(binary_path, "r:gz")
|
||||
zipfile.extractall('cfd') # specify which folder to extract to
|
||||
zipfile.close()
|
||||
|
||||
binary_path = os.path.join(os.getcwd(), 'cfd', 'cloudflared')
|
||||
|
||||
# send the sha256 (the checksum) to workers kv
|
||||
pkg_hash = get_sha256(args.path)
|
||||
pkg_hash = get_sha256(binary_path)
|
||||
send_hash(pkg_hash, args.name, args.release_version, args.kv_account_id, args.namespace_id, args.kv_api_token)
|
||||
|
||||
# create the artifacts directory if it doesn't exist
|
||||
|
@ -186,7 +200,8 @@ def main():
|
|||
|
||||
# copy the binary to the path
|
||||
copy_path = os.path.join(artifact_path, args.name)
|
||||
shutil.copy(args.path, copy_path)
|
||||
if args.path != copy_path:
|
||||
shutil.copy(args.path, copy_path)
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue