diff --git a/hkexchan.go b/hkexchan.go index 52c79f9..e2ecc76 100644 --- a/hkexchan.go +++ b/hkexchan.go @@ -15,11 +15,11 @@ import ( "crypto/aes" "crypto/cipher" "encoding/hex" + "errors" "fmt" "hash" "log" "math/big" - "os" "golang.org/x/crypto/blowfish" "golang.org/x/crypto/twofish" @@ -46,12 +46,11 @@ const ( /* Support functionality to set up encryption after a channel has been negotiated via hkexnet.go */ -func (hc Conn) getStream(keymat *big.Int) (rc cipher.Stream, mc hash.Hash) { +func (hc Conn) getStream(keymat *big.Int) (rc cipher.Stream, mc hash.Hash, err error) { var key []byte var block cipher.Block var iv []byte var ivlen int - var err error copts := hc.cipheropts & 0xFF // TODO: each cipher alg case should ensure len(keymat.Bytes()) @@ -93,7 +92,8 @@ func (hc Conn) getStream(keymat *big.Int) (rc cipher.Stream, mc hash.Hash) { default: log.Printf("[invalid cipher (%d)]\n", copts) fmt.Printf("DOOFUS SET A VALID CIPHER ALG (%d)\n", copts) - os.Exit(1) + err = errors.New("hkexchan: INVALID CIPHER ALG") + //os.Exit(1) } hopts := (hc.cipheropts >> 8) & 0xFF @@ -109,20 +109,19 @@ func (hc Conn) getStream(keymat *big.Int) (rc cipher.Stream, mc hash.Hash) { default: log.Printf("[invalid hmac (%d)]\n", hopts) fmt.Printf("DOOFUS SET A VALID HMAC ALG (%d)\n", hopts) - os.Exit(1) + err = errors.New("hkexchan: INVALID HMAC ALG") + return + //os.Exit(1) } if err != nil { - panic(err) + // Feed the IV into the hmac: all traffic in the connection must + // feed its data into the hmac afterwards, so both ends can xor + // that with the stream to detect corruption. + _, _ = mc.Write(iv) + var currentHash []byte + currentHash = mc.Sum(currentHash) + log.Printf("Channel init hmac(iv):%s\n", hex.EncodeToString(currentHash)) } - - // Feed the IV into the hmac: all traffic in the connection must - // feed its data into the hmac afterwards, so both ends can xor - // that with the stream to detect corruption. - _, _ = mc.Write(iv) - var currentHash []byte - currentHash = mc.Sum(currentHash) - log.Printf("Channel init hmac(iv):%s\n", hex.EncodeToString(currentHash)) - return } diff --git a/hkexnet.go b/hkexnet.go index 2849781..e967868 100644 --- a/hkexnet.go +++ b/hkexnet.go @@ -28,6 +28,12 @@ import ( "time" ) +const ( + csoNone = iota // No error, normal packet + csoHmacInvalid // HMAC mismatch detect on remote end + csoChaff // This packet is a dummy, do not process beyond decryption +) + /*---------------------------------------------------------------------*/ // Conn is a HKex connection - a drop-in replacement for net.Conn @@ -149,8 +155,8 @@ func Dial(protocol string, ipport string, extensions ...string) (hc *Conn, err e hc.h.FA() log.Printf("**(c)** FA:%s\n", hc.h.fa) - hc.r, hc.rm = hc.getStream(hc.h.fa) - hc.w, hc.wm = hc.getStream(hc.h.fa) + hc.r, hc.rm, err = hc.getStream(hc.h.fa) + hc.w, hc.wm, err = hc.getStream(hc.h.fa) return } @@ -262,11 +268,13 @@ func (hl HKExListener) Accept() (hc Conn, err error) { // d is value for Herradura key exchange d := big.NewInt(0) _, err = fmt.Fscanln(c, d) + log.Printf("[Got d:%v]", d) if err != nil { return hc, err } _, err = fmt.Fscanf(c, "%08x:%08x\n", &hc.cipheropts, &hc.opts) + log.Printf("[Got cipheropts, opts:%v, %v]", hc.cipheropts, hc.opts) if err != nil { return hc, err } @@ -279,8 +287,8 @@ func (hl HKExListener) Accept() (hc Conn, err error) { fmt.Fprintf(c, "0x%s\n%08x:%08x\n", hc.h.d.Text(16), hc.cipheropts, hc.opts) - hc.r, hc.rm = hc.getStream(hc.h.fa) - hc.w, hc.wm = hc.getStream(hc.h.fa) + hc.r, hc.rm, err = hc.getStream(hc.h.fa) + hc.w, hc.wm, err = hc.getStream(hc.h.fa) return } @@ -303,9 +311,9 @@ func (c Conn) Read(b []byte) (n int, err error) { var hmacIn [4]uint8 var payloadLen uint32 - // Read ctrl/status opcode (for now, set nonzero on hmac mismatch) + // Read ctrl/status opcode (csoHmacInvalid on hmac mismatch) err = binary.Read(c.c, binary.BigEndian, &ctrlStatOp) - if ctrlStatOp != 0 { + if ctrlStatOp == csoHmacInvalid { // Other side indicated channel tampering, close channel c.Close() return 1, errors.New("** ALERT - remote end detected HMAC mismatch - possible channel tampering **") @@ -328,14 +336,19 @@ func (c Conn) Read(b []byte) (n int, err error) { if err != nil { if err.Error() != "EOF" { panic(err) - } // else { - // return 0, err - //} + // Cannot just return 0, err here - client won't hang up properly + // when 'exit' from shell. TODO: try server sending ctrlStatOp to + // indicate to Reader? -rlm 20180428 + } } + if payloadLen > 16384 { - panic("Insane payloadLen") + log.Printf("[Insane payloadLen:%v]\n", payloadLen) + c.Close() + return 1, errors.New("Insane payloadLen") } //log.Println("payloadLen:", payloadLen) + var payloadBytes = make([]byte, payloadLen) n, err = io.ReadFull(c.c, payloadBytes) //log.Print(" << Read ", n, " payloadBytes") @@ -364,8 +377,14 @@ func (c Conn) Read(b []byte) (n int, err error) { if err != nil { panic(err) } - c.dBuf.Write(payloadBytes) - //log.Printf("c.dBuf: %s\n", hex.Dump(c.dBuf.Bytes())) + + // Throw away pkt if it's chaff (ie., caller to Read() won't see this data) + if ctrlStatOp == csoChaff { + log.Printf("[Chaff pkt]\n") + } else { + c.dBuf.Write(payloadBytes) + //log.Printf("c.dBuf: %s\n", hex.Dump(c.dBuf.Bytes())) + } // Re-calculate hmac, compare with received value c.rm.Write(payloadBytes) @@ -374,10 +393,11 @@ func (c Conn) Read(b []byte) (n int, err error) { // Log alert if hmac didn't match, corrupted channel if !bytes.Equal(hTmp, []byte(hmacIn[0:])) /*|| hmacIn[0] > 0xf8*/ { - fmt.Println("** ALERT - hmac mismatch, possible channel tampering **") - _, _ = c.c.Write([]byte{0x1}) + fmt.Println("** ALERT - detected HMAC mismatch, possible channel tampering **") + _, _ = c.c.Write([]byte{csoHmacInvalid}) } } + retN := c.dBuf.Len() if retN > len(b) { retN = len(b) @@ -416,11 +436,11 @@ func (c Conn) Write(b []byte) (n int, err error) { panic(err) } log.Printf(" ->ctext:\r\n%s\r\n", hex.Dump(wb.Bytes())) - + var ctrlStatOp byte - ctrlStatOp = 0x00 + ctrlStatOp = csoNone _ = binary.Write(c.c, binary.BigEndian, &ctrlStatOp) - + // Write hmac LSB, payloadLen followed by payload _ = binary.Write(c.c, binary.BigEndian, hmacOut) _ = binary.Write(c.c, binary.BigEndian, payloadLen) diff --git a/hkexpasswd/hkexpasswd.go b/hkexpasswd/hkexpasswd.go index e4ce1c5..2d20875 100644 --- a/hkexpasswd/hkexpasswd.go +++ b/hkexpasswd/hkexpasswd.go @@ -19,7 +19,7 @@ import ( "os/user" "github.com/jameskeane/bcrypt" - hkexsh "blitter.com/hkexsh" + hkexsh "blitter.com/go/hkexsh" ) func main() { diff --git a/hkexsh/hkexsh.go b/hkexsh/hkexsh.go index e0657d7..c90e20b 100644 --- a/hkexsh/hkexsh.go +++ b/hkexsh/hkexsh.go @@ -18,7 +18,7 @@ import ( "strings" "sync" - hkexsh "blitter.com/hkexsh" + hkexsh "blitter.com/go/hkexsh" isatty "github.com/mattn/go-isatty" ) diff --git a/hkexshd/hkexshd.go b/hkexshd/hkexshd.go index 83e8557..d5906f9 100644 --- a/hkexshd/hkexshd.go +++ b/hkexshd/hkexshd.go @@ -18,8 +18,8 @@ import ( "os/user" "syscall" - hkexsh "blitter.com/hkexsh" - "blitter.com/hkexsh/spinsult" + hkexsh "blitter.com/go/hkexsh" + "blitter.com/go/hkexsh/spinsult" "github.com/kr/pty" ) @@ -171,90 +171,93 @@ func main() { // Wait for a connection. conn, err := l.Accept() if err != nil { - log.Fatal(err) - } - log.Println("Accepted client") + log.Printf("Accept() got error(%v), hanging up.\n", err) + conn.Close() + //log.Fatal(err) + } else { + log.Println("Accepted client") - // Handle the connection in a new goroutine. - // The loop then returns to accepting, so that - // multiple connections may be served concurrently. - go func(c hkexsh.Conn) (e error) { - defer c.Close() + // Handle the connection in a new goroutine. + // The loop then returns to accepting, so that + // multiple connections may be served concurrently. + go func(c hkexsh.Conn) (e error) { + defer c.Close() - //We use io.ReadFull() here to guarantee we consume - //just the data we want for the cmdSpec, and no more. - //Otherwise data will be sitting in the channel that isn't - //passed down to the command handlers. - var rec cmdSpec - var len1, len2, len3, len4 uint32 + //We use io.ReadFull() here to guarantee we consume + //just the data we want for the cmdSpec, and no more. + //Otherwise data will be sitting in the channel that isn't + //passed down to the command handlers. + var rec cmdSpec + var len1, len2, len3, len4 uint32 - n, err := fmt.Fscanf(c, "%d %d %d %d\n", &len1, &len2, &len3, &len4) - log.Printf("cmdSpec read:%d %d %d %d\n", len1, len2, len3, len4) + n, err := fmt.Fscanf(c, "%d %d %d %d\n", &len1, &len2, &len3, &len4) + log.Printf("cmdSpec read:%d %d %d %d\n", len1, len2, len3, len4) - if err != nil || n < 4 { - log.Println("[Bad cmdSpec fmt]") - return err - } - //fmt.Printf(" lens:%d %d %d %d\n", len1, len2, len3, len4) + if err != nil || n < 4 { + log.Println("[Bad cmdSpec fmt]") + return err + } + //fmt.Printf(" lens:%d %d %d %d\n", len1, len2, len3, len4) - rec.op = make([]byte, len1, len1) - _, err = io.ReadFull(c, rec.op) - if err != nil { - log.Println("[Bad cmdSpec.op]") - return err - } - rec.who = make([]byte, len2, len2) - _, err = io.ReadFull(c, rec.who) - if err != nil { - log.Println("[Bad cmdSpec.who]") - return err - } + rec.op = make([]byte, len1, len1) + _, err = io.ReadFull(c, rec.op) + if err != nil { + log.Println("[Bad cmdSpec.op]") + return err + } + rec.who = make([]byte, len2, len2) + _, err = io.ReadFull(c, rec.who) + if err != nil { + log.Println("[Bad cmdSpec.who]") + return err + } - rec.cmd = make([]byte, len3, len3) - _, err = io.ReadFull(c, rec.cmd) - if err != nil { - log.Println("[Bad cmdSpec.cmd]") - return err - } + rec.cmd = make([]byte, len3, len3) + _, err = io.ReadFull(c, rec.cmd) + if err != nil { + log.Println("[Bad cmdSpec.cmd]") + return err + } - rec.authCookie = make([]byte, len4, len4) - _, err = io.ReadFull(c, rec.authCookie) - if err != nil { - log.Println("[Bad cmdSpec.authCookie]") - return err - } + rec.authCookie = make([]byte, len4, len4) + _, err = io.ReadFull(c, rec.authCookie) + if err != nil { + log.Println("[Bad cmdSpec.authCookie]") + return err + } - log.Printf("[cmdSpec: op:%c who:%s cmd:%s auth:****]\n", - rec.op[0], string(rec.who), string(rec.cmd)) + log.Printf("[cmdSpec: op:%c who:%s cmd:%s auth:****]\n", + rec.op[0], string(rec.who), string(rec.cmd)) - valid, allowedCmds := hkexsh.AuthUser(string(rec.who), string(rec.authCookie), "/etc/hkexsh.passwd") - if !valid { - log.Println("Invalid user", string(rec.who)) - c.Write([]byte(rejectUserMsg())) + valid, allowedCmds := hkexsh.AuthUser(string(rec.who), string(rec.authCookie), "/etc/hkexsh.passwd") + if !valid { + log.Println("Invalid user", string(rec.who)) + c.Write([]byte(rejectUserMsg())) + return + } + log.Printf("[allowedCmds:%s]\n", allowedCmds) + + if rec.op[0] == 'c' { + // Non-interactive command + log.Println("[Running command]") + runShellAs(string(rec.who), string(rec.cmd), false, conn) + // Returned hopefully via an EOF or exit/logout; + // Clear current op so user can enter next, or EOF + rec.op[0] = 0 + log.Println("[Command complete]") + } else if rec.op[0] == 's' { + log.Println("[Running shell]") + runShellAs(string(rec.who), string(rec.cmd), true, conn) + // Returned hopefully via an EOF or exit/logout; + // Clear current op so user can enter next, or EOF + rec.op[0] = 0 + log.Println("[Exiting shell]") + } else { + log.Println("[Bad cmdSpec]") + } return - } - log.Printf("[allowedCmds:%s]\n", allowedCmds) - - if rec.op[0] == 'c' { - // Non-interactive command - log.Println("[Running command]") - runShellAs(string(rec.who), string(rec.cmd), false, conn) - // Returned hopefully via an EOF or exit/logout; - // Clear current op so user can enter next, or EOF - rec.op[0] = 0 - log.Println("[Command complete]") - } else if rec.op[0] == 's' { - log.Println("[Running shell]") - runShellAs(string(rec.who), string(rec.cmd), true, conn) - // Returned hopefully via an EOF or exit/logout; - // Clear current op so user can enter next, or EOF - rec.op[0] = 0 - log.Println("[Exiting shell]") - } else { - log.Println("[Bad cmdSpec]") - } - return - }(conn) + }(conn) + } // Accept() success } //endfor log.Println("[Exiting]") }