cloudflared-mirror/cmd/cloudflared/updater/workers_update.go

247 lines
6.5 KiB
Go

package updater
import (
"archive/tar"
"compress/gzip"
"crypto/sha256"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"text/template"
"time"
)
const (
clientTimeout = time.Second * 60
// stop the service
// rename cloudflared.exe to cloudflared.exe.old
// rename cloudflared.exe.new to cloudflared.exe
// delete cloudflared.exe.old
// start the service
// delete the batch file
windowsUpdateCommandTemplate = `@echo off
sc stop cloudflared >nul 2>&1
rename "{{.TargetPath}}" {{.OldName}}
rename "{{.NewPath}}" {{.BinaryName}}
del "{{.OldPath}}"
sc start cloudflared >nul 2>&1
del {{.BatchName}}`
batchFileName = "cfd_update.bat"
)
// Prepare some data to insert into the template.
type batchData struct {
TargetPath string
OldName string
NewPath string
OldPath string
BinaryName string
BatchName string
}
// WorkersVersion implements the Version interface.
// It contains everything needed to perform a version upgrade
type WorkersVersion struct {
downloadURL string
checksum string
version string
targetPath string
isCompressed bool
userMessage string
}
// NewWorkersVersion creates a new Version object. This is normally created by a WorkersService JSON checkin response
// url is where to download the file
// version is the version of this update
// checksum is the expected checksum of the downloaded file
// target path is where the file should be replace. Normally this the running cloudflared's path
// userMessage is a possible message to convey back to the user after having checked in with the Updater Service
// isCompressed tells whether the asset to update cloudflared is compressed or not
func NewWorkersVersion(url, version, checksum, targetPath, userMessage string, isCompressed bool) CheckResult {
return &WorkersVersion{
downloadURL: url,
version: version,
checksum: checksum,
targetPath: targetPath,
isCompressed: isCompressed,
userMessage: userMessage,
}
}
// Apply does the actual verification and update logic.
// This includes signature and checksum validation,
// replacing the binary, etc
func (v *WorkersVersion) Apply() error {
newFilePath := fmt.Sprintf("%s.new", v.targetPath)
os.Remove(newFilePath) //remove any failed updates before download
// download the file
if err := download(v.downloadURL, newFilePath, v.isCompressed); err != nil {
return err
}
// check that the file is what is expected
if err := isValidChecksum(v.checksum, newFilePath); err != nil {
return err
}
oldFilePath := fmt.Sprintf("%s.old", v.targetPath)
// Windows requires more effort to self update, especially when it is running as a service:
// you have to stop the service (if running as one) in order to move/rename the binary
// but now the binary isn't running though, so an external process
// has to move the old binary out and the new one in then start the service
// the easiest way to do this is with a batch file (or with a DLL, but that gets ugly for a cross compiled binary like cloudflared)
// a batch file isn't ideal, but it is the simplest path forward for the constraints Windows creates
if runtime.GOOS == "windows" {
if err := writeBatchFile(v.targetPath, newFilePath, oldFilePath); err != nil {
return err
}
rootDir := filepath.Dir(v.targetPath)
batchPath := filepath.Join(rootDir, batchFileName)
return runWindowsBatch(batchPath)
}
// now move the current file out, move the new file in and delete the old file
if err := os.Rename(v.targetPath, oldFilePath); err != nil {
return err
}
if err := os.Rename(newFilePath, v.targetPath); err != nil {
//attempt rollback
os.Rename(oldFilePath, v.targetPath)
return err
}
os.Remove(oldFilePath)
return nil
}
// String returns the version number of this update/release (e.g. 2020.08.05)
func (v *WorkersVersion) Version() string {
return v.version
}
// String returns a possible message to convey back to user after having checked in with the Updater Service. E.g.
// it can warn about the need to update the version currently running.
func (v *WorkersVersion) UserMessage() string {
return v.userMessage
}
// download the file from the link in the json
func download(url, filepath string, isCompressed bool) error {
client := &http.Client{
Timeout: clientTimeout,
}
resp, err := client.Get(url)
if err != nil {
return err
}
defer resp.Body.Close()
var r io.Reader
r = resp.Body
// compressed macos binary, need to decompress
if isCompressed || isCompressedFile(url) {
// first the gzip reader
gr, err := gzip.NewReader(resp.Body)
if err != nil {
return err
}
defer gr.Close()
// now the tar
tr := tar.NewReader(gr)
// advance the reader pass the header, which will be the single binary file
tr.Next()
r = tr
}
out, err := os.OpenFile(filepath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0755)
if err != nil {
return err
}
defer out.Close()
_, err = io.Copy(out, r)
return err
}
// isCompressedFile is a really simple file extension check to see if this is a macos tar and gzipped
func isCompressedFile(urlstring string) bool {
if strings.HasSuffix(urlstring, ".tgz") {
return true
}
u, err := url.Parse(urlstring)
if err != nil {
return false
}
return strings.HasSuffix(u.Path, ".tgz")
}
// checks if the checksum in the json response matches the checksum of the file download
func isValidChecksum(checksum, filePath string) error {
f, err := os.Open(filePath)
if err != nil {
return err
}
defer f.Close()
h := sha256.New()
if _, err := io.Copy(h, f); err != nil {
return err
}
hash := fmt.Sprintf("%x", h.Sum(nil))
if checksum != hash {
return errors.New("checksum validation failed")
}
return nil
}
// writeBatchFile writes a batch file out to disk
// see the dicussion on why it has to be done this way
func writeBatchFile(targetPath string, newPath string, oldPath string) error {
os.Remove(batchFileName) //remove any failed updates before download
f, err := os.Create(batchFileName)
if err != nil {
return err
}
defer f.Close()
cfdName := filepath.Base(targetPath)
oldName := filepath.Base(oldPath)
data := batchData{
TargetPath: targetPath,
OldName: oldName,
NewPath: newPath,
OldPath: oldPath,
BinaryName: cfdName,
BatchName: batchFileName,
}
t, err := template.New("batch").Parse(windowsUpdateCommandTemplate)
if err != nil {
return err
}
return t.Execute(f, data)
}
// run each OS command for windows
func runWindowsBatch(batchFile string) error {
cmd := exec.Command("cmd", "/c", batchFile)
return cmd.Start()
}