diff --git a/hkexnet.go b/hkexnet.go index 0405a86..2849781 100644 --- a/hkexnet.go +++ b/hkexnet.go @@ -17,6 +17,7 @@ import ( "crypto/cipher" "encoding/binary" "encoding/hex" + "errors" "fmt" "hash" "io" @@ -298,9 +299,18 @@ func (c Conn) Read(b []byte) (n int, err error) { break } + var ctrlStatOp uint8 var hmacIn [4]uint8 var payloadLen uint32 + // Read ctrl/status opcode (for now, set nonzero on hmac mismatch) + err = binary.Read(c.c, binary.BigEndian, &ctrlStatOp) + if ctrlStatOp != 0 { + // Other side indicated channel tampering, close channel + c.Close() + return 1, errors.New("** ALERT - remote end detected HMAC mismatch - possible channel tampering **") + } + // Read the hmac and payload len first err = binary.Read(c.c, binary.BigEndian, &hmacIn) // Normal client 'exit' from interactive session will cause @@ -362,10 +372,10 @@ func (c Conn) Read(b []byte) (n int, err error) { hTmp := c.rm.Sum(nil)[0:4] log.Printf("<%04x) HMAC:(i)%s (c)%02x\r\n", decryptN, hex.EncodeToString([]byte(hmacIn[0:])), hTmp) - // Puke if hmac didn't match, corrupted channel - if !bytes.Equal(hTmp, []byte(hmacIn[0:])) || hmacIn[0] > 0xf8 { + // Log alert if hmac didn't match, corrupted channel + if !bytes.Equal(hTmp, []byte(hmacIn[0:])) /*|| hmacIn[0] > 0xf8*/ { fmt.Println("** ALERT - hmac mismatch, possible channel tampering **") - c.Close() + _, _ = c.c.Write([]byte{0x1}) } } retN := c.dBuf.Len() @@ -406,7 +416,11 @@ func (c Conn) Write(b []byte) (n int, err error) { panic(err) } log.Printf(" ->ctext:\r\n%s\r\n", hex.Dump(wb.Bytes())) - + + var ctrlStatOp byte + ctrlStatOp = 0x00 + _ = binary.Write(c.c, binary.BigEndian, &ctrlStatOp) + // Write hmac LSB, payloadLen followed by payload _ = binary.Write(c.c, binary.BigEndian, hmacOut) _ = binary.Write(c.c, binary.BigEndian, payloadLen)