cloudflared-mirror/tunnelstate/conntracker.go

104 lines
2.4 KiB
Go

package tunnelstate
import (
"net"
"sync"
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/connection"
)
type ConnTracker struct {
mutex sync.RWMutex
// int is the connection Index
connectionInfo map[uint8]ConnectionInfo
log *zerolog.Logger
}
type ConnectionInfo struct {
IsConnected bool `json:"isConnected,omitempty"`
Protocol connection.Protocol `json:"protocol,omitempty"`
EdgeAddress net.IP `json:"edgeAddress,omitempty"`
}
// Convinience struct to extend the connection with its index.
type IndexedConnectionInfo struct {
ConnectionInfo
Index uint8 `json:"index,omitempty"`
}
func NewConnTracker(
log *zerolog.Logger,
) *ConnTracker {
return &ConnTracker{
connectionInfo: make(map[uint8]ConnectionInfo, 0),
log: log,
}
}
func (ct *ConnTracker) OnTunnelEvent(c connection.Event) {
switch c.EventType {
case connection.Connected:
ct.mutex.Lock()
ci := ConnectionInfo{
IsConnected: true,
Protocol: c.Protocol,
EdgeAddress: c.EdgeAddress,
}
ct.connectionInfo[c.Index] = ci
ct.mutex.Unlock()
case connection.Disconnected, connection.Reconnecting, connection.RegisteringTunnel, connection.Unregistering:
ct.mutex.Lock()
ci := ct.connectionInfo[c.Index]
ci.IsConnected = false
ct.connectionInfo[c.Index] = ci
ct.mutex.Unlock()
default:
ct.log.Error().Msgf("Unknown connection event case %v", c)
}
}
func (ct *ConnTracker) CountActiveConns() uint {
ct.mutex.RLock()
defer ct.mutex.RUnlock()
active := uint(0)
for _, ci := range ct.connectionInfo {
if ci.IsConnected {
active++
}
}
return active
}
// HasConnectedWith checks if we've ever had a successful connection to the edge
// with said protocol.
func (ct *ConnTracker) HasConnectedWith(protocol connection.Protocol) bool {
ct.mutex.RLock()
defer ct.mutex.RUnlock()
for _, ci := range ct.connectionInfo {
if ci.Protocol == protocol {
return true
}
}
return false
}
// Returns the connection information iff it is connected this
// also leverages the [IndexedConnectionInfo] to also provide the connection index
func (ct *ConnTracker) GetActiveConnections() []IndexedConnectionInfo {
ct.mutex.RLock()
defer ct.mutex.RUnlock()
connections := make([]IndexedConnectionInfo, 0)
for key, value := range ct.connectionInfo {
if value.IsConnected {
info := IndexedConnectionInfo{value, key}
connections = append(connections, info)
}
}
return connections
}