package updater import ( "archive/tar" "compress/gzip" "crypto/sha256" "errors" "fmt" "io" "net/http" "net/url" "os" "os/exec" "path/filepath" "runtime" "strings" "text/template" "time" ) const ( clientTimeout = time.Second * 60 // stop the service // rename cloudflared.exe to cloudflared.exe.old // rename cloudflared.exe.new to cloudflared.exe // delete cloudflared.exe.old // start the service // delete the batch file windowsUpdateCommandTemplate = `@echo off sc stop cloudflared >nul 2>&1 rename "{{.TargetPath}}" {{.OldName}} rename "{{.NewPath}}" {{.BinaryName}} del "{{.OldPath}}" sc start cloudflared >nul 2>&1 del {{.BatchName}}` batchFileName = "cfd_update.bat" ) // Prepare some data to insert into the template. type batchData struct { TargetPath string OldName string NewPath string OldPath string BinaryName string BatchName string } // WorkersVersion implements the Version interface. // It contains everything needed to perform a version upgrade type WorkersVersion struct { downloadURL string checksum string version string targetPath string isCompressed bool userMessage string } // NewWorkersVersion creates a new Version object. This is normally created by a WorkersService JSON checkin response // url is where to download the file // version is the version of this update // checksum is the expected checksum of the downloaded file // target path is where the file should be replace. Normally this the running cloudflared's path // userMessage is a possible message to convey back to the user after having checked in with the Updater Service // isCompressed tells whether the asset to update cloudflared is compressed or not func NewWorkersVersion(url, version, checksum, targetPath, userMessage string, isCompressed bool) CheckResult { return &WorkersVersion{ downloadURL: url, version: version, checksum: checksum, targetPath: targetPath, isCompressed: isCompressed, userMessage: userMessage, } } // Apply does the actual verification and update logic. // This includes signature and checksum validation, // replacing the binary, etc func (v *WorkersVersion) Apply() error { newFilePath := fmt.Sprintf("%s.new", v.targetPath) os.Remove(newFilePath) //remove any failed updates before download // download the file if err := download(v.downloadURL, newFilePath, v.isCompressed); err != nil { return err } // check that the file is what is expected if err := isValidChecksum(v.checksum, newFilePath); err != nil { return err } oldFilePath := fmt.Sprintf("%s.old", v.targetPath) // Windows requires more effort to self update, especially when it is running as a service: // you have to stop the service (if running as one) in order to move/rename the binary // but now the binary isn't running though, so an external process // has to move the old binary out and the new one in then start the service // the easiest way to do this is with a batch file (or with a DLL, but that gets ugly for a cross compiled binary like cloudflared) // a batch file isn't ideal, but it is the simplest path forward for the constraints Windows creates if runtime.GOOS == "windows" { if err := writeBatchFile(v.targetPath, newFilePath, oldFilePath); err != nil { return err } rootDir := filepath.Dir(v.targetPath) batchPath := filepath.Join(rootDir, batchFileName) return runWindowsBatch(batchPath) } // now move the current file out, move the new file in and delete the old file if err := os.Rename(v.targetPath, oldFilePath); err != nil { return err } if err := os.Rename(newFilePath, v.targetPath); err != nil { //attempt rollback os.Rename(oldFilePath, v.targetPath) return err } os.Remove(oldFilePath) return nil } // String returns the version number of this update/release (e.g. 2020.08.05) func (v *WorkersVersion) Version() string { return v.version } // String returns a possible message to convey back to user after having checked in with the Updater Service. E.g. // it can warn about the need to update the version currently running. func (v *WorkersVersion) UserMessage() string { return v.userMessage } // download the file from the link in the json func download(url, filepath string, isCompressed bool) error { client := &http.Client{ Timeout: clientTimeout, } resp, err := client.Get(url) if err != nil { return err } defer resp.Body.Close() var r io.Reader r = resp.Body // compressed macos binary, need to decompress if isCompressed || isCompressedFile(url) { // first the gzip reader gr, err := gzip.NewReader(resp.Body) if err != nil { return err } defer gr.Close() // now the tar tr := tar.NewReader(gr) // advance the reader pass the header, which will be the single binary file tr.Next() r = tr } out, err := os.OpenFile(filepath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0755) if err != nil { return err } defer out.Close() _, err = io.Copy(out, r) return err } // isCompressedFile is a really simple file extension check to see if this is a macos tar and gzipped func isCompressedFile(urlstring string) bool { if strings.HasSuffix(urlstring, ".tgz") { return true } u, err := url.Parse(urlstring) if err != nil { return false } return strings.HasSuffix(u.Path, ".tgz") } // checks if the checksum in the json response matches the checksum of the file download func isValidChecksum(checksum, filePath string) error { f, err := os.Open(filePath) if err != nil { return err } defer f.Close() h := sha256.New() if _, err := io.Copy(h, f); err != nil { return err } hash := fmt.Sprintf("%x", h.Sum(nil)) if checksum != hash { return errors.New("checksum validation failed") } return nil } // writeBatchFile writes a batch file out to disk // see the dicussion on why it has to be done this way func writeBatchFile(targetPath string, newPath string, oldPath string) error { os.Remove(batchFileName) //remove any failed updates before download f, err := os.Create(batchFileName) if err != nil { return err } defer f.Close() cfdName := filepath.Base(targetPath) oldName := filepath.Base(oldPath) data := batchData{ TargetPath: targetPath, OldName: oldName, NewPath: newPath, OldPath: oldPath, BinaryName: cfdName, BatchName: batchFileName, } t, err := template.New("batch").Parse(windowsUpdateCommandTemplate) if err != nil { return err } return t.Execute(f, data) } // run each OS command for windows func runWindowsBatch(batchFile string) error { cmd := exec.Command("cmd", "/c", batchFile) return cmd.Start() }