-Added error checking for all stages of hkex.Conn.Accept() and GetStream()

-Server will log such errors without panic/exit
-Const added but not yet used for 'chaff' packets
This commit is contained in:
Russ Magee 2018-04-28 16:05:33 -07:00
parent c56d4d9ad9
commit 50f0433579
5 changed files with 132 additions and 110 deletions

View File

@ -15,11 +15,11 @@ import (
"crypto/aes" "crypto/aes"
"crypto/cipher" "crypto/cipher"
"encoding/hex" "encoding/hex"
"errors"
"fmt" "fmt"
"hash" "hash"
"log" "log"
"math/big" "math/big"
"os"
"golang.org/x/crypto/blowfish" "golang.org/x/crypto/blowfish"
"golang.org/x/crypto/twofish" "golang.org/x/crypto/twofish"
@ -46,12 +46,11 @@ const (
/* Support functionality to set up encryption after a channel has /* Support functionality to set up encryption after a channel has
been negotiated via hkexnet.go been negotiated via hkexnet.go
*/ */
func (hc Conn) getStream(keymat *big.Int) (rc cipher.Stream, mc hash.Hash) { func (hc Conn) getStream(keymat *big.Int) (rc cipher.Stream, mc hash.Hash, err error) {
var key []byte var key []byte
var block cipher.Block var block cipher.Block
var iv []byte var iv []byte
var ivlen int var ivlen int
var err error
copts := hc.cipheropts & 0xFF copts := hc.cipheropts & 0xFF
// TODO: each cipher alg case should ensure len(keymat.Bytes()) // TODO: each cipher alg case should ensure len(keymat.Bytes())
@ -93,7 +92,8 @@ func (hc Conn) getStream(keymat *big.Int) (rc cipher.Stream, mc hash.Hash) {
default: default:
log.Printf("[invalid cipher (%d)]\n", copts) log.Printf("[invalid cipher (%d)]\n", copts)
fmt.Printf("DOOFUS SET A VALID CIPHER ALG (%d)\n", copts) fmt.Printf("DOOFUS SET A VALID CIPHER ALG (%d)\n", copts)
os.Exit(1) err = errors.New("hkexchan: INVALID CIPHER ALG")
//os.Exit(1)
} }
hopts := (hc.cipheropts >> 8) & 0xFF hopts := (hc.cipheropts >> 8) & 0xFF
@ -109,13 +109,12 @@ func (hc Conn) getStream(keymat *big.Int) (rc cipher.Stream, mc hash.Hash) {
default: default:
log.Printf("[invalid hmac (%d)]\n", hopts) log.Printf("[invalid hmac (%d)]\n", hopts)
fmt.Printf("DOOFUS SET A VALID HMAC ALG (%d)\n", hopts) fmt.Printf("DOOFUS SET A VALID HMAC ALG (%d)\n", hopts)
os.Exit(1) err = errors.New("hkexchan: INVALID HMAC ALG")
return
//os.Exit(1)
} }
if err != nil { if err != nil {
panic(err)
}
// Feed the IV into the hmac: all traffic in the connection must // Feed the IV into the hmac: all traffic in the connection must
// feed its data into the hmac afterwards, so both ends can xor // feed its data into the hmac afterwards, so both ends can xor
// that with the stream to detect corruption. // that with the stream to detect corruption.
@ -123,6 +122,6 @@ func (hc Conn) getStream(keymat *big.Int) (rc cipher.Stream, mc hash.Hash) {
var currentHash []byte var currentHash []byte
currentHash = mc.Sum(currentHash) currentHash = mc.Sum(currentHash)
log.Printf("Channel init hmac(iv):%s\n", hex.EncodeToString(currentHash)) log.Printf("Channel init hmac(iv):%s\n", hex.EncodeToString(currentHash))
}
return return
} }

View File

@ -28,6 +28,12 @@ import (
"time" "time"
) )
const (
csoNone = iota // No error, normal packet
csoHmacInvalid // HMAC mismatch detect on remote end
csoChaff // This packet is a dummy, do not process beyond decryption
)
/*---------------------------------------------------------------------*/ /*---------------------------------------------------------------------*/
// Conn is a HKex connection - a drop-in replacement for net.Conn // Conn is a HKex connection - a drop-in replacement for net.Conn
@ -149,8 +155,8 @@ func Dial(protocol string, ipport string, extensions ...string) (hc *Conn, err e
hc.h.FA() hc.h.FA()
log.Printf("**(c)** FA:%s\n", hc.h.fa) log.Printf("**(c)** FA:%s\n", hc.h.fa)
hc.r, hc.rm = hc.getStream(hc.h.fa) hc.r, hc.rm, err = hc.getStream(hc.h.fa)
hc.w, hc.wm = hc.getStream(hc.h.fa) hc.w, hc.wm, err = hc.getStream(hc.h.fa)
return return
} }
@ -262,11 +268,13 @@ func (hl HKExListener) Accept() (hc Conn, err error) {
// d is value for Herradura key exchange // d is value for Herradura key exchange
d := big.NewInt(0) d := big.NewInt(0)
_, err = fmt.Fscanln(c, d) _, err = fmt.Fscanln(c, d)
log.Printf("[Got d:%v]", d)
if err != nil { if err != nil {
return hc, err return hc, err
} }
_, err = fmt.Fscanf(c, "%08x:%08x\n", _, err = fmt.Fscanf(c, "%08x:%08x\n",
&hc.cipheropts, &hc.opts) &hc.cipheropts, &hc.opts)
log.Printf("[Got cipheropts, opts:%v, %v]", hc.cipheropts, hc.opts)
if err != nil { if err != nil {
return hc, err return hc, err
} }
@ -279,8 +287,8 @@ func (hl HKExListener) Accept() (hc Conn, err error) {
fmt.Fprintf(c, "0x%s\n%08x:%08x\n", hc.h.d.Text(16), fmt.Fprintf(c, "0x%s\n%08x:%08x\n", hc.h.d.Text(16),
hc.cipheropts, hc.opts) hc.cipheropts, hc.opts)
hc.r, hc.rm = hc.getStream(hc.h.fa) hc.r, hc.rm, err = hc.getStream(hc.h.fa)
hc.w, hc.wm = hc.getStream(hc.h.fa) hc.w, hc.wm, err = hc.getStream(hc.h.fa)
return return
} }
@ -303,9 +311,9 @@ func (c Conn) Read(b []byte) (n int, err error) {
var hmacIn [4]uint8 var hmacIn [4]uint8
var payloadLen uint32 var payloadLen uint32
// Read ctrl/status opcode (for now, set nonzero on hmac mismatch) // Read ctrl/status opcode (csoHmacInvalid on hmac mismatch)
err = binary.Read(c.c, binary.BigEndian, &ctrlStatOp) err = binary.Read(c.c, binary.BigEndian, &ctrlStatOp)
if ctrlStatOp != 0 { if ctrlStatOp == csoHmacInvalid {
// Other side indicated channel tampering, close channel // Other side indicated channel tampering, close channel
c.Close() c.Close()
return 1, errors.New("** ALERT - remote end detected HMAC mismatch - possible channel tampering **") return 1, errors.New("** ALERT - remote end detected HMAC mismatch - possible channel tampering **")
@ -328,14 +336,19 @@ func (c Conn) Read(b []byte) (n int, err error) {
if err != nil { if err != nil {
if err.Error() != "EOF" { if err.Error() != "EOF" {
panic(err) panic(err)
} // else { // Cannot just return 0, err here - client won't hang up properly
// return 0, err // when 'exit' from shell. TODO: try server sending ctrlStatOp to
//} // indicate to Reader? -rlm 20180428
} }
}
if payloadLen > 16384 { if payloadLen > 16384 {
panic("Insane payloadLen") log.Printf("[Insane payloadLen:%v]\n", payloadLen)
c.Close()
return 1, errors.New("Insane payloadLen")
} }
//log.Println("payloadLen:", payloadLen) //log.Println("payloadLen:", payloadLen)
var payloadBytes = make([]byte, payloadLen) var payloadBytes = make([]byte, payloadLen)
n, err = io.ReadFull(c.c, payloadBytes) n, err = io.ReadFull(c.c, payloadBytes)
//log.Print(" << Read ", n, " payloadBytes") //log.Print(" << Read ", n, " payloadBytes")
@ -364,8 +377,14 @@ func (c Conn) Read(b []byte) (n int, err error) {
if err != nil { if err != nil {
panic(err) panic(err)
} }
// Throw away pkt if it's chaff (ie., caller to Read() won't see this data)
if ctrlStatOp == csoChaff {
log.Printf("[Chaff pkt]\n")
} else {
c.dBuf.Write(payloadBytes) c.dBuf.Write(payloadBytes)
//log.Printf("c.dBuf: %s\n", hex.Dump(c.dBuf.Bytes())) //log.Printf("c.dBuf: %s\n", hex.Dump(c.dBuf.Bytes()))
}
// Re-calculate hmac, compare with received value // Re-calculate hmac, compare with received value
c.rm.Write(payloadBytes) c.rm.Write(payloadBytes)
@ -374,10 +393,11 @@ func (c Conn) Read(b []byte) (n int, err error) {
// Log alert if hmac didn't match, corrupted channel // Log alert if hmac didn't match, corrupted channel
if !bytes.Equal(hTmp, []byte(hmacIn[0:])) /*|| hmacIn[0] > 0xf8*/ { if !bytes.Equal(hTmp, []byte(hmacIn[0:])) /*|| hmacIn[0] > 0xf8*/ {
fmt.Println("** ALERT - hmac mismatch, possible channel tampering **") fmt.Println("** ALERT - detected HMAC mismatch, possible channel tampering **")
_, _ = c.c.Write([]byte{0x1}) _, _ = c.c.Write([]byte{csoHmacInvalid})
} }
} }
retN := c.dBuf.Len() retN := c.dBuf.Len()
if retN > len(b) { if retN > len(b) {
retN = len(b) retN = len(b)
@ -418,7 +438,7 @@ func (c Conn) Write(b []byte) (n int, err error) {
log.Printf(" ->ctext:\r\n%s\r\n", hex.Dump(wb.Bytes())) log.Printf(" ->ctext:\r\n%s\r\n", hex.Dump(wb.Bytes()))
var ctrlStatOp byte var ctrlStatOp byte
ctrlStatOp = 0x00 ctrlStatOp = csoNone
_ = binary.Write(c.c, binary.BigEndian, &ctrlStatOp) _ = binary.Write(c.c, binary.BigEndian, &ctrlStatOp)
// Write hmac LSB, payloadLen followed by payload // Write hmac LSB, payloadLen followed by payload

View File

@ -19,7 +19,7 @@ import (
"os/user" "os/user"
"github.com/jameskeane/bcrypt" "github.com/jameskeane/bcrypt"
hkexsh "blitter.com/hkexsh" hkexsh "blitter.com/go/hkexsh"
) )
func main() { func main() {

View File

@ -18,7 +18,7 @@ import (
"strings" "strings"
"sync" "sync"
hkexsh "blitter.com/hkexsh" hkexsh "blitter.com/go/hkexsh"
isatty "github.com/mattn/go-isatty" isatty "github.com/mattn/go-isatty"
) )

View File

@ -18,8 +18,8 @@ import (
"os/user" "os/user"
"syscall" "syscall"
hkexsh "blitter.com/hkexsh" hkexsh "blitter.com/go/hkexsh"
"blitter.com/hkexsh/spinsult" "blitter.com/go/hkexsh/spinsult"
"github.com/kr/pty" "github.com/kr/pty"
) )
@ -171,8 +171,10 @@ func main() {
// Wait for a connection. // Wait for a connection.
conn, err := l.Accept() conn, err := l.Accept()
if err != nil { if err != nil {
log.Fatal(err) log.Printf("Accept() got error(%v), hanging up.\n", err)
} conn.Close()
//log.Fatal(err)
} else {
log.Println("Accepted client") log.Println("Accepted client")
// Handle the connection in a new goroutine. // Handle the connection in a new goroutine.
@ -255,6 +257,7 @@ func main() {
} }
return return
}(conn) }(conn)
} // Accept() success
} //endfor } //endfor
log.Println("[Exiting]") log.Println("[Exiting]")
} }