diff --git a/hosts.go b/hosts.go new file mode 100644 index 0000000..5b21c4c --- /dev/null +++ b/hosts.go @@ -0,0 +1,124 @@ +package main + +import ( + "bufio" + "crypto/sha256" + "crypto/x509" + "encoding/hex" + "errors" + "os" + "path/filepath" + "strconv" + "strings" + "time" +) + +var hosts map[string]Host + +type Host struct { + // string since why convert it to bytes and handle errors when you can just not + Fingerprint string + NotAfter time.Time +} + +func populateHosts() error { + file, err := os.OpenFile(filepath.Join(xdgDataHome(), "konbata", "known_hosts_2"), os.O_RDONLY|os.O_CREATE, 0600) + if err != nil { + return err + } + defer file.Close() + scanner := bufio.NewScanner(file) + hosts = make(map[string]Host) + for scanner.Scan() { + if scanner.Text() == "" { + continue + } + values := strings.Split(scanner.Text(), " ") + if len(values) != 3 { + return errors.New("malformed host line encountered") + } + hostname := values[0] + fingerprint := values[1] + timeInt, err := strconv.ParseInt(values[2], 10, 0) + if err != nil { + return err + } + hosts[hostname] = Host{ + Fingerprint: fingerprint, + NotAfter: time.UnixMicro(timeInt), + } + } + return nil +} + +func saveHostsTmp(filename string) error { + file, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE, 0600) + if err != nil { + return err + } + defer file.Close() + err = file.Truncate(0) + if err != nil { + return err + } + writer := bufio.NewWriter(file) + for hostname, host := range hosts { + _, err = writer.WriteString(hostname) + if err != nil { + return err + } + _, err = writer.WriteRune(' ') + if err != nil { + return err + } + _, err = writer.WriteString(host.Fingerprint) + if err != nil { + return err + } + _, err = writer.WriteRune(' ') + if err != nil { + return err + } + _, err = writer.WriteString(strconv.FormatInt(host.NotAfter.UnixMicro(), 10)) + if err != nil { + return err + } + _, err = writer.WriteRune('\n') + if err != nil { + return err + } + } + return writer.Flush() +} + +func saveHosts() error { + filename := filepath.Join(xdgDataHome(), "konbata", "known_hosts_2") + tmpFilename := filepath.Join(xdgDataHome(), "konbata", "known_hosts_2.tmp") + err := saveHostsTmp(tmpFilename) + if err != nil { + return err + } + return os.Rename(tmpFilename, filename) +} + +func TrustCertificate(hostname string, cert *x509.Certificate) error { + hash := sha256.New() + hash.Write(cert.Raw) + fingerprint := hex.EncodeToString(hash.Sum(nil)) + currentTime := time.Now() + + host, ok := hosts[hostname] + if !ok || currentTime.After(host.NotAfter) { + if currentTime.Before(cert.NotBefore) { + return errors.New("cert is used before its not before value") + } + hosts[hostname] = Host{ + Fingerprint: fingerprint, + NotAfter: cert.NotAfter, + } + return saveHosts() + } else if host.Fingerprint != fingerprint { + return errors.New("fingerprint does not match!") + } + return nil +} diff --git a/main.go b/main.go index 74d1078..d5817a9 100644 --- a/main.go +++ b/main.go @@ -2,7 +2,6 @@ package main import ( "context" - "crypto/x509" "encoding/xml" "errors" "fmt" @@ -10,31 +9,18 @@ import ( "log" "net/url" "os" - "path/filepath" "sort" "strings" "time" "git.sr.ht/~adnano/go-gemini" - "git.sr.ht/~adnano/go-gemini/tofu" ) const TIMEOUT_NS time.Duration = time.Duration(60_000_000_000) const MAX_FILE_SIZE int64 = 512 * 1024 -var ( - hosts tofu.KnownHosts - hostsfile *tofu.HostWriter -) - func init() { - path := filepath.Join(xdgDataHome(), "konbata", "known_hosts") - err := hosts.Load(path) - if err != nil { - log.Fatal(err) - } - - hostsfile, err = tofu.OpenHostsFile(path) + err := populateHosts() if err != nil { log.Fatal(err) } @@ -50,22 +36,6 @@ func init() { } } -func trustCertificate(hostname string, cert *x509.Certificate) error { - host := tofu.NewHost(hostname, cert.Raw) - knownHost, ok := hosts.Lookup(hostname) - if ok { - // Check fingerprint - if knownHost.Fingerprint != host.Fingerprint { - return errors.New("fingerprint does not match!") - } - return nil - } - - hosts.Add(host) - hostsfile.WriteHost(host) - return nil -} - func do(client gemini.Client, ctx context.Context, req *gemini.Request, via []*gemini.Request) (*gemini.Response, *gemini.Request, error) { if target, exists := predirs[req.URL.String()]; exists { via = append(via, req) @@ -151,7 +121,7 @@ func main() { } client := gemini.Client{ - TrustCertificate: trustCertificate, + TrustCertificate: TrustCertificate, } ctx, cancel := context.WithTimeout(context.Background(), TIMEOUT_NS) defer cancel()