Set up to handle Kyber768 KEM

This commit is contained in:
Russ Magee 2018-10-08 21:31:11 -07:00
parent 767ae7bd07
commit 4c286ae6c1
3 changed files with 147 additions and 89 deletions

View File

@ -9,10 +9,11 @@ package hkexnet
const ( const (
KEX_HERRADURA = iota // this MUST be first for default if omitted in ctor KEX_HERRADURA = iota // this MUST be first for default if omitted in ctor
KEX_FOO KEX_KYBER768
//KEX_DH //KEX_DH
//KEX_ETC //KEX_ETC
) )
// Sent from client to server in order to specify which // Sent from client to server in order to specify which
// algo shall be used (eg., HerraduraKEx, [TODO: others...]) // algo shall be used (eg., HerraduraKEx, [TODO: others...])
type KEXAlg uint8 type KEXAlg uint8
@ -39,6 +40,7 @@ const (
// Channel status type // Channel status type
type CSOType uint32 type CSOType uint32
//TODO: this should be small (max unfragmented packet size?)
const MAX_PAYLOAD_LEN = 4*1024*1024*1024 - 1 const MAX_PAYLOAD_LEN = 4*1024*1024*1024 - 1
const ( const (

View File

@ -69,14 +69,15 @@ type (
szMax uint // max size in bytes szMax uint // max size in bytes
} }
// Conn is a HKex connection - a superset of net.Conn //h *hkex.HerraduraKEx // TODO: make an interface?
// Conn is a connection wrapping net.Conn with KEX & session state
Conn struct { Conn struct {
kex KEXAlg kex KEXAlg // KEX/KEM propsal (client -> server)
m *sync.Mutex m *sync.Mutex // (internal)
c net.Conn // which also implements io.Reader, io.Writer, ... c *net.Conn // which also implements io.Reader, io.Writer, ...
h *hkex.HerraduraKEx // TODO: make an interface? cipheropts uint32 // post-KEx cipher/hmac options
cipheropts uint32 // post-KEx cipher/hmac options opts uint32 // post-KEx protocol options (caller-defined)
opts uint32 // post-KEx protocol options (caller-defined)
WinCh chan WinSize WinCh chan WinSize
Rows uint16 Rows uint16
Cols uint16 Cols uint16
@ -138,6 +139,48 @@ func (hc *Conn) SetOpts(opts uint32) {
hc.opts = opts hc.opts = opts
} }
func getkexalgnum(extensions ...string) (k KEXAlg) {
for _, s := range extensions {
switch s {
case "KEX_HERRADURA":
default:
log.Println("[extension arg = KEX_HERRADURA]")
k = KEX_HERRADURA
case "KEX_KYBER768":
log.Println("[extension arg = KEX_KYBER768]")
k = KEX_KYBER768
}
}
return
}
// Return a new hkexnet.Conn
//
// Note this is internal: use Dial() or Accept()
func _new(kexAlg KEXAlg, conn *net.Conn) (hc *Conn, e error) {
// Set up stuff common to all KEx/KEM types
hc = &Conn{kex: kexAlg,
m: &sync.Mutex{},
c: conn,
closeStat: new(CSOType),
WinCh: make(chan WinSize, 1),
dBuf: new(bytes.Buffer)}
*hc.closeStat = CSEStillOpen // open or prematurely-closed status
// Set up KEx/KEM-specifics
switch hc.kex {
case KEX_HERRADURA:
default:
return hc, HKExAcceptSetup(hc.c, hc)
log.Printf("[KEx alg %d accepted]\n", kexAlg)
case KEX_KYBER768:
fmt.Println("KYBER768: TODO")
return nil, errors.New("KEx Setup failed")
}
return
}
func (hc *Conn) applyConnExtensions(extensions ...string) { func (hc *Conn) applyConnExtensions(extensions ...string) {
//fmt.Printf("CSENone:%d CSEBadAuth:%d CSETruncCSO:%d CSEStillOpen:%d CSEExecFail:%d CSEPtyExecFail:%d\n", //fmt.Printf("CSENone:%d CSEBadAuth:%d CSETruncCSO:%d CSEStillOpen:%d CSEExecFail:%d CSEPtyExecFail:%d\n",
// CSENone, CSEBadAuth, CSETruncCSO, CSEStillOpen, CSEExecFail, CSEPtyExecFail) // CSENone, CSEBadAuth, CSETruncCSO, CSEStillOpen, CSEExecFail, CSEPtyExecFail)
@ -147,14 +190,6 @@ func (hc *Conn) applyConnExtensions(extensions ...string) {
for _, s := range extensions { for _, s := range extensions {
switch s { switch s {
case "KEX_HERRADURA":
log.Println("[extension arg = KEX_HERRADURA]")
hc.kex = KEX_HERRADURA
break
case "KEX_FOO":
log.Println("[extension arg = KEX_FOO]")
hc.kex = KEX_FOO
break
case "C_AES_256": case "C_AES_256":
log.Println("[extension arg = C_AES_256]") log.Println("[extension arg = C_AES_256]")
hc.cipheropts &= (0xFFFFFF00) hc.cipheropts &= (0xFFFFFF00)
@ -187,10 +222,15 @@ func (hc *Conn) applyConnExtensions(extensions ...string) {
} }
} }
func Kyber768DialSetup(c net.Conn, hc *Conn) (err error) {
return errors.New("NOT IMPLEMENTED")
}
func HKExDialSetup(c net.Conn, hc *Conn) (err error) { func HKExDialSetup(c net.Conn, hc *Conn) (err error) {
h := hkex.New(0, 0)
// Send hkexnet.Conn parameters to remote side // Send hkexnet.Conn parameters to remote side
// d is value for Herradura key exchange // d is value for Herradura key exchange
fmt.Fprintf(c, "0x%s\n%08x:%08x\n", hc.h.D().Text(16), fmt.Fprintf(c, "0x%s\n%08x:%08x\n", h.D().Text(16),
hc.cipheropts, hc.opts) hc.cipheropts, hc.opts)
d := big.NewInt(0) d := big.NewInt(0)
@ -205,44 +245,49 @@ func HKExDialSetup(c net.Conn, hc *Conn) (err error) {
return err return err
} }
hc.h.SetPeerD(d) h.SetPeerD(d)
log.Printf("** local D:%s\n", hc.h.D().Text(16)) log.Printf("** local D:%s\n", h.D().Text(16))
log.Printf("**(c)** peer D:%s\n", hc.h.PeerD().Text(16)) log.Printf("**(c)** peer D:%s\n", h.PeerD().Text(16))
hc.h.ComputeFA() h.ComputeFA()
log.Printf("**(c)** FA:%s\n", hc.h.FA()) log.Printf("**(c)** FA:%s\n", h.FA())
hc.r, hc.rm, err = hc.getStream(hc.h.FA()) hc.r, hc.rm, err = hc.getStream(h.FA())
hc.w, hc.wm, err = hc.getStream(hc.h.FA()) hc.w, hc.wm, err = hc.getStream(h.FA())
return return
} }
func HKExAcceptSetup(c net.Conn, hc *Conn) (err error) { func Kyber768AcceptSetup(c *net.Conn, hc *Conn) (err error) {
return errors.New("NOT IMPLEMENTED")
}
func HKExAcceptSetup(c *net.Conn, hc *Conn) (err error) {
h := hkex.New(0, 0)
// Read in hkexnet.Conn parameters over raw Conn c // Read in hkexnet.Conn parameters over raw Conn c
// d is value for Herradura key exchange // d is value for Herradura key exchange
d := big.NewInt(0) d := big.NewInt(0)
_, err = fmt.Fscanln(c, d) _, err = fmt.Fscanln(*c, d)
log.Printf("[Got d:%v]", d) log.Printf("[Got d:%v]", d)
if err != nil { if err != nil {
return err return err
} }
_, err = fmt.Fscanf(c, "%08x:%08x\n", _, err = fmt.Fscanf(*c, "%08x:%08x\n",
&hc.cipheropts, &hc.opts) &hc.cipheropts, &hc.opts)
log.Printf("[Got cipheropts, opts:%v, %v]", hc.cipheropts, hc.opts) log.Printf("[Got cipheropts, opts:%v, %v]", hc.cipheropts, hc.opts)
if err != nil { if err != nil {
return err return err
} }
hc.h.SetPeerD(d) h.SetPeerD(d)
log.Printf("** D:%s\n", hc.h.D().Text(16)) log.Printf("** D:%s\n", h.D().Text(16))
log.Printf("**(s)** peerD:%s\n", hc.h.PeerD().Text(16)) log.Printf("**(s)** peerD:%s\n", h.PeerD().Text(16))
hc.h.ComputeFA() h.ComputeFA()
log.Printf("**(s)** FA:%s\n", hc.h.FA()) log.Printf("**(s)** FA:%s\n", h.FA())
// Send D and cipheropts/conn_opts to peer // Send D and cipheropts/conn_opts to peer
fmt.Fprintf(c, "0x%s\n%08x:%08x\n", hc.h.D().Text(16), fmt.Fprintf(*c, "0x%s\n%08x:%08x\n", h.D().Text(16),
hc.cipheropts, hc.opts) hc.cipheropts, hc.opts)
hc.r, hc.rm, err = hc.getStream(hc.h.FA()) hc.r, hc.rm, err = hc.getStream(h.FA())
hc.w, hc.wm, err = hc.getStream(hc.h.FA()) hc.w, hc.wm, err = hc.getStream(h.FA())
return return
} }
@ -259,38 +304,38 @@ func Dial(protocol string, ipport string, extensions ...string) (hc Conn, err er
// Open raw Conn c // Open raw Conn c
c, err := net.Dial(protocol, ipport) c, err := net.Dial(protocol, ipport)
if err != nil { if err != nil {
return hc, err return Conn{}, err
} }
// Init hkexnet.Conn hc over net.Conn c // Init hkexnet.Conn hc over net.Conn c
// NOTE: kex default of KEX_HERRADURA may be overridden by ret, err := _new(getkexalgnum(extensions...), &c)
// future extension args to applyConnExtensions(), which is if err != nil {
// called prior to Dial() return Conn{}, err
hc = Conn{m: &sync.Mutex{}, c: c, closeStat: new(CSOType), h: hkex.New(0, 0), dBuf: new(bytes.Buffer)} }
hc = *ret
// Client has full control over Conn extensions. It's the server's
// responsibility to accept or reject the proposed parameters.
hc.applyConnExtensions(extensions...) hc.applyConnExtensions(extensions...)
// TODO: Factor out ALL params following this to helpers for
// specific KEx algs
fmt.Fprintf(c, "%02x\n", hc.kex)
// --
*hc.closeStat = CSEStillOpen // open or prematurely-closed status
// Perform Key Exchange according to client-request algorithm // Perform Key Exchange according to client-request algorithm
fmt.Fprintf(c, "%02x\n", hc.kex)
switch hc.kex { switch hc.kex {
case KEX_HERRADURA: case KEX_HERRADURA:
fmt.Println("[HKExDialSetup()]")
if HKExDialSetup(c, &hc) != nil { if HKExDialSetup(c, &hc) != nil {
return hc, nil return Conn{}, nil
} }
case KEX_FOO: case KEX_KYBER768:
// For testing: set up as HKEx anyway, but server via Accept() should fmt.Println("[Kyber768DialSetup()]")
// reject as invalid. if Kyber768DialSetup(c, &hc) != nil {
//if FooKExDialSetup(c, hc) != nil { return Conn{}, nil
if HKExDialSetup(c, &hc) != nil {
return hc, nil
} }
default: default:
log.Printf("Invalid kex alg (%d), rejecting\n", hc.kex) fmt.Println("[Default HKExDialSetup()]")
return hc, errors.New("Invalid kex alg") if HKExDialSetup(c, &hc) != nil {
return Conn{}, nil
}
} }
return return
} }
@ -302,19 +347,19 @@ func (hc *Conn) Close() (err error) {
binary.BigEndian.PutUint32(s, uint32(*hc.closeStat)) binary.BigEndian.PutUint32(s, uint32(*hc.closeStat))
log.Printf("** Writing closeStat %d at Close()\n", *hc.closeStat) log.Printf("** Writing closeStat %d at Close()\n", *hc.closeStat)
hc.WritePacket(s, CSOExitStatus) hc.WritePacket(s, CSOExitStatus)
err = hc.c.Close() err = (*hc.c).Close()
log.Println("[Conn Closing]") log.Println("[Conn Closing]")
return return
} }
// LocalAddr returns the local network address. // LocalAddr returns the local network address.
func (hc *Conn) LocalAddr() net.Addr { func (hc *Conn) LocalAddr() net.Addr {
return hc.c.LocalAddr() return (*hc.c).LocalAddr()
} }
// RemoteAddr returns the remote network address. // RemoteAddr returns the remote network address.
func (hc *Conn) RemoteAddr() net.Addr { func (hc *Conn) RemoteAddr() net.Addr {
return hc.c.RemoteAddr() return (*hc.c).RemoteAddr()
} }
// SetDeadline sets the read and write deadlines associated // SetDeadline sets the read and write deadlines associated
@ -333,7 +378,7 @@ func (hc *Conn) RemoteAddr() net.Addr {
// //
// A zero value for t means I/O operations will not time out. // A zero value for t means I/O operations will not time out.
func (hc *Conn) SetDeadline(t time.Time) error { func (hc *Conn) SetDeadline(t time.Time) error {
return hc.c.SetDeadline(t) return (*hc.c).SetDeadline(t)
} }
// SetWriteDeadline sets the deadline for future Write calls // SetWriteDeadline sets the deadline for future Write calls
@ -342,14 +387,14 @@ func (hc *Conn) SetDeadline(t time.Time) error {
// some of the data was successfully written. // some of the data was successfully written.
// A zero value for t means Write will not time out. // A zero value for t means Write will not time out.
func (hc *Conn) SetWriteDeadline(t time.Time) error { func (hc *Conn) SetWriteDeadline(t time.Time) error {
return hc.c.SetWriteDeadline(t) return (*hc.c).SetWriteDeadline(t)
} }
// SetReadDeadline sets the deadline for future Read calls // SetReadDeadline sets the deadline for future Read calls
// and any currently-blocked Read call. // and any currently-blocked Read call.
// A zero value for t means Read will not time out. // A zero value for t means Read will not time out.
func (hc *Conn) SetReadDeadline(t time.Time) error { func (hc *Conn) SetReadDeadline(t time.Time) error {
return hc.c.SetReadDeadline(t) return (*hc.c).SetReadDeadline(t)
} }
/*---------------------------------------------------------------------*/ /*---------------------------------------------------------------------*/
@ -397,35 +442,46 @@ func (hl *HKExListener) Accept() (hc Conn, err error) {
// Open raw Conn c // Open raw Conn c
c, err := hl.l.Accept() c, err := hl.l.Accept()
if err != nil { if err != nil {
hc := Conn{m: &sync.Mutex{}, c: nil, h: nil, closeStat: new(CSOType), cipheropts: 0, opts: 0, return Conn{}, err
r: nil, w: nil}
return hc, err
} }
log.Println("[Accepted]") log.Println("[net.Listener Accepted]")
hc = Conn{ /*kex: from client,*/ m: &sync.Mutex{}, c: c, h: hkex.New(0, 0), closeStat: new(CSOType), WinCh: make(chan WinSize, 1), // Read KEx alg proposed by client
dBuf: new(bytes.Buffer)} var kexAlg KEXAlg
// TODO: Factor out ALL params following this to helpers for
// specific KEx algs
var kexAlg uint8
_, err = fmt.Fscanln(c, &kexAlg) _, err = fmt.Fscanln(c, &kexAlg)
if err != nil { if err != nil {
return hc, err return Conn{}, err
} }
log.Printf("[Client proposed KEx alg: %v]\n", kexAlg) log.Printf("[Client proposed KEx alg: %v]\n", kexAlg)
// -- // --
switch kexAlg { ret, err := _new(kexAlg, &c)
if err != nil {
return Conn{}, err
}
hc = *ret
switch hc.kex {
case KEX_HERRADURA: case KEX_HERRADURA:
log.Printf("[KEx alg %d accepted]\n", kexAlg) log.Println("[Setting up for KEX_HERRADURA]")
if HKExAcceptSetup(c, &hc) != nil { if HKExAcceptSetup(&c, &hc) != nil {
return hc, nil log.Println("[ERROR - KEX_HERRADURA]")
return Conn{}, nil
}
case KEX_KYBER768:
log.Println("[Setting up for KEX_KYBER768]")
if Kyber768AcceptSetup(&c, &hc) != nil {
log.Println("[ERROR - KEX_KYBER768]")
return Conn{}, nil
} }
default: default:
log.Printf("[KEx alg %d rejected]\n", kexAlg) log.Println("[unknown alg, Setting up for KEX_HERRADURA]")
return hc, errors.New("KEx rejected") if HKExAcceptSetup(&c, &hc) != nil {
log.Println("[ERROR - default KEX_HERRADURA]")
return Conn{}, nil
}
} }
log.Println("[hc.Accept successful]")
return return
} }
@ -447,7 +503,7 @@ func (hc Conn) Read(b []byte) (n int, err error) {
var payloadLen uint32 var payloadLen uint32
// Read ctrl/status opcode (CSOHmacInvalid on hmac mismatch) // Read ctrl/status opcode (CSOHmacInvalid on hmac mismatch)
err = binary.Read(hc.c, binary.BigEndian, &ctrlStatOp) err = binary.Read(*hc.c, binary.BigEndian, &ctrlStatOp)
log.Printf("[ctrlStatOp: %v]\n", ctrlStatOp) log.Printf("[ctrlStatOp: %v]\n", ctrlStatOp)
if ctrlStatOp == CSOHmacInvalid { if ctrlStatOp == CSOHmacInvalid {
// Other side indicated channel tampering, close channel // Other side indicated channel tampering, close channel
@ -456,7 +512,7 @@ func (hc Conn) Read(b []byte) (n int, err error) {
} }
// Read the hmac and payload len first // Read the hmac and payload len first
err = binary.Read(hc.c, binary.BigEndian, &hmacIn) err = binary.Read(*hc.c, binary.BigEndian, &hmacIn)
// Normal client 'exit' from interactive session will cause // Normal client 'exit' from interactive session will cause
// (on server side) err.Error() == "<iface/addr info ...>: use of closed network connection" // (on server side) err.Error() == "<iface/addr info ...>: use of closed network connection"
if err != nil { if err != nil {
@ -468,7 +524,7 @@ func (hc Conn) Read(b []byte) (n int, err error) {
return 0, err return 0, err
} }
err = binary.Read(hc.c, binary.BigEndian, &payloadLen) err = binary.Read(*hc.c, binary.BigEndian, &payloadLen)
if err != nil { if err != nil {
if err.Error() != "EOF" { if err.Error() != "EOF" {
log.Println("[2]unexpected Read() err:", err) log.Println("[2]unexpected Read() err:", err)
@ -482,7 +538,7 @@ func (hc Conn) Read(b []byte) (n int, err error) {
} }
var payloadBytes = make([]byte, payloadLen) var payloadBytes = make([]byte, payloadLen)
n, err = io.ReadFull(hc.c, payloadBytes) n, err = io.ReadFull(*hc.c, payloadBytes)
// Normal client 'exit' from interactive session will cause // Normal client 'exit' from interactive session will cause
// (on server side) err.Error() == "<iface/addr info ...>: use of closed network connection" // (on server side) err.Error() == "<iface/addr info ...>: use of closed network connection"
@ -553,7 +609,7 @@ func (hc Conn) Read(b []byte) (n int, err error) {
// Log alert if hmac didn't match, corrupted channel // Log alert if hmac didn't match, corrupted channel
if !bytes.Equal(hTmp, []byte(hmacIn[0:])) /*|| hmacIn[0] > 0xf8*/ { if !bytes.Equal(hTmp, []byte(hmacIn[0:])) /*|| hmacIn[0] > 0xf8*/ {
fmt.Println("** ALERT - detected HMAC mismatch, possible channel tampering **") fmt.Println("** ALERT - detected HMAC mismatch, possible channel tampering **")
_, _ = hc.c.Write([]byte{CSOHmacInvalid}) _, _ = (*hc.c).Write([]byte{CSOHmacInvalid})
} }
} }
} }
@ -642,14 +698,14 @@ func (hc *Conn) WritePacket(b []byte, op byte) (n int, err error) {
log.Printf(" ->ctext:\r\n%s\r\n", hex.Dump(wb.Bytes())) log.Printf(" ->ctext:\r\n%s\r\n", hex.Dump(wb.Bytes()))
ctrlStatOp := op ctrlStatOp := op
err = binary.Write(hc.c, binary.BigEndian, &ctrlStatOp) err = binary.Write(*hc.c, binary.BigEndian, &ctrlStatOp)
if err == nil { if err == nil {
// Write hmac LSB, payloadLen followed by payload // Write hmac LSB, payloadLen followed by payload
err = binary.Write(hc.c, binary.BigEndian, hmacOut) err = binary.Write(*hc.c, binary.BigEndian, hmacOut)
if err == nil { if err == nil {
err = binary.Write(hc.c, binary.BigEndian, payloadLen) err = binary.Write(*hc.c, binary.BigEndian, payloadLen)
if err == nil { if err == nil {
n, err = hc.c.Write(wb.Bytes()) n, err = (*hc.c).Write(wb.Bytes())
} else { } else {
//fmt.Println("[c]WriteError!") //fmt.Println("[c]WriteError!")
} }

View File

@ -543,7 +543,7 @@ func main() {
} }
} }
conn, err := hkexnet.Dial("tcp", server /*[kexAlg eg. "KEX_HERRADURA"], */, cAlg, hAlg) conn, err := hkexnet.Dial("tcp", server, "KEX_HERRADURA", cAlg, hAlg)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
panic(err) panic(err)