diff --git a/hkexnet/hkexnet.go b/hkexnet/hkexnet.go index 955591e..00acdbe 100644 --- a/hkexnet/hkexnet.go +++ b/hkexnet/hkexnet.go @@ -842,6 +842,11 @@ func (hc Conn) Read(b []byte) (n int, err error) { lport := binary.BigEndian.Uint16(payloadBytes[0:2]) rport := binary.BigEndian.Uint16(payloadBytes[2:4]) logger.LogDebug(fmt.Sprintf("[Client] Got CSOTunRefused [%d:%d]", lport, rport)) + if _, ok := (*hc.tuns)[rport]; ok { + (*hc.tuns)[rport].Died = true + } else { + logger.LogDebug(fmt.Sprintf("[Client] CSOTunRefused on already-closed tun [%d:%d]", lport, rport)) + } } else if ctrlStatOp == CSOTunDisconn { // server side's rport has disconnected (server lost) lport := binary.BigEndian.Uint16(payloadBytes[0:2]) diff --git a/hkexnet/hkextun.go b/hkexnet/hkextun.go index 96042db..915bb12 100644 --- a/hkexnet/hkextun.go +++ b/hkexnet/hkextun.go @@ -52,7 +52,7 @@ type ( ) func (hc *Conn) CollapseAllTunnels(client bool) { - for k,t := range *hc.tuns { + for k, t := range *hc.tuns { var tunDst bytes.Buffer binary.Write(&tunDst, binary.BigEndian, t.Lport) binary.Write(&tunDst, binary.BigEndian, t.Rport) @@ -80,7 +80,7 @@ func (hc *Conn) InitTunEndpoint(lp uint16, p string /* net.Addr */, rp uint16) { Ctl: make(chan rune, 1)} logger.LogDebug(fmt.Sprintf("InitTunEndpoint [%d:%s:%d]", lp, p, rp)) } else { - logger.LogDebug(fmt.Sprintf("InitTunEndpoint [reusing] [%d:%s:%d]", (*hc.tuns)[rp].Lport, (*hc.tuns)[rp].Peer, (*hc.tuns)[rp].Rport)) + logger.LogDebug(fmt.Sprintf("InitTunEndpoint [reusing] %v", (*hc.tuns)[rp])) if (*hc.tuns)[rp].Data == nil { // When re-using a tunnel it will have its // data channel removed on closure. Re-create it @@ -93,26 +93,26 @@ func (hc *Conn) InitTunEndpoint(lp uint16, p string /* net.Addr */, rp uint16) { func (hc *Conn) StartClientTunnel(lport, rport uint16) { hc.InitTunEndpoint(lport, "", rport) - var l HKExListener + go func() { var wg sync.WaitGroup weAreListening := false - for cmd := range (*hc.tuns)[rport].Ctl { - logger.LogDebug(fmt.Sprintf("[ClientTun] Listening for client tunnel port %d", lport)) + for cmd := range (*hc.tuns)[rport].Ctl { if cmd == 'a' && !weAreListening { l, e := net.Listen("tcp4", fmt.Sprintf(":%d", lport)) if e != nil { logger.LogDebug(fmt.Sprintf("[ClientTun] Could not get lport %d! (%s)", lport, e)) } else { weAreListening = true + logger.LogDebug(fmt.Sprintf("[ClientTun] Listening for client tunnel port %d", lport)) + for { + c, e := l.Accept() // If tunnel is being re-used, re-init it if (*hc.tuns)[rport] == nil { hc.InitTunEndpoint(lport, "", rport) } - c, e := l.Accept() - // ask server to dial() its side, rport var tunDst bytes.Buffer binary.Write(&tunDst, binary.BigEndian, lport) @@ -121,7 +121,6 @@ func (hc *Conn) StartClientTunnel(lport, rport uint16) { if e != nil { logger.LogDebug(fmt.Sprintf("[ClientTun] Accept() got error(%v), hanging up.", e)) - //break } else { logger.LogDebug(fmt.Sprintf("[ClientTun] Accepted tunnel client %v", (*hc.tuns)[rport])) @@ -145,7 +144,7 @@ func (hc *Conn) StartClientTunnel(lport, rport uint16) { for { rBuf := make([]byte, 1024) //Read data from c, encrypt/write via hc to client(lport) - c.SetReadDeadline(time.Now().Add(20 * time.Second)) + c.SetReadDeadline(time.Now().Add(200 * time.Millisecond)) n, e := c.Read(rBuf) if e != nil { if e == io.EOF { @@ -211,7 +210,7 @@ func (hc *Conn) StartClientTunnel(lport, rport uint16) { for { bytes, ok := <-(*hc.tuns)[rport].Data if ok { - c.SetWriteDeadline(time.Now().Add(20 * time.Second)) + c.SetWriteDeadline(time.Now().Add(200 * time.Millisecond)) _, e := c.Write(bytes) if e != nil { logger.LogDebug(fmt.Sprintf("[ClientTun] worker B: lport conn closed")) @@ -234,11 +233,6 @@ func (hc *Conn) StartClientTunnel(lport, rport uint16) { } else if cmd == 'r' { logger.LogDebug(fmt.Sprintf("[ClientTun] Server replied TunRefused %v\n", (*hc.tuns)[rport])) } - _ = l //else if cmd == 'x' { - //logger.LogDebug(fmt.Sprintf("[ClientTun] Server replied TunDisconn, closing lport %v\n", t)) - //l.Close() - //weAreListening = false - //} } // end t.Ctl for }() } @@ -298,7 +292,7 @@ func (hc *Conn) StartServerTunnel(lport, rport uint16) { for { rBuf := make([]byte, 1024) // Read data from c, encrypt/write via hc to client(lport) - c.SetReadDeadline(time.Now().Add(20 * time.Second)) + c.SetReadDeadline(time.Now().Add(200 * time.Millisecond)) n, e := c.Read(rBuf) if e != nil { if e == io.EOF { @@ -315,8 +309,6 @@ func (hc *Conn) StartServerTunnel(lport, rport uint16) { } else if strings.Contains(e.Error(), "i/o timeout") { if (*hc.tuns)[rport].Died { logger.LogDebug(fmt.Sprintf("[ServerTun] worker A: timeout: Server side died, hanging up %v", (*hc.tuns)[rport])) - //hc.WritePacket(tunDst.Bytes(), CSOTunDisconn) - //(*hc.tuns)[rport].Died = true if (*hc.tuns)[rport].Data != nil { close((*hc.tuns)[rport].Data) (*hc.tuns)[rport].Data = nil @@ -360,7 +352,7 @@ func (hc *Conn) StartServerTunnel(lport, rport uint16) { for { rData, ok := <-(*hc.tuns)[rport].Data if ok { - c.SetWriteDeadline(time.Now().Add(20 * time.Second)) + c.SetWriteDeadline(time.Now().Add(200 * time.Millisecond)) _, e := c.Write(rData) if e != nil { logger.LogDebug(fmt.Sprintf("[ServerTun] worker B: ERROR writing to rport conn")) @@ -379,6 +371,5 @@ func (hc *Conn) StartServerTunnel(lport, rport uint16) { } } // t.Ctl read loop logger.LogDebug("[ServerTun] Tunnel exiting t.Ctl read loop - channel closed??") - //wg.Wait() }() } diff --git a/hkexsh/hkexsh.go b/hkexsh/hkexsh.go index 5c1ea4a..d6bc24e 100755 --- a/hkexsh/hkexsh.go +++ b/hkexsh/hkexsh.go @@ -273,7 +273,6 @@ func doShellMode(isInteractive bool, conn *hkexnet.Conn, oldState *hkexsh.State, // gracefully here if !strings.HasSuffix(inerr.Error(), "use of closed network connection") { log.Println(inerr) - conn.CollapseAllTunnels(true) os.Exit(1) } } @@ -311,7 +310,6 @@ func doShellMode(isInteractive bool, conn *hkexnet.Conn, oldState *hkexsh.State, fmt.Println(outerr) _ = hkexsh.Restore(int(os.Stdin.Fd()), oldState) // Best effort. log.Println("[Hanging up]") - conn.CollapseAllTunnels(true) os.Exit(0) } }() @@ -648,7 +646,6 @@ func main() { doShellMode(isInteractive, &conn, oldState, rec) } else { // copyMode _, s := doCopyMode(&conn, pathIsDest, fileArgs, rec) - conn.CollapseAllTunnels(true) rec.SetStatus(s) }