diff --git a/hkexnet/consts.go b/hkexnet/consts.go index 9b19429..f37ea29 100644 --- a/hkexnet/consts.go +++ b/hkexnet/consts.go @@ -56,7 +56,6 @@ const ( // Tunnel setup/control/status CSOTunSetup // client -> server tunnel setup request (dstport) - CSOTunInUse // server -> client: tunnel rport is in use CSOTunSetupAck // server -> client tunnel setup ack CSOTunAccept // client -> server: tunnel client got an Accept() // (Do we need a CSOTunAcceptAck server->client?) @@ -66,23 +65,26 @@ const ( CSOTunHangup // client -> server: tunnel lport hung up ) -// TunEndpoint.tunCtl control values +// TunEndpoint.tunCtl control values - used to control workers for client or server tunnels +// depending on the code const ( - TunCtl_AcceptedClient = 'a' // client side has accept()ed a conn + TunCtl_Client_Listen = 'a' + + TunCtl_Server_Dial = 'd' // server has dialled OK, client side can accept() conns // [CSOTunAccept] // status: client listen() worker accepted conn on lport // action:server side should dial() rport on client's behalf - TunCtl_LostClient = 'h' // client side has hung up + TunCtl_Info_Hangup = 'h' // client side has hung up // [CSOTunHangup] // status: client side conn hung up from lport // action:server side should hang up on rport, on client's behalf - TunCtl_ConnRefused = 'r' // server side couldn't complete tunnel + TunCtl_Info_ConnRefused = 'r' // server side couldn't complete tunnel // [CSOTunRefused] // status:server side could not dial() remote side - - TunCtl_LostConn = 'l' // server side disconnected + + TunCtl_Info_LostConn = 'x' // server side disconnected // [CSOTunDisconn] // status:server side lost connection to rport // action:client should disconnect accepted lport connection diff --git a/hkexnet/hkexnet.go b/hkexnet/hkexnet.go index 2d7656b..270ec11 100644 --- a/hkexnet/hkexnet.go +++ b/hkexnet/hkexnet.go @@ -85,7 +85,7 @@ type ( Cols uint16 chaff ChaffConfig - tuns map[uint16]*TunEndpoint + tuns *map[uint16](*TunEndpoint) closeStat *CSOType // close status (CSOExitStatus) r cipher.Stream //read cipherStream @@ -208,6 +208,8 @@ func _new(kexAlg KEXAlg, conn *net.Conn) (hc *Conn, e error) { closeStat: new(CSOType), WinCh: make(chan WinSize, 1), dBuf: new(bytes.Buffer)} + tempMap := make(map[uint16]*TunEndpoint) + hc.tuns = &tempMap *hc.closeStat = CSEStillOpen // open or prematurely-closed status @@ -814,39 +816,58 @@ func (hc Conn) Read(b []byte) (n int, err error) { // server side tunnel setup in response to client lport := binary.BigEndian.Uint16(payloadBytes[0:2]) rport := binary.BigEndian.Uint16(payloadBytes[2:4]) - logger.LogDebug(fmt.Sprintf("Read(): Tunnel setup [%d:%d]", lport, rport)) - hc.StartServerTunnel(lport, rport) - hc.tuns[rport].Ctl <- 'a' // Dial() rport + logger.LogDebug(fmt.Sprintf("mapkey is %d", rport)) + if _, ok := (*hc.tuns)[rport]; !ok { + // tunnel first-time open + logger.LogDebug(fmt.Sprintf("[Server] Got Initial CSOTunSetup [%d:%d]", lport, rport)) + hc.StartServerTunnel(lport, rport) + } else { + logger.LogDebug(fmt.Sprintf("[Server] Got CSOTunSetup [%d:%d]", lport, rport)) + } + (*hc.tuns)[rport].Ctl <- 'd' // Dial() rport } else if ctrlStatOp == CSOTunSetupAck { lport := binary.BigEndian.Uint16(payloadBytes[0:2]) rport := binary.BigEndian.Uint16(payloadBytes[2:4]) - logger.LogDebug(fmt.Sprintf("Read(): Tunnel setup ack [%d:%d]", lport, rport)) - hc.dBuf.Write(payloadBytes) + logger.LogDebug(fmt.Sprintf("mapkey is %d\n", rport)) + if _, ok := (*hc.tuns)[rport]; !ok { + // tunnel first-time open + logger.LogDebug(fmt.Sprintf("[Client] Got Initial CSOTunSetupAck [%d:%d]", lport, rport)) + hc.StartClientTunnel(lport, rport) + } else { + logger.LogDebug(fmt.Sprintf("[Client] Got CSOTunSetupAck [%d:%d]", lport, rport)) + } + (*hc.tuns)[rport].Ctl <- 'a' // Listen() for lport connection } else if ctrlStatOp == CSOTunRefused { - // client side has been told nothing is listening on rport + // client side receiving CSOTunRefused means the remote side + // could not dial() rport. So we cannot yet listen() + // for client-side on lport. lport := binary.BigEndian.Uint16(payloadBytes[0:2]) rport := binary.BigEndian.Uint16(payloadBytes[2:4]) - logger.LogDebug(fmt.Sprintf("Read(): Tunnel refused [%d:%d]", lport, rport)) - hc.dBuf.Write(payloadBytes) + logger.LogDebug(fmt.Sprintf("mapkey is %d\n", rport)) + logger.LogDebug(fmt.Sprintf("[Client] Got CSOTunRefused [%d:%d]", lport, rport)) + (*hc.tuns)[rport].Ctl <- 'r' // client should NOT Listen() } else if ctrlStatOp == CSOTunDisconn { // server side's rport has disconnected (server lost) lport := binary.BigEndian.Uint16(payloadBytes[0:2]) rport := binary.BigEndian.Uint16(payloadBytes[2:4]) - logger.LogDebug(fmt.Sprintf("Read(): Tunnel server disconnected [%d:%d]", lport, rport)) - hc.dBuf.Write(payloadBytes) + logger.LogDebug(fmt.Sprintf("mapkey is %d\n", rport)) + logger.LogDebug(fmt.Sprintf("[Client] Got CSOTunDisconn [%d:%d]", lport, rport)) + (*hc.tuns)[rport].Ctl <- 'x' // client should hangup on current lport conn } else if ctrlStatOp == CSOTunHangup { // client side's lport has hung up lport := binary.BigEndian.Uint16(payloadBytes[0:2]) rport := binary.BigEndian.Uint16(payloadBytes[2:4]) - logger.LogDebug(fmt.Sprintf("Read(): Tunnel client hung up [%d:%d]", lport, rport)) - hc.dBuf.Write(payloadBytes) + logger.LogDebug(fmt.Sprintf("mapkey is %d\n", rport)) + logger.LogDebug(fmt.Sprintf("[Server] Got CSOTunHangup [%d:%d]", lport, rport)) + (*hc.tuns)[rport].Ctl <- 'h' // server should hang up on currently-dialled rport } else if ctrlStatOp == CSOTunData { lport := binary.BigEndian.Uint16(payloadBytes[0:2]) rport := binary.BigEndian.Uint16(payloadBytes[2:4]) + logger.LogDebug(fmt.Sprintf("mapkey is %d\n", rport)) //fmt.Printf("[Got CSOTunData: [lport %d:rport %d] data:%v\n", lport, rport, payloadBytes[4:]) - if hc.tuns[rport] != nil { + if _, ok := (*hc.tuns)[rport]; ok { logger.LogDebug(fmt.Sprintf("[Writing data to rport [%d:%d]", lport, rport)) - hc.tuns[rport].Data <- payloadBytes[4:] + (*hc.tuns)[rport].Data <- payloadBytes[4:] } else { logger.LogDebug(fmt.Sprintf("[Attempt to write data to closed tun [%d:%d]", lport, rport)) } diff --git a/hkexnet/hkextun.go b/hkexnet/hkextun.go index 1cfe66c..1947ce8 100644 --- a/hkexnet/hkextun.go +++ b/hkexnet/hkextun.go @@ -48,118 +48,147 @@ type ( ) func (hc *Conn) InitTunEndpoint(lp uint16, p string /* net.Addr */, rp uint16) { - if hc.tuns == nil { - hc.tuns = make(map[uint16]*TunEndpoint) + if (*hc.tuns) == nil { + (*hc.tuns) = make(map[uint16]*TunEndpoint) } - if hc.tuns[rp] == nil { + if (*hc.tuns)[rp] == nil { var addrs []net.Addr if p == "" { addrs, _ = net.InterfaceAddrs() p = addrs[0].String() } - hc.tuns[rp] = &TunEndpoint{ /*Status: CSOTunSetup,*/ Peer: p, + (*hc.tuns)[rp] = &TunEndpoint{ /*Status: CSOTunSetup,*/ Peer: p, Lport: lp, Rport: rp, Data: make(chan []byte, 1), Ctl: make(chan rune, 1)} - logger.LogDebug(fmt.Sprintf("InitTunEndpoint [%d:%s:%d]\n", lp, p, rp)) + logger.LogDebug(fmt.Sprintf("InitTunEndpoint [%d:%s:%d]", lp, p, rp)) + } else { + logger.LogDebug(fmt.Sprintf("InitTunEndpoint [reusing] [%d:%s:%d]", (*hc.tuns)[rp].Lport, (*hc.tuns)[rp].Peer, (*hc.tuns)[rp].Rport)) } return } func (hc *Conn) StartClientTunnel(lport, rport uint16) { hc.InitTunEndpoint(lport, "", rport) - t := hc.tuns[rport] // for convenience - + t := (*hc.tuns)[rport] // for convenience + var l HKExListener go func() { - logger.LogDebug(fmt.Sprintf("Listening for client tunnel port %d", lport)) - l, e := net.Listen("tcp", fmt.Sprintf(":%d", lport)) - if e != nil { - logger.LogDebug(fmt.Sprintf("[Could not get lport %d! (%s)", lport, e)) - } else { - defer l.Close() - for { - c, e := l.Accept() - - defer func() { - c.Close() - }() + weAreListening := false + for cmd := range t.Ctl { + logger.LogDebug(fmt.Sprintf("[ClientTun] Listening for client tunnel port %d", lport)) + if cmd == 'a' && !weAreListening { + l, e := net.Listen("tcp", fmt.Sprintf(":%d", lport)) if e != nil { - logger.LogDebug(fmt.Sprintf("Accept() got error(%v), hanging up.", e)) - break + logger.LogDebug(fmt.Sprintf("[ClientTun] Could not get lport %d! (%s)", lport, e)) } else { - logger.LogDebug(fmt.Sprintln("Accepted tunnel client")) - - // outside client -> tunnel lport - go func() { + weAreListening = true + for { + c, e := l.Accept() var tunDst bytes.Buffer + // ask server to dial() its side, rport binary.Write(&tunDst, binary.BigEndian, lport) binary.Write(&tunDst, binary.BigEndian, rport) - for { - rBuf := make([]byte, 1024) - //Read data from c, encrypt/write via hc to client(lport) - n, e := c.Read(rBuf) - if e != nil { - if e == io.EOF { - logger.LogDebug(fmt.Sprintf("lport Disconnected: shutting down tunnel [%d:%d]", lport, rport)) - } else { - logger.LogDebug(fmt.Sprintf("Read error from lport of tun [%d:%d]\n%s", lport, rport, e)) + hc.WritePacket(tunDst.Bytes(), CSOTunSetup) + + if e != nil { + logger.LogDebug(fmt.Sprintf("[ClientTun] Accept() got error(%v), hanging up.", e)) + break + } else { + logger.LogDebug(fmt.Sprintf("[ClientTun] Accepted tunnel client %v", t)) + + // outside client -> tunnel lport + go func() { + defer func() { + c.Close() + }() + + var tunDst bytes.Buffer + binary.Write(&tunDst, binary.BigEndian, lport) + binary.Write(&tunDst, binary.BigEndian, rport) + for { + rBuf := make([]byte, 1024) + //Read data from c, encrypt/write via hc to client(lport) + n, e := c.Read(rBuf) + if e != nil { + if e == io.EOF { + logger.LogDebug(fmt.Sprintf("[ClientTun] lport Disconnected: shutting down tunnel %v", t)) + } else { + logger.LogDebug(fmt.Sprintf("[ClientTun] Read error from lport of tun %v\n%s", t, e)) + } + hc.WritePacket(tunDst.Bytes(), CSOTunHangup) + break + } + if n > 0 { + rBuf = append(tunDst.Bytes(), rBuf[:n]...) + _, de := hc.WritePacket(rBuf[:n+4], CSOTunData) + if de != nil { + logger.LogDebug(fmt.Sprintf("[ClientTun] Error writing to tunnel %v, %s]\n", t, de)) + break + } + } } - hc.WritePacket(tunDst.Bytes(), CSOTunHangup) - break - } - if n > 0 { - rBuf = append(tunDst.Bytes(), rBuf[:n]...) - hc.WritePacket(rBuf[:n+4], CSOTunData) - } - } - }() + }() - // tunnel lport -> outside client (c) - go func() { - defer func() { - c.Close() - }() + // tunnel lport -> outside client (c) + go func() { + defer func() { + c.Close() + }() - for { - bytes, ok := <-t.Data - if ok { - c.Write(bytes) - } else { - logger.LogDebug(fmt.Sprintf("[Channel closed?]\n")) - break - } - } - }() + for { + bytes, ok := <-t.Data + if ok { + _, e := c.Write(bytes) + if e != nil { + logger.LogDebug(fmt.Sprintf("[ClientTun] lport conn closed")) + break + } + } else { + logger.LogDebug(fmt.Sprintf("[ClientTun] Channel closed?")) + break + } + } + }() - } + } // end Accept() worker block + } // end for-accept + } // end Listen() block + } else if cmd == 'r' { + logger.LogDebug(fmt.Sprintf("[ClientTun] Server replied TunRefused %v\n", t)) + } else if cmd == 'x' { + logger.LogDebug(fmt.Sprintf("[ClientTun] Server replied TunDisconn, closing lport %v\n", t)) + l.Close() + weAreListening = false } - } + } // end t.Ctl for }() } func (hc *Conn) StartServerTunnel(lport, rport uint16) { hc.InitTunEndpoint(lport, "", rport) - t := hc.tuns[rport] // for convenience + t := (*hc.tuns)[rport] // for convenience var err error go func() { + weAreDialled := false for cmd := range t.Ctl { var c net.Conn - if cmd == 'a' { - logger.LogDebug("Server dialling...") + if cmd == 'd' && !weAreDialled { + logger.LogDebug("[ServerTun] dialling...") c, err = net.Dial("tcp", fmt.Sprintf(":%d", rport)) if err != nil { - logger.LogDebug(fmt.Sprintf("Nothing is serving at rport :%d!", rport)) + logger.LogDebug(fmt.Sprintf("[ServerTun] Dial() error for tun %v: %s", t, err)) var resp bytes.Buffer binary.Write(&resp, binary.BigEndian /*lport*/, uint16(0)) binary.Write(&resp, binary.BigEndian, rport) hc.WritePacket(resp.Bytes(), CSOTunRefused) } else { - logger.LogDebug(fmt.Sprintf("[Tunnel Opened - %d:%s:%d]", lport, t.Peer, rport)) + logger.LogDebug(fmt.Sprintf("[ServerTun] Tunnel Opened - %v", t)) + weAreDialled = true var resp bytes.Buffer binary.Write(&resp, binary.BigEndian, lport) binary.Write(&resp, binary.BigEndian, rport) - logger.LogDebug(fmt.Sprintf("[Writing CSOTunSetupAck[%d:%d]", lport, rport)) + logger.LogDebug(fmt.Sprintf("[ServerTun] Writing CSOTunSetupAck %v", t)) hc.WritePacket(resp.Bytes(), CSOTunSetupAck) // @@ -167,7 +196,9 @@ func (hc *Conn) StartServerTunnel(lport, rport uint16) { // go func() { defer func() { + logger.LogDebug("[ServerTun] (deferred hangup workerA)") c.Close() + weAreDialled = false }() var tunDst bytes.Buffer @@ -179,15 +210,15 @@ func (hc *Conn) StartServerTunnel(lport, rport uint16) { n, e := c.Read(rBuf) if e != nil { if e == io.EOF { - logger.LogDebug(fmt.Sprintf("rport Disconnected: shutting down tunnel %v\n", t)) + logger.LogDebug(fmt.Sprintf("[ServerTun] rport Disconnected: shutting down tunnel %v", t)) } else { - logger.LogDebug(fmt.Sprintf("Read error from rport of tun %v\n%s", t, e)) + logger.LogDebug(fmt.Sprintf("[ServerTun] Read error from rport of tun %v: %s", t, e)) } var resp bytes.Buffer binary.Write(&resp, binary.BigEndian, lport) binary.Write(&resp, binary.BigEndian, rport) hc.WritePacket(resp.Bytes(), CSOTunDisconn) - logger.LogDebug(fmt.Sprintf("Closing server rport %d net.Dial()", t.Rport)) + logger.LogDebug(fmt.Sprintf("[ServerTun] Closing server rport %d net.Dial()", t.Rport)) break } if n > 0 { @@ -200,22 +231,32 @@ func (hc *Conn) StartServerTunnel(lport, rport uint16) { // worker to read data from client (already decrypted) & fwd to rport go func() { defer func() { + logger.LogDebug("[ServerTun] (deferred hangup workerB)") c.Close() + weAreDialled = false }() for { rData, ok := <-t.Data if ok { - c.Write(rData) + _, e := c.Write(rData) + if e != nil { + logger.LogDebug(fmt.Sprintf("[ServerTun] ERROR writing to rport conn")) + break + } } else { - logger.LogDebug("[ERROR reading from hc.tuns[] channel - closed?]") + logger.LogDebug("[ServerTun] ERROR reading from hc.tuns[] channel - closed?") break } } }() } - } // TODO: elseifs for other state transtions driven by client - } - }() // t.Ctl read loop - logger.LogDebug("[ServerTunnel() exiting t.Ctl read loop - channel closed??]") + } else if cmd == 'h' { + // client side has hung up + logger.LogDebug(fmt.Sprintf("[ServerTun] Client hung up: hanging up on rport %v", t)) + weAreDialled = false + } + } // t.Ctl read loop + logger.LogDebug("[ServerTun] Tunnel exiting t.Ctl read loop - channel closed??") + }() } diff --git a/hkexsh/hkexsh.go b/hkexsh/hkexsh.go index 6ef531c..d6bc24e 100755 --- a/hkexsh/hkexsh.go +++ b/hkexsh/hkexsh.go @@ -349,26 +349,7 @@ func reqTunnel(hc *hkexnet.Conn, lp uint16, p string /*net.Addr*/, rp uint16) { fmt.Printf("bTmp:%x\n", bTmp.Bytes()) logger.LogDebug(fmt.Sprintln("[Client sending CSOTunSetup]")) hc.WritePacket(bTmp.Bytes(), hkexnet.CSOTunSetup) - - // Server should reply immediately with CSOTunSetupAck[lport:rport] - // hkexnet.Read() on server side handles server side tun setup. - resp := make([]byte, 4) - var lpResp, rpResp uint16 - n, e := io.ReadFull(hc, resp) - if n < 4 || e != nil { - logger.LogErr(fmt.Sprintf("[Client tun response len %d, %s\n", n, e)) - } else { - lpResp = binary.BigEndian.Uint16(resp[0:2]) - rpResp = binary.BigEndian.Uint16(resp[2:4]) - } - if lpResp == lp && rpResp == rp { - logger.LogDebug("[Client got tun setup ack OK]") - hc.StartClientTunnel(lp, rp) - } else { - logger.LogDebug(fmt.Sprintf("[Client tun response ports [%d:%d]\n", lpResp, rpResp)) - logger.LogDebug(fmt.Sprintln("[Client tun setup FAILED]")) - } return }