diff --git a/hkexnet/hkexnet.go b/hkexnet/hkexnet.go index 4337502..0c2505f 100644 --- a/hkexnet/hkexnet.go +++ b/hkexnet/hkexnet.go @@ -54,7 +54,7 @@ type KEX uint8 const ( KEX_HERRADURA = iota // this MUST be first for default if omitted in ctor - //KEX_FOO + KEX_FOO //KEX_DH //KEX_ETC ) @@ -73,7 +73,7 @@ const ( CSONone = iota // No error, normal packet CSOHmacInvalid // HMAC mismatch detected on remote end CSOTermSize // set term size (rows:cols) - CSOExitStatus // Remote cmd exit status (TODO) + CSOExitStatus // Remote cmd exit status CSOChaff // Dummy packet, do not pass beyond decryption ) @@ -101,12 +101,12 @@ type ( // Conn is a HKex connection - a superset of net.Conn Conn struct { - kex KEX + kex KEX // KEX alg (typedef uint8) m *sync.Mutex - c net.Conn // which also implements io.Reader, io.Writer, ... - h *hkex.HerraduraKEx - cipheropts uint32 // post-KEx cipher/hmac options - opts uint32 // post-KEx protocol options (caller-defined) + c net.Conn // which also implements io.Reader, io.Writer, ... + h *hkex.HerraduraKEx // TODO: make an interface? + cipheropts uint32 // post-KEx cipher/hmac options + opts uint32 // post-KEx protocol options (caller-defined) WinCh chan WinSize Rows uint16 Cols uint16 @@ -176,10 +176,10 @@ func (hc *Conn) applyConnExtensions(extensions ...string) { log.Println("[extension arg = KEX_HERRADURA]") hc.kex = KEX_HERRADURA break - //case "KEX_FOO": - // log.Println("[extension arg = KEX_FOO]") - // hc.kex = KEX_FOO - // break + case "KEX_FOO": + log.Println("[extension arg = KEX_FOO]") + hc.kex = KEX_FOO + break case "C_AES_256": log.Println("[extension arg = C_AES_256]") hc.cipheropts &= (0xFFFFFF00) @@ -207,6 +207,65 @@ func (hc *Conn) applyConnExtensions(extensions ...string) { } } +func HKExDialSetup(c net.Conn, hc *Conn) (err error) { + // Send hkexnet.Conn parameters to remote side + // d is value for Herradura key exchange + fmt.Fprintf(c, "0x%s\n%08x:%08x\n", hc.h.D().Text(16), + hc.cipheropts, hc.opts) + + d := big.NewInt(0) + _, err = fmt.Fscanln(c, d) + if err != nil { + return err + } + // Read peer D over net.Conn (c) + _, err = fmt.Fscanf(c, "%08x:%08x\n", + &hc.cipheropts, &hc.opts) + if err != nil { + return err + } + + hc.h.SetPeerD(d) + log.Printf("** local D:%s\n", hc.h.D().Text(16)) + log.Printf("**(c)** peer D:%s\n", hc.h.PeerD().Text(16)) + hc.h.ComputeFA() + log.Printf("**(c)** FA:%s\n", hc.h.FA()) + + hc.r, hc.rm, err = hc.getStream(hc.h.FA()) + hc.w, hc.wm, err = hc.getStream(hc.h.FA()) + return +} + +func HKExAcceptSetup(c net.Conn, hc *Conn) (err error) { + // Read in hkexnet.Conn parameters over raw Conn c + // d is value for Herradura key exchange + d := big.NewInt(0) + _, err = fmt.Fscanln(c, d) + log.Printf("[Got d:%v]", d) + if err != nil { + return err + } + _, err = fmt.Fscanf(c, "%08x:%08x\n", + &hc.cipheropts, &hc.opts) + log.Printf("[Got cipheropts, opts:%v, %v]", hc.cipheropts, hc.opts) + if err != nil { + return err + } + hc.h.SetPeerD(d) + log.Printf("** D:%s\n", hc.h.D().Text(16)) + log.Printf("**(s)** peerD:%s\n", hc.h.PeerD().Text(16)) + hc.h.ComputeFA() + log.Printf("**(s)** FA:%s\n", hc.h.FA()) + + // Send D and cipheropts/conn_opts to peer + fmt.Fprintf(c, "0x%s\n%08x:%08x\n", hc.h.D().Text(16), + hc.cipheropts, hc.opts) + + hc.r, hc.rm, err = hc.getStream(hc.h.FA()) + hc.w, hc.wm, err = hc.getStream(hc.h.FA()) + return +} + // Dial as net.Dial(), but with implicit key exchange to set up secure // channel on connect // @@ -234,32 +293,25 @@ func Dial(protocol string, ipport string, extensions ...string) (hc *Conn, err e fmt.Fprintf(c, "%02x\n", hc.kex) // -- - // Send hkexnet.Conn parameters to remote side - // d is value for Herradura key exchange - fmt.Fprintf(c, "0x%s\n%08x:%08x\n", hc.h.D().Text(16), - hc.cipheropts, hc.opts) - - d := big.NewInt(0) - _, err = fmt.Fscanln(c, d) - if err != nil { - return nil, err - } - // Read peer D over net.Conn (c) - _, err = fmt.Fscanf(c, "%08x:%08x\n", - &hc.cipheropts, &hc.opts) - if err != nil { - return nil, err - } - - hc.h.SetPeerD(d) - log.Printf("** local D:%s\n", hc.h.D().Text(16)) - log.Printf("**(c)** peer D:%s\n", hc.h.PeerD().Text(16)) - hc.h.ComputeFA() - log.Printf("**(c)** FA:%s\n", hc.h.FA()) - - hc.r, hc.rm, err = hc.getStream(hc.h.FA()) - hc.w, hc.wm, err = hc.getStream(hc.h.FA()) *hc.closeStat = CSEStillOpen // open or prematurely-closed status + + // Perform Key Exchange according to client-request algorithm + switch hc.kex { + case KEX_HERRADURA: + if HKExDialSetup(c, hc) != nil { + return hc, nil + } + case KEX_FOO: + // For testing: set up as HKEx anyway, but server via Accept() should + // reject as invalid. + //if FooKExDialSetup(c, hc) != nil { + if HKExDialSetup(c, hc) != nil { + return hc, nil + } + default: + log.Printf("Invalid kex alg (%d), rejecting\n", hc.kex) + return nil, errors.New("Invalid kex alg") + } return } @@ -380,35 +432,19 @@ func (hl *HKExListener) Accept() (hc Conn, err error) { if err != nil { return hc, err } - log.Printf("[KEx alg: %v]\n", kexAlg) + log.Printf("[Client proposed KEx alg: %v]\n", kexAlg) // -- - // Read in hkexnet.Conn parameters over raw Conn c - // d is value for Herradura key exchange - d := big.NewInt(0) - _, err = fmt.Fscanln(c, d) - log.Printf("[Got d:%v]", d) - if err != nil { - return hc, err + switch kexAlg { + case KEX_HERRADURA: + log.Printf("[KEx alg %d accepted]\n", kexAlg) + if HKExAcceptSetup(c, &hc) != nil { + return hc, nil + } + default: + log.Printf("[KEx alg %d rejected]\n", kexAlg) + return hc, errors.New("KEx rejected") } - _, err = fmt.Fscanf(c, "%08x:%08x\n", - &hc.cipheropts, &hc.opts) - log.Printf("[Got cipheropts, opts:%v, %v]", hc.cipheropts, hc.opts) - if err != nil { - return hc, err - } - hc.h.SetPeerD(d) - log.Printf("** D:%s\n", hc.h.D().Text(16)) - log.Printf("**(s)** peerD:%s\n", hc.h.PeerD().Text(16)) - hc.h.ComputeFA() - log.Printf("**(s)** FA:%s\n", hc.h.FA()) - - // Send D and cipheropts/conn_opts to peer - fmt.Fprintf(c, "0x%s\n%08x:%08x\n", hc.h.D().Text(16), - hc.cipheropts, hc.opts) - - hc.r, hc.rm, err = hc.getStream(hc.h.FA()) - hc.w, hc.wm, err = hc.getStream(hc.h.FA()) return } @@ -559,7 +595,11 @@ func (hc *Conn) WritePacket(b []byte, op byte) (n int, err error) { //log.Printf("[Encrypting...]\r\n") var hmacOut []uint8 var payloadLen uint32 - + + if hc.m == nil || hc.wm == nil { + return 0, errors.New("Secure chan not ready for writing") + } + // N.B. Originally this Lock() surrounded only the // calls to binary.Write(hc.c ..) however there appears // to be some other unshareable state in the Conn