Rewrite fingerprint checking

This commit is contained in:
blankie 2023-01-02 00:11:02 +07:00
parent 72e92e1219
commit 104cc5476c
Signed by: blankie
GPG Key ID: CC15FC822C7F61F5
2 changed files with 126 additions and 32 deletions

124
hosts.go Normal file
View File

@ -0,0 +1,124 @@
package main
import (
"bufio"
"crypto/sha256"
"crypto/x509"
"encoding/hex"
"errors"
"os"
"path/filepath"
"strconv"
"strings"
"time"
)
var hosts map[string]Host
type Host struct {
// string since why convert it to bytes and handle errors when you can just not
Fingerprint string
NotAfter time.Time
}
func populateHosts() error {
file, err := os.OpenFile(filepath.Join(xdgDataHome(), "konbata", "known_hosts_2"), os.O_RDONLY|os.O_CREATE, 0600)
if err != nil {
return err
}
defer file.Close()
scanner := bufio.NewScanner(file)
hosts = make(map[string]Host)
for scanner.Scan() {
if scanner.Text() == "" {
continue
}
values := strings.Split(scanner.Text(), " ")
if len(values) != 3 {
return errors.New("malformed host line encountered")
}
hostname := values[0]
fingerprint := values[1]
timeInt, err := strconv.ParseInt(values[2], 10, 0)
if err != nil {
return err
}
hosts[hostname] = Host{
Fingerprint: fingerprint,
NotAfter: time.UnixMicro(timeInt),
}
}
return nil
}
func saveHostsTmp(filename string) error {
file, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE, 0600)
if err != nil {
return err
}
defer file.Close()
err = file.Truncate(0)
if err != nil {
return err
}
writer := bufio.NewWriter(file)
for hostname, host := range hosts {
_, err = writer.WriteString(hostname)
if err != nil {
return err
}
_, err = writer.WriteRune(' ')
if err != nil {
return err
}
_, err = writer.WriteString(host.Fingerprint)
if err != nil {
return err
}
_, err = writer.WriteRune(' ')
if err != nil {
return err
}
_, err = writer.WriteString(strconv.FormatInt(host.NotAfter.UnixMicro(), 10))
if err != nil {
return err
}
_, err = writer.WriteRune('\n')
if err != nil {
return err
}
}
return writer.Flush()
}
func saveHosts() error {
filename := filepath.Join(xdgDataHome(), "konbata", "known_hosts_2")
tmpFilename := filepath.Join(xdgDataHome(), "konbata", "known_hosts_2.tmp")
err := saveHostsTmp(tmpFilename)
if err != nil {
return err
}
return os.Rename(tmpFilename, filename)
}
func TrustCertificate(hostname string, cert *x509.Certificate) error {
hash := sha256.New()
hash.Write(cert.Raw)
fingerprint := hex.EncodeToString(hash.Sum(nil))
currentTime := time.Now()
host, ok := hosts[hostname]
if !ok || currentTime.After(host.NotAfter) {
if currentTime.Before(cert.NotBefore) {
return errors.New("cert is used before its not before value")
}
hosts[hostname] = Host{
Fingerprint: fingerprint,
NotAfter: cert.NotAfter,
}
return saveHosts()
} else if host.Fingerprint != fingerprint {
return errors.New("fingerprint does not match!")
}
return nil
}

34
main.go
View File

@ -2,7 +2,6 @@ package main
import ( import (
"context" "context"
"crypto/x509"
"encoding/xml" "encoding/xml"
"errors" "errors"
"fmt" "fmt"
@ -10,31 +9,18 @@ import (
"log" "log"
"net/url" "net/url"
"os" "os"
"path/filepath"
"sort" "sort"
"strings" "strings"
"time" "time"
"git.sr.ht/~adnano/go-gemini" "git.sr.ht/~adnano/go-gemini"
"git.sr.ht/~adnano/go-gemini/tofu"
) )
const TIMEOUT_NS time.Duration = time.Duration(60_000_000_000) const TIMEOUT_NS time.Duration = time.Duration(60_000_000_000)
const MAX_FILE_SIZE int64 = 512 * 1024 const MAX_FILE_SIZE int64 = 512 * 1024
var (
hosts tofu.KnownHosts
hostsfile *tofu.HostWriter
)
func init() { func init() {
path := filepath.Join(xdgDataHome(), "konbata", "known_hosts") err := populateHosts()
err := hosts.Load(path)
if err != nil {
log.Fatal(err)
}
hostsfile, err = tofu.OpenHostsFile(path)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
@ -50,22 +36,6 @@ func init() {
} }
} }
func trustCertificate(hostname string, cert *x509.Certificate) error {
host := tofu.NewHost(hostname, cert.Raw)
knownHost, ok := hosts.Lookup(hostname)
if ok {
// Check fingerprint
if knownHost.Fingerprint != host.Fingerprint {
return errors.New("fingerprint does not match!")
}
return nil
}
hosts.Add(host)
hostsfile.WriteHost(host)
return nil
}
func do(client gemini.Client, ctx context.Context, req *gemini.Request, via []*gemini.Request) (*gemini.Response, *gemini.Request, error) { func do(client gemini.Client, ctx context.Context, req *gemini.Request, via []*gemini.Request) (*gemini.Response, *gemini.Request, error) {
if target, exists := predirs[req.URL.String()]; exists { if target, exists := predirs[req.URL.String()]; exists {
via = append(via, req) via = append(via, req)
@ -151,7 +121,7 @@ func main() {
} }
client := gemini.Client{ client := gemini.Client{
TrustCertificate: trustCertificate, TrustCertificate: TrustCertificate,
} }
ctx, cancel := context.WithTimeout(context.Background(), TIMEOUT_NS) ctx, cancel := context.WithTimeout(context.Background(), TIMEOUT_NS)
defer cancel() defer cancel()