Added locking APIs for most Conn/Tun fields, save <- Data/ShutdownTun() race

Signed-off-by: Russ Magee <rmagee@gmail.com>
This commit is contained in:
Russ Magee 2019-06-27 22:10:59 -07:00
parent c327b2ec72
commit 8f5366fff4
2 changed files with 70 additions and 61 deletions

View File

@ -122,6 +122,14 @@ func Init(d bool, c string, f logger.Priority) {
_initLogging(d, c, f) _initLogging(d, c, f)
} }
func (hc *Conn) Lock() {
hc.m.Lock()
}
func (hc *Conn) Unlock() {
hc.m.Unlock()
}
func (hc Conn) GetStatus() CSOType { func (hc Conn) GetStatus() CSOType {
return *hc.closeStat return *hc.closeStat
} }
@ -1084,7 +1092,7 @@ func (hc Conn) Read(b []byte) (n int, err error) {
rport := binary.BigEndian.Uint16(payloadBytes[2:4]) rport := binary.BigEndian.Uint16(payloadBytes[2:4])
logger.LogDebug(fmt.Sprintf("[Client] Got CSOTunRefused [%d:%d]", lport, rport)) logger.LogDebug(fmt.Sprintf("[Client] Got CSOTunRefused [%d:%d]", lport, rport))
if _, ok := (*hc.tuns)[rport]; ok { if _, ok := (*hc.tuns)[rport]; ok {
(*hc.tuns)[rport].Died = true hc.MarkTunDead(rport)
} else { } else {
logger.LogDebug(fmt.Sprintf("[Client] CSOTunRefused on already-closed tun [%d:%d]", lport, rport)) logger.LogDebug(fmt.Sprintf("[Client] CSOTunRefused on already-closed tun [%d:%d]", lport, rport))
} }
@ -1094,7 +1102,7 @@ func (hc Conn) Read(b []byte) (n int, err error) {
rport := binary.BigEndian.Uint16(payloadBytes[2:4]) rport := binary.BigEndian.Uint16(payloadBytes[2:4])
logger.LogDebug(fmt.Sprintf("[Client] Got CSOTunDisconn [%d:%d]", lport, rport)) logger.LogDebug(fmt.Sprintf("[Client] Got CSOTunDisconn [%d:%d]", lport, rport))
if _, ok := (*hc.tuns)[rport]; ok { if _, ok := (*hc.tuns)[rport]; ok {
(*hc.tuns)[rport].Died = true hc.MarkTunDead(rport)
} else { } else {
logger.LogDebug(fmt.Sprintf("[Client] CSOTunDisconn on already-closed tun [%d:%d]", lport, rport)) logger.LogDebug(fmt.Sprintf("[Client] CSOTunDisconn on already-closed tun [%d:%d]", lport, rport))
} }
@ -1104,7 +1112,7 @@ func (hc Conn) Read(b []byte) (n int, err error) {
rport := binary.BigEndian.Uint16(payloadBytes[2:4]) rport := binary.BigEndian.Uint16(payloadBytes[2:4])
logger.LogDebug(fmt.Sprintf("[Server] Got CSOTunHangup [%d:%d]", lport, rport)) logger.LogDebug(fmt.Sprintf("[Server] Got CSOTunHangup [%d:%d]", lport, rport))
if _, ok := (*hc.tuns)[rport]; ok { if _, ok := (*hc.tuns)[rport]; ok {
(*hc.tuns)[rport].Died = true hc.MarkTunDead(rport)
} else { } else {
logger.LogDebug(fmt.Sprintf("[Server] CSOTunHangup to already-closed tun [%d:%d]", lport, rport)) logger.LogDebug(fmt.Sprintf("[Server] CSOTunHangup to already-closed tun [%d:%d]", lport, rport))
} }
@ -1117,7 +1125,7 @@ func (hc Conn) Read(b []byte) (n int, err error) {
logger.LogDebug(fmt.Sprintf("[Writing data to rport [%d:%d]", lport, rport)) logger.LogDebug(fmt.Sprintf("[Writing data to rport [%d:%d]", lport, rport))
} }
(*hc.tuns)[rport].Data <- payloadBytes[4:] (*hc.tuns)[rport].Data <- payloadBytes[4:]
(*hc.tuns)[rport].KeepAlive = 0 hc.ResetTunnelAge(rport)
} else { } else {
logger.LogDebug(fmt.Sprintf("[Attempt to write data to closed tun [%d:%d]", lport, rport)) logger.LogDebug(fmt.Sprintf("[Attempt to write data to closed tun [%d:%d]", lport, rport))
} }
@ -1212,7 +1220,7 @@ func (hc *Conn) WritePacket(b []byte, ctrlStatOp byte) (n int, err error) {
// //
// Would be nice to determine if the mutex scope // Would be nice to determine if the mutex scope
// could be tightened. // could be tightened.
hc.m.Lock() hc.Lock()
payloadLen = uint32(len(b)) payloadLen = uint32(len(b))
//!fmt.Printf(" --== payloadLen:%d\n", payloadLen) //!fmt.Printf(" --== payloadLen:%d\n", payloadLen)
if hc.logPlainText { if hc.logPlainText {
@ -1254,7 +1262,7 @@ func (hc *Conn) WritePacket(b []byte, ctrlStatOp byte) (n int, err error) {
} else { } else {
//fmt.Println("[a]WriteError!") //fmt.Println("[a]WriteError!")
} }
hc.m.Unlock() hc.Unlock()
if err != nil { if err != nil {
log.Println(err) log.Println(err)

View File

@ -16,6 +16,7 @@ import (
"net" "net"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
"blitter.com/go/hkexsh/logger" "blitter.com/go/hkexsh/logger"
@ -46,7 +47,7 @@ type (
Lport uint16 // ... ie., RPort is on server, LPort is on client Lport uint16 // ... ie., RPort is on server, LPort is on client
Peer string //net.Addr Peer string //net.Addr
Died bool // set by client upon receipt of a CSOTunDisconn Died bool // set by client upon receipt of a CSOTunDisconn
KeepAlive uint // must be reset by client to keep server dial() alive KeepAlive uint32 // must be reset by client to keep server dial() alive
Ctl chan rune //See TunCtl_* consts Ctl chan rune //See TunCtl_* consts
Data chan []byte Data chan []byte
} }
@ -67,6 +68,8 @@ func (hc *Conn) CollapseAllTunnels(client bool) {
} }
func (hc *Conn) InitTunEndpoint(lp uint16, p string /* net.Addr */, rp uint16) { func (hc *Conn) InitTunEndpoint(lp uint16, p string /* net.Addr */, rp uint16) {
hc.Lock()
defer hc.Unlock()
if (*hc.tuns) == nil { if (*hc.tuns) == nil {
(*hc.tuns) = make(map[uint16]*TunEndpoint) (*hc.tuns) = make(map[uint16]*TunEndpoint)
} }
@ -87,6 +90,7 @@ func (hc *Conn) InitTunEndpoint(lp uint16, p string /* net.Addr */, rp uint16) {
// data channel removed on closure. Re-create it // data channel removed on closure. Re-create it
(*hc.tuns)[rp].Data = make(chan []byte, 1) (*hc.tuns)[rp].Data = make(chan []byte, 1)
} }
(*hc.tuns)[rp].KeepAlive = 0
(*hc.tuns)[rp].Died = false (*hc.tuns)[rp].Died = false
} }
return return
@ -149,37 +153,23 @@ func (hc *Conn) StartClientTunnel(lport, rport uint16) {
if e == io.EOF { if e == io.EOF {
logger.LogDebug(fmt.Sprintf("[ClientTun] worker A: lport Disconnected: shutting down tunnel %v", (*hc.tuns)[rport])) logger.LogDebug(fmt.Sprintf("[ClientTun] worker A: lport Disconnected: shutting down tunnel %v", (*hc.tuns)[rport]))
// if Died was already set, server-side already is gone. // if Died was already set, server-side already is gone.
if !(*hc.tuns)[rport].Died { if hc.TunIsAlive(rport) {
hc.WritePacket(tunDst.Bytes(), CSOTunHangup) hc.WritePacket(tunDst.Bytes(), CSOTunHangup)
} }
(*hc.tuns)[rport].Died = true hc.ShutdownTun(rport)
if (*hc.tuns)[rport].Data != nil {
close((*hc.tuns)[rport].Data)
(*hc.tuns)[rport].Data = nil
}
delete((*hc.tuns), rport)
break break
} else if strings.Contains(e.Error(), "i/o timeout") { } else if strings.Contains(e.Error(), "i/o timeout") {
if (*hc.tuns)[rport].Died { if !hc.TunIsAlive(rport) {
logger.LogDebug(fmt.Sprintf("[ClientTun] worker A: timeout: Server side died, hanging up %v", (*hc.tuns)[rport])) logger.LogDebug(fmt.Sprintf("[ClientTun] worker A: timeout: Server side died, hanging up %v", (*hc.tuns)[rport]))
if (*hc.tuns)[rport].Data != nil { hc.ShutdownTun(rport)
close((*hc.tuns)[rport].Data)
(*hc.tuns)[rport].Data = nil
}
delete((*hc.tuns), rport)
break break
} }
} else { } else {
logger.LogDebug(fmt.Sprintf("[ClientTun] worker A: Read error from lport of tun %v\n%s", (*hc.tuns)[rport], e)) logger.LogDebug(fmt.Sprintf("[ClientTun] worker A: Read error from lport of tun %v\n%s", (*hc.tuns)[rport], e))
if !(*hc.tuns)[rport].Died { if hc.TunIsAlive(rport) {
hc.WritePacket(tunDst.Bytes(), CSOTunHangup) hc.WritePacket(tunDst.Bytes(), CSOTunHangup)
} }
(*hc.tuns)[rport].Died = true hc.ShutdownTun(rport)
if (*hc.tuns)[rport].Data != nil {
close((*hc.tuns)[rport].Data)
(*hc.tuns)[rport].Data = nil
}
delete((*hc.tuns), rport)
break break
} }
} }
@ -232,7 +222,7 @@ func (hc *Conn) StartClientTunnel(lport, rport uint16) {
// When both workers have exited due to a disconnect or other // When both workers have exited due to a disconnect or other
// condition, it's safe to remove the tunnel descriptor. // condition, it's safe to remove the tunnel descriptor.
logger.LogDebug("[ClientTun] workers exited") logger.LogDebug("[ClientTun] workers exited")
delete((*hc.tuns), rport) hc.ShutdownTun(rport)
} // end for-accept } // end for-accept
} // end Listen() block } // end Listen() block
} }
@ -240,6 +230,39 @@ func (hc *Conn) StartClientTunnel(lport, rport uint16) {
}() }()
} }
func (hc *Conn) AgeTunnel(endp uint16) uint32 {
return atomic.AddUint32(&(*hc.tuns)[endp].KeepAlive, 1)
}
func (hc *Conn) ResetTunnelAge(endp uint16) {
atomic.StoreUint32(&(*hc.tuns)[endp].KeepAlive, 0)
}
func (hc *Conn) TunIsAlive(endp uint16) bool {
hc.Lock()
defer hc.Unlock()
return !(*hc.tuns)[endp].Died
}
func (hc *Conn) MarkTunDead(endp uint16) {
hc.Lock()
defer hc.Unlock()
(*hc.tuns)[endp].Died = true
}
func (hc *Conn) ShutdownTun(endp uint16) {
hc.Lock()
defer hc.Unlock()
if (*hc.tuns)[endp] != nil {
(*hc.tuns)[endp].Died = true
if (*hc.tuns)[endp].Data != nil {
close((*hc.tuns)[endp].Data)
(*hc.tuns)[endp].Data = nil
}
}
delete((*hc.tuns), endp)
}
func (hc *Conn) StartServerTunnel(lport, rport uint16) { func (hc *Conn) StartServerTunnel(lport, rport uint16) {
hc.InitTunEndpoint(lport, "", rport) hc.InitTunEndpoint(lport, "", rport)
var err error var err error
@ -260,9 +283,9 @@ func (hc *Conn) StartServerTunnel(lport, rport uint16) {
logger.LogDebug("[ServerTun] worker A: Client endpoint removed.") logger.LogDebug("[ServerTun] worker A: Client endpoint removed.")
break break
} }
(*hc.tuns)[rport].KeepAlive += 1 age := hc.AgeTunnel(rport)
if (*hc.tuns)[rport].KeepAlive > 25 { if age > 25 {
(*hc.tuns)[rport].Died = true hc.MarkTunDead(rport)
logger.LogDebug("[ServerTun] worker A: Client died, hanging up.") logger.LogDebug("[ServerTun] worker A: Client died, hanging up.")
break break
} }
@ -319,37 +342,23 @@ func (hc *Conn) StartServerTunnel(lport, rport uint16) {
if e != nil { if e != nil {
if e == io.EOF { if e == io.EOF {
logger.LogDebug(fmt.Sprintf("[ServerTun] worker A: rport Disconnected: shutting down tunnel %v", (*hc.tuns)[rport])) logger.LogDebug(fmt.Sprintf("[ServerTun] worker A: rport Disconnected: shutting down tunnel %v", (*hc.tuns)[rport]))
if !(*hc.tuns)[rport].Died { if hc.TunIsAlive(rport) {
hc.WritePacket(tunDst.Bytes(), CSOTunDisconn) hc.WritePacket(tunDst.Bytes(), CSOTunDisconn)
} }
(*hc.tuns)[rport].Died = true hc.ShutdownTun(rport)
if (*hc.tuns)[rport].Data != nil {
close((*hc.tuns)[rport].Data)
(*hc.tuns)[rport].Data = nil
}
delete((*hc.tuns), rport)
break break
} else if strings.Contains(e.Error(), "i/o timeout") { } else if strings.Contains(e.Error(), "i/o timeout") {
if (*hc.tuns)[rport].Died { if !hc.TunIsAlive(rport) {
logger.LogDebug(fmt.Sprintf("[ServerTun] worker A: timeout: Server side died, hanging up %v", (*hc.tuns)[rport])) logger.LogDebug(fmt.Sprintf("[ServerTun] worker A: timeout: Server side died, hanging up %v", (*hc.tuns)[rport]))
if (*hc.tuns)[rport].Data != nil { hc.ShutdownTun(rport)
close((*hc.tuns)[rport].Data)
(*hc.tuns)[rport].Data = nil
}
delete((*hc.tuns), rport)
break break
} }
} else { } else {
logger.LogDebug(fmt.Sprintf("[ServerTun] worker A: Read error from rport of tun %v: %s", (*hc.tuns)[rport], e)) logger.LogDebug(fmt.Sprintf("[ServerTun] worker A: Read error from rport of tun %v: %s", (*hc.tuns)[rport], e))
if !(*hc.tuns)[rport].Died { if hc.TunIsAlive(rport) {
hc.WritePacket(tunDst.Bytes(), CSOTunDisconn) hc.WritePacket(tunDst.Bytes(), CSOTunDisconn)
} }
(*hc.tuns)[rport].Died = true hc.ShutdownTun(rport)
if (*hc.tuns)[rport].Data != nil {
close((*hc.tuns)[rport].Data)
(*hc.tuns)[rport].Data = nil
}
delete((*hc.tuns), rport)
break break
} }
} }
@ -357,14 +366,6 @@ func (hc *Conn) StartServerTunnel(lport, rport uint16) {
rBuf = append(tunDst.Bytes(), rBuf[:n]...) rBuf = append(tunDst.Bytes(), rBuf[:n]...)
hc.WritePacket(rBuf[:n+4], CSOTunData) hc.WritePacket(rBuf[:n+4], CSOTunData)
} }
//if (*hc.tuns)[rport].KeepAlive > 50000 {
// (*hc.tuns)[rport].Died = true
// logger.LogDebug("[ServerTun] worker A: Client died, hanging up.")
//} else {
// (*hc.tuns)[rport].KeepAlive += 1
//}
} }
logger.LogDebug("[ServerTun] worker A: exiting") logger.LogDebug("[ServerTun] worker A: exiting")
}() }()
@ -382,7 +383,7 @@ func (hc *Conn) StartServerTunnel(lport, rport uint16) {
logger.LogDebug("[ServerTun] worker B: starting") logger.LogDebug("[ServerTun] worker B: starting")
for { for {
rData, ok := <-(*hc.tuns)[rport].Data rData, ok := <-(*hc.tuns)[rport].Data // FIXME: race w/ShutdownTun() calls
if ok { if ok {
c.SetWriteDeadline(time.Now().Add(200 * time.Millisecond)) c.SetWriteDeadline(time.Now().Add(200 * time.Millisecond))
_, e := c.Write(rData) _, e := c.Write(rData)