TUN-4866: Add Control Stream for QUIC
This commit adds support to Register and Unregister Connections via RPC on the QUIC transport protocol
This commit is contained in:
parent
1082ac1c36
commit
12ad264eb3
|
@ -248,10 +248,17 @@ func prepareTunnelConfig(
|
||||||
|
|
||||||
edgeTLSConfigs := make(map[connection.Protocol]*tls.Config, len(connection.ProtocolList))
|
edgeTLSConfigs := make(map[connection.Protocol]*tls.Config, len(connection.ProtocolList))
|
||||||
for _, p := range connection.ProtocolList {
|
for _, p := range connection.ProtocolList {
|
||||||
edgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c, p.ServerName())
|
tlsSettings := p.TLSSettings()
|
||||||
|
if tlsSettings == nil {
|
||||||
|
return nil, ingress.Ingress{}, fmt.Errorf("%s has unknown TLS settings", p)
|
||||||
|
}
|
||||||
|
edgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c, tlsSettings.ServerName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ingress.Ingress{}, errors.Wrap(err, "unable to create TLS config to connect with edge")
|
return nil, ingress.Ingress{}, errors.Wrap(err, "unable to create TLS config to connect with edge")
|
||||||
}
|
}
|
||||||
|
if len(edgeTLSConfig.NextProtos) > 0 {
|
||||||
|
edgeTLSConfig.NextProtos = tlsSettings.NextProtos
|
||||||
|
}
|
||||||
edgeTLSConfigs[p] = edgeTLSConfig
|
edgeTLSConfigs[p] = edgeTLSConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,89 @@
|
||||||
|
package connection
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
|
||||||
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RPCClientFunc derives a named tunnel rpc client that can then be used to register and unregister connections.
|
||||||
|
type RPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient
|
||||||
|
|
||||||
|
type controlStream struct {
|
||||||
|
observer *Observer
|
||||||
|
|
||||||
|
connectedFuse ConnectedFuse
|
||||||
|
namedTunnelConfig *NamedTunnelConfig
|
||||||
|
connIndex uint8
|
||||||
|
|
||||||
|
newRPCClientFunc RPCClientFunc
|
||||||
|
|
||||||
|
gracefulShutdownC <-chan struct{}
|
||||||
|
gracePeriod time.Duration
|
||||||
|
stoppedGracefully bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// ControlStreamHandler registers connections with origintunneld and initiates graceful shutdown.
|
||||||
|
type ControlStreamHandler interface {
|
||||||
|
ServeControlStream(ctx context.Context, rw io.ReadWriteCloser, connOptions *tunnelpogs.ConnectionOptions) error
|
||||||
|
IsStopped() bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewControlStream returns a new instance of ControlStreamHandler
|
||||||
|
func NewControlStream(
|
||||||
|
observer *Observer,
|
||||||
|
connectedFuse ConnectedFuse,
|
||||||
|
namedTunnelConfig *NamedTunnelConfig,
|
||||||
|
connIndex uint8,
|
||||||
|
newRPCClientFunc RPCClientFunc,
|
||||||
|
gracefulShutdownC <-chan struct{},
|
||||||
|
gracePeriod time.Duration,
|
||||||
|
) ControlStreamHandler {
|
||||||
|
if newRPCClientFunc == nil {
|
||||||
|
newRPCClientFunc = newRegistrationRPCClient
|
||||||
|
}
|
||||||
|
return &controlStream{
|
||||||
|
observer: observer,
|
||||||
|
connectedFuse: connectedFuse,
|
||||||
|
namedTunnelConfig: namedTunnelConfig,
|
||||||
|
newRPCClientFunc: newRPCClientFunc,
|
||||||
|
connIndex: connIndex,
|
||||||
|
gracefulShutdownC: gracefulShutdownC,
|
||||||
|
gracePeriod: gracePeriod,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *controlStream) ServeControlStream(
|
||||||
|
ctx context.Context,
|
||||||
|
rw io.ReadWriteCloser,
|
||||||
|
connOptions *tunnelpogs.ConnectionOptions,
|
||||||
|
) error {
|
||||||
|
rpcClient := c.newRPCClientFunc(ctx, rw, c.observer.log)
|
||||||
|
defer rpcClient.Close()
|
||||||
|
|
||||||
|
if err := rpcClient.RegisterConnection(ctx, c.namedTunnelConfig, connOptions, c.connIndex, c.observer); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.connectedFuse.Connected()
|
||||||
|
|
||||||
|
// wait for connection termination or start of graceful shutdown
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
break
|
||||||
|
case <-c.gracefulShutdownC:
|
||||||
|
c.stoppedGracefully = true
|
||||||
|
}
|
||||||
|
|
||||||
|
c.observer.sendUnregisteringEvent(c.connIndex)
|
||||||
|
rpcClient.GracefulShutdown(ctx, c.gracePeriod)
|
||||||
|
c.observer.log.Info().Uint8(LogFieldConnIndex, c.connIndex).Msg("Unregistered tunnel connection")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *controlStream) IsStopped() bool {
|
||||||
|
return c.stoppedGracefully
|
||||||
|
}
|
|
@ -33,18 +33,15 @@ type HTTP2Connection struct {
|
||||||
conn net.Conn
|
conn net.Conn
|
||||||
server *http2.Server
|
server *http2.Server
|
||||||
config *Config
|
config *Config
|
||||||
namedTunnel *NamedTunnelConfig
|
|
||||||
connOptions *tunnelpogs.ConnectionOptions
|
connOptions *tunnelpogs.ConnectionOptions
|
||||||
observer *Observer
|
observer *Observer
|
||||||
connIndexStr string
|
|
||||||
connIndex uint8
|
connIndex uint8
|
||||||
// newRPCClientFunc allows us to mock RPCs during testing
|
// newRPCClientFunc allows us to mock RPCs during testing
|
||||||
newRPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient
|
newRPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient
|
||||||
|
|
||||||
log *zerolog.Logger
|
log *zerolog.Logger
|
||||||
activeRequestsWG sync.WaitGroup
|
activeRequestsWG sync.WaitGroup
|
||||||
connectedFuse ConnectedFuse
|
controlStreamHandler ControlStreamHandler
|
||||||
gracefulShutdownC <-chan struct{}
|
|
||||||
stoppedGracefully bool
|
stoppedGracefully bool
|
||||||
controlStreamErr error // result of running control stream handler
|
controlStreamErr error // result of running control stream handler
|
||||||
}
|
}
|
||||||
|
@ -53,13 +50,11 @@ type HTTP2Connection struct {
|
||||||
func NewHTTP2Connection(
|
func NewHTTP2Connection(
|
||||||
conn net.Conn,
|
conn net.Conn,
|
||||||
config *Config,
|
config *Config,
|
||||||
namedTunnelConfig *NamedTunnelConfig,
|
|
||||||
connOptions *tunnelpogs.ConnectionOptions,
|
connOptions *tunnelpogs.ConnectionOptions,
|
||||||
observer *Observer,
|
observer *Observer,
|
||||||
connIndex uint8,
|
connIndex uint8,
|
||||||
connectedFuse ConnectedFuse,
|
controlStreamHandler ControlStreamHandler,
|
||||||
log *zerolog.Logger,
|
log *zerolog.Logger,
|
||||||
gracefulShutdownC <-chan struct{},
|
|
||||||
) *HTTP2Connection {
|
) *HTTP2Connection {
|
||||||
return &HTTP2Connection{
|
return &HTTP2Connection{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
|
@ -67,15 +62,12 @@ func NewHTTP2Connection(
|
||||||
MaxConcurrentStreams: math.MaxUint32,
|
MaxConcurrentStreams: math.MaxUint32,
|
||||||
},
|
},
|
||||||
config: config,
|
config: config,
|
||||||
namedTunnel: namedTunnelConfig,
|
|
||||||
connOptions: connOptions,
|
connOptions: connOptions,
|
||||||
observer: observer,
|
observer: observer,
|
||||||
connIndexStr: uint8ToString(connIndex),
|
|
||||||
connIndex: connIndex,
|
connIndex: connIndex,
|
||||||
newRPCClientFunc: newRegistrationRPCClient,
|
newRPCClientFunc: newRegistrationRPCClient,
|
||||||
connectedFuse: connectedFuse,
|
controlStreamHandler: controlStreamHandler,
|
||||||
log: log,
|
log: log,
|
||||||
gracefulShutdownC: gracefulShutdownC,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -91,7 +83,7 @@ func (c *HTTP2Connection) Serve(ctx context.Context) error {
|
||||||
})
|
})
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case c.stoppedGracefully:
|
case c.controlStreamHandler.IsStopped():
|
||||||
return nil
|
return nil
|
||||||
case c.controlStreamErr != nil:
|
case c.controlStreamErr != nil:
|
||||||
return c.controlStreamErr
|
return c.controlStreamErr
|
||||||
|
@ -116,7 +108,7 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
switch connType {
|
switch connType {
|
||||||
case TypeControlStream:
|
case TypeControlStream:
|
||||||
if err := c.serveControlStream(r.Context(), respWriter); err != nil {
|
if err := c.controlStreamHandler.ServeControlStream(r.Context(), respWriter, c.connOptions); err != nil {
|
||||||
c.controlStreamErr = err
|
c.controlStreamErr = err
|
||||||
c.log.Error().Err(err)
|
c.log.Error().Err(err)
|
||||||
respWriter.WriteErrorResponse()
|
respWriter.WriteErrorResponse()
|
||||||
|
@ -154,29 +146,6 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *HTTP2Connection) serveControlStream(ctx context.Context, respWriter *http2RespWriter) error {
|
|
||||||
rpcClient := c.newRPCClientFunc(ctx, respWriter, c.observer.log)
|
|
||||||
defer rpcClient.Close()
|
|
||||||
|
|
||||||
if err := rpcClient.RegisterConnection(ctx, c.namedTunnel, c.connOptions, c.connIndex, c.observer); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.connectedFuse.Connected()
|
|
||||||
|
|
||||||
// wait for connection termination or start of graceful shutdown
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
break
|
|
||||||
case <-c.gracefulShutdownC:
|
|
||||||
c.stoppedGracefully = true
|
|
||||||
}
|
|
||||||
|
|
||||||
c.observer.sendUnregisteringEvent(c.connIndex)
|
|
||||||
rpcClient.GracefulShutdown(ctx, c.config.GracePeriod)
|
|
||||||
c.observer.log.Info().Uint8(LogFieldConnIndex, c.connIndex).Msg("Unregistered tunnel connection")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HTTP2Connection) close() {
|
func (c *HTTP2Connection) close() {
|
||||||
// Wait for all serve HTTP handlers to return
|
// Wait for all serve HTTP handlers to return
|
||||||
c.activeRequestsWG.Wait()
|
c.activeRequestsWG.Wait()
|
||||||
|
|
|
@ -30,16 +30,24 @@ func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) {
|
||||||
edgeConn, originConn := net.Pipe()
|
edgeConn, originConn := net.Pipe()
|
||||||
var connIndex = uint8(0)
|
var connIndex = uint8(0)
|
||||||
log := zerolog.Nop()
|
log := zerolog.Nop()
|
||||||
|
obs := NewObserver(&log, &log, false)
|
||||||
|
controlStream := NewControlStream(
|
||||||
|
obs,
|
||||||
|
mockConnectedFuse{},
|
||||||
|
&NamedTunnelConfig{},
|
||||||
|
connIndex,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
1*time.Second,
|
||||||
|
)
|
||||||
return NewHTTP2Connection(
|
return NewHTTP2Connection(
|
||||||
originConn,
|
originConn,
|
||||||
testConfig,
|
testConfig,
|
||||||
&NamedTunnelConfig{},
|
|
||||||
&pogs.ConnectionOptions{},
|
&pogs.ConnectionOptions{},
|
||||||
NewObserver(&log, &log, false),
|
obs,
|
||||||
connIndex,
|
connIndex,
|
||||||
mockConnectedFuse{},
|
controlStream,
|
||||||
&log,
|
&log,
|
||||||
nil,
|
|
||||||
), edgeConn
|
), edgeConn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -225,7 +233,18 @@ func TestServeControlStream(t *testing.T) {
|
||||||
registered: make(chan struct{}),
|
registered: make(chan struct{}),
|
||||||
unregistered: make(chan struct{}),
|
unregistered: make(chan struct{}),
|
||||||
}
|
}
|
||||||
http2Conn.newRPCClientFunc = rpcClientFactory.newMockRPCClient
|
|
||||||
|
obs := NewObserver(&log, &log, false)
|
||||||
|
controlStream := NewControlStream(
|
||||||
|
obs,
|
||||||
|
mockConnectedFuse{},
|
||||||
|
&NamedTunnelConfig{},
|
||||||
|
1,
|
||||||
|
rpcClientFactory.newMockRPCClient,
|
||||||
|
nil,
|
||||||
|
1*time.Second,
|
||||||
|
)
|
||||||
|
http2Conn.controlStreamHandler = controlStream
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
|
@ -264,7 +283,18 @@ func TestFailRegistration(t *testing.T) {
|
||||||
registered: make(chan struct{}),
|
registered: make(chan struct{}),
|
||||||
unregistered: make(chan struct{}),
|
unregistered: make(chan struct{}),
|
||||||
}
|
}
|
||||||
http2Conn.newRPCClientFunc = rpcClientFactory.newMockRPCClient
|
|
||||||
|
obs := NewObserver(&log, &log, false)
|
||||||
|
controlStream := NewControlStream(
|
||||||
|
obs,
|
||||||
|
mockConnectedFuse{},
|
||||||
|
&NamedTunnelConfig{},
|
||||||
|
http2Conn.connIndex,
|
||||||
|
rpcClientFactory.newMockRPCClient,
|
||||||
|
nil,
|
||||||
|
1*time.Second,
|
||||||
|
)
|
||||||
|
http2Conn.controlStreamHandler = controlStream
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
|
@ -297,10 +327,21 @@ func TestGracefulShutdownHTTP2(t *testing.T) {
|
||||||
unregistered: make(chan struct{}),
|
unregistered: make(chan struct{}),
|
||||||
}
|
}
|
||||||
events := &eventCollectorSink{}
|
events := &eventCollectorSink{}
|
||||||
http2Conn.newRPCClientFunc = rpcClientFactory.newMockRPCClient
|
|
||||||
http2Conn.observer.RegisterSink(events)
|
|
||||||
shutdownC := make(chan struct{})
|
shutdownC := make(chan struct{})
|
||||||
http2Conn.gracefulShutdownC = shutdownC
|
obs := NewObserver(&log, &log, false)
|
||||||
|
obs.RegisterSink(events)
|
||||||
|
controlStream := NewControlStream(
|
||||||
|
obs,
|
||||||
|
mockConnectedFuse{},
|
||||||
|
&NamedTunnelConfig{},
|
||||||
|
http2Conn.connIndex,
|
||||||
|
rpcClientFactory.newMockRPCClient,
|
||||||
|
shutdownC,
|
||||||
|
1*time.Second,
|
||||||
|
)
|
||||||
|
|
||||||
|
http2Conn.controlStreamHandler = controlStream
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
|
@ -339,7 +380,7 @@ func TestGracefulShutdownHTTP2(t *testing.T) {
|
||||||
case <-time.Tick(time.Second):
|
case <-time.Tick(time.Second):
|
||||||
t.Fatal("timeout out waiting for unregistered signal")
|
t.Fatal("timeout out waiting for unregistered signal")
|
||||||
}
|
}
|
||||||
assert.True(t, http2Conn.stoppedGracefully)
|
assert.True(t, controlStream.IsStopped())
|
||||||
|
|
||||||
cancel()
|
cancel()
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
|
@ -15,33 +15,29 @@ const (
|
||||||
edgeH2muxTLSServerName = "cftunnel.com"
|
edgeH2muxTLSServerName = "cftunnel.com"
|
||||||
// edgeH2TLSServerName is the server name to establish http2 connection with edge
|
// edgeH2TLSServerName is the server name to establish http2 connection with edge
|
||||||
edgeH2TLSServerName = "h2.cftunnel.com"
|
edgeH2TLSServerName = "h2.cftunnel.com"
|
||||||
|
// edgeQUICServerName is the server name to establish quic connection with edge.
|
||||||
|
edgeQUICServerName = "quic.cftunnel.com"
|
||||||
// threshold to switch back to h2mux when the user intentionally pick --protocol http2
|
// threshold to switch back to h2mux when the user intentionally pick --protocol http2
|
||||||
explicitHTTP2FallbackThreshold = -1
|
explicitHTTP2FallbackThreshold = -1
|
||||||
autoSelectFlag = "auto"
|
autoSelectFlag = "auto"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ProtocolList = []Protocol{H2mux, HTTP2}
|
// ProtocolList represents a list of supported protocols for communication with the edge.
|
||||||
|
ProtocolList = []Protocol{H2mux, HTTP2, QUIC}
|
||||||
)
|
)
|
||||||
|
|
||||||
type Protocol int64
|
type Protocol int64
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
// H2mux protocol can be used both with Classic and Named Tunnels. .
|
||||||
H2mux Protocol = iota
|
H2mux Protocol = iota
|
||||||
|
// HTTP2 is used only with named tunnels. It's more efficient than H2mux for L4 proxying.
|
||||||
HTTP2
|
HTTP2
|
||||||
|
// QUIC is used only with named tunnels.
|
||||||
|
QUIC
|
||||||
)
|
)
|
||||||
|
|
||||||
func (p Protocol) ServerName() string {
|
|
||||||
switch p {
|
|
||||||
case H2mux:
|
|
||||||
return edgeH2muxTLSServerName
|
|
||||||
case HTTP2:
|
|
||||||
return edgeH2TLSServerName
|
|
||||||
default:
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fallback returns the fallback protocol and whether the protocol has a fallback
|
// Fallback returns the fallback protocol and whether the protocol has a fallback
|
||||||
func (p Protocol) fallback() (Protocol, bool) {
|
func (p Protocol) fallback() (Protocol, bool) {
|
||||||
switch p {
|
switch p {
|
||||||
|
@ -49,6 +45,8 @@ func (p Protocol) fallback() (Protocol, bool) {
|
||||||
return 0, false
|
return 0, false
|
||||||
case HTTP2:
|
case HTTP2:
|
||||||
return H2mux, true
|
return H2mux, true
|
||||||
|
case QUIC:
|
||||||
|
return HTTP2, true
|
||||||
default:
|
default:
|
||||||
return 0, false
|
return 0, false
|
||||||
}
|
}
|
||||||
|
@ -60,11 +58,39 @@ func (p Protocol) String() string {
|
||||||
return "h2mux"
|
return "h2mux"
|
||||||
case HTTP2:
|
case HTTP2:
|
||||||
return "http2"
|
return "http2"
|
||||||
|
case QUIC:
|
||||||
|
return "quic"
|
||||||
default:
|
default:
|
||||||
return fmt.Sprintf("unknown protocol")
|
return fmt.Sprintf("unknown protocol")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p Protocol) TLSSettings() *TLSSettings {
|
||||||
|
switch p {
|
||||||
|
case H2mux:
|
||||||
|
return &TLSSettings{
|
||||||
|
ServerName: edgeH2muxTLSServerName,
|
||||||
|
}
|
||||||
|
case HTTP2:
|
||||||
|
return &TLSSettings{
|
||||||
|
ServerName: edgeH2TLSServerName,
|
||||||
|
}
|
||||||
|
case QUIC:
|
||||||
|
fmt.Println("returning this?")
|
||||||
|
return &TLSSettings{
|
||||||
|
ServerName: edgeQUICServerName,
|
||||||
|
NextProtos: []string{"argotunnel"},
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type TLSSettings struct {
|
||||||
|
ServerName string
|
||||||
|
NextProtos []string
|
||||||
|
}
|
||||||
|
|
||||||
type ProtocolSelector interface {
|
type ProtocolSelector interface {
|
||||||
Current() Protocol
|
Current() Protocol
|
||||||
Fallback() (Protocol, bool)
|
Fallback() (Protocol, bool)
|
||||||
|
@ -184,6 +210,10 @@ func NewProtocolSelector(
|
||||||
return newAutoProtocolSelector(HTTP2, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil
|
return newAutoProtocolSelector(HTTP2, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if protocolFlag == QUIC.String() {
|
||||||
|
return newAutoProtocolSelector(QUIC, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil
|
||||||
|
}
|
||||||
|
|
||||||
if protocolFlag != autoSelectFlag {
|
if protocolFlag != autoSelectFlag {
|
||||||
return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage)
|
return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage)
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@ import (
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
|
||||||
quicpogs "github.com/cloudflare/cloudflared/quic"
|
quicpogs "github.com/cloudflare/cloudflared/quic"
|
||||||
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -29,8 +30,10 @@ const (
|
||||||
// QUICConnection represents the type that facilitates Proxying via QUIC streams.
|
// QUICConnection represents the type that facilitates Proxying via QUIC streams.
|
||||||
type QUICConnection struct {
|
type QUICConnection struct {
|
||||||
session quic.Session
|
session quic.Session
|
||||||
logger zerolog.Logger
|
logger *zerolog.Logger
|
||||||
httpProxy OriginProxy
|
httpProxy OriginProxy
|
||||||
|
gracefulShutdownC <-chan struct{}
|
||||||
|
stoppedGracefully bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewQUICConnection returns a new instance of QUICConnection.
|
// NewQUICConnection returns a new instance of QUICConnection.
|
||||||
|
@ -40,19 +43,26 @@ func NewQUICConnection(
|
||||||
edgeAddr net.Addr,
|
edgeAddr net.Addr,
|
||||||
tlsConfig *tls.Config,
|
tlsConfig *tls.Config,
|
||||||
httpProxy OriginProxy,
|
httpProxy OriginProxy,
|
||||||
logger zerolog.Logger,
|
connOptions *tunnelpogs.ConnectionOptions,
|
||||||
|
controlStreamHandler ControlStreamHandler,
|
||||||
|
observer *Observer,
|
||||||
) (*QUICConnection, error) {
|
) (*QUICConnection, error) {
|
||||||
session, err := quic.DialAddr(edgeAddr.String(), tlsConfig, quicConfig)
|
session, err := quic.DialAddr(edgeAddr.String(), tlsConfig, quicConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "failed to dial to edge")
|
return nil, errors.Wrap(err, "failed to dial to edge")
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO: RegisterConnectionRPC here.
|
registrationStream, err := session.OpenStream()
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "failed to open a registration stream")
|
||||||
|
}
|
||||||
|
|
||||||
|
go controlStreamHandler.ServeControlStream(ctx, registrationStream, connOptions)
|
||||||
|
|
||||||
return &QUICConnection{
|
return &QUICConnection{
|
||||||
session: session,
|
session: session,
|
||||||
httpProxy: httpProxy,
|
httpProxy: httpProxy,
|
||||||
logger: logger,
|
logger: observer.log,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -96,11 +106,26 @@ func (q *QUICConnection) handleStream(stream quic.Stream) error {
|
||||||
w := newHTTPResponseAdapter(stream)
|
w := newHTTPResponseAdapter(stream)
|
||||||
return q.httpProxy.ProxyHTTP(w, req, connectRequest.Type == quicpogs.ConnectionTypeWebsocket)
|
return q.httpProxy.ProxyHTTP(w, req, connectRequest.Type == quicpogs.ConnectionTypeWebsocket)
|
||||||
case quicpogs.ConnectionTypeTCP:
|
case quicpogs.ConnectionTypeTCP:
|
||||||
return errors.New("not implemented")
|
// TODO: This is a placeholder for testing completion. TUN-4865 will add proper TCP support.
|
||||||
|
rwa := &streamReadWriteAcker{
|
||||||
|
ReadWriter: stream,
|
||||||
|
}
|
||||||
|
return q.httpProxy.ProxyTCP(context.Background(), rwa, &TCPRequest{Dest: connectRequest.Dest})
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// streamReadWriteAcker is a light wrapper over QUIC streams with a callback to send response back to
|
||||||
|
// the client.
|
||||||
|
type streamReadWriteAcker struct {
|
||||||
|
io.ReadWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
// AckConnection acks response back to the proxy.
|
||||||
|
func (s *streamReadWriteAcker) AckConnection() error {
|
||||||
|
return quicpogs.WriteConnectResponseData(s, nil)
|
||||||
|
}
|
||||||
|
|
||||||
// httpResponseAdapter translates responses written by the HTTP Proxy into ones that can be used in QUIC.
|
// httpResponseAdapter translates responses written by the HTTP Proxy into ones that can be used in QUIC.
|
||||||
type httpResponseAdapter struct {
|
type httpResponseAdapter struct {
|
||||||
io.Writer
|
io.Writer
|
||||||
|
|
|
@ -16,6 +16,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gobwas/ws/wsutil"
|
"github.com/gobwas/ws/wsutil"
|
||||||
"github.com/lucas-clemente/quic-go"
|
"github.com/lucas-clemente/quic-go"
|
||||||
|
@ -25,6 +26,7 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
quicpogs "github.com/cloudflare/cloudflared/quic"
|
quicpogs "github.com/cloudflare/cloudflared/quic"
|
||||||
|
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestQUICServer tests if a quic server accepts and responds to a quic client with the acceptance protocol.
|
// TestQUICServer tests if a quic server accepts and responds to a quic client with the acceptance protocol.
|
||||||
|
@ -134,6 +136,13 @@ func TestQUICServer(t *testing.T) {
|
||||||
message: wsBuf.Bytes(),
|
message: wsBuf.Bytes(),
|
||||||
expectedResponse: []byte{0x81, 0x5, 0x48, 0x65, 0x6c, 0x6c, 0x6f},
|
expectedResponse: []byte{0x81, 0x5, 0x48, 0x65, 0x6c, 0x6c, 0x6f},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
desc: "test tcp proxy",
|
||||||
|
connectionType: quicpogs.ConnectionTypeTCP,
|
||||||
|
metadata: []quicpogs.Metadata{},
|
||||||
|
message: []byte("Here is some tcp data"),
|
||||||
|
expectedResponse: []byte("Here is some tcp data"),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
|
@ -149,11 +158,43 @@ func TestQUICServer(t *testing.T) {
|
||||||
)
|
)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
qC, err := NewQUICConnection(ctx, quicConfig, udpListener.LocalAddr(), tlsClientConfig, originProxy, log)
|
rpcClientFactory := mockRPCClientFactory{
|
||||||
|
registered: make(chan struct{}),
|
||||||
|
unregistered: make(chan struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
obs := NewObserver(&log, &log, false)
|
||||||
|
controlStream := NewControlStream(
|
||||||
|
obs,
|
||||||
|
mockConnectedFuse{},
|
||||||
|
&NamedTunnelConfig{},
|
||||||
|
1,
|
||||||
|
rpcClientFactory.newMockRPCClient,
|
||||||
|
nil,
|
||||||
|
1*time.Second,
|
||||||
|
)
|
||||||
|
|
||||||
|
qC, err := NewQUICConnection(
|
||||||
|
ctx,
|
||||||
|
quicConfig,
|
||||||
|
udpListener.LocalAddr(),
|
||||||
|
tlsClientConfig,
|
||||||
|
originProxy,
|
||||||
|
&pogs.ConnectionOptions{},
|
||||||
|
controlStream,
|
||||||
|
NewObserver(&log, &log, false),
|
||||||
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
go qC.Serve(ctx)
|
go qC.Serve(ctx)
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-rpcClientFactory.registered:
|
||||||
|
break //ok
|
||||||
|
case <-time.Tick(time.Second):
|
||||||
|
t.Fatal("timeout out waiting for registration")
|
||||||
|
}
|
||||||
cancel()
|
cancel()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -174,7 +215,7 @@ func quicServer(
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
earlyListener, err := quic.ListenEarly(conn, tlsConf, config)
|
earlyListener, err := quic.Listen(conn, tlsConf, config)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
session, err := earlyListener.Accept(ctx)
|
session, err := earlyListener.Accept(ctx)
|
||||||
|
@ -183,7 +224,6 @@ func quicServer(
|
||||||
stream, err := session.OpenStreamSync(context.Background())
|
stream, err := session.OpenStreamSync(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Start off ALPN
|
|
||||||
err = quicpogs.WriteConnectRequestData(stream, dest, connectionType, metadata...)
|
err = quicpogs.WriteConnectRequestData(stream, dest, connectionType, metadata...)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
@ -264,5 +304,7 @@ func (moc *mockOriginProxyWithRequest) ProxyHTTP(w ResponseWriter, r *http.Reque
|
||||||
}
|
}
|
||||||
|
|
||||||
func (moc *mockOriginProxyWithRequest) ProxyTCP(ctx context.Context, rwa ReadWriteAcker, tcpRequest *TCPRequest) error {
|
func (moc *mockOriginProxyWithRequest) ProxyTCP(ctx context.Context, rwa ReadWriteAcker, tcpRequest *TCPRequest) error {
|
||||||
|
rwa.AckConnection()
|
||||||
|
io.Copy(rwa, rwa)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
113
origin/tunnel.go
113
origin/tunnel.go
|
@ -11,6 +11,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/lucas-clemente/quic-go"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
@ -271,15 +272,38 @@ func ServeTunnel(
|
||||||
|
|
||||||
defer config.Observer.SendDisconnect(connIndex)
|
defer config.Observer.SendDisconnect(connIndex)
|
||||||
|
|
||||||
|
connectedFuse := &connectedFuse{
|
||||||
|
fuse: fuse,
|
||||||
|
backoff: backoff,
|
||||||
|
}
|
||||||
|
|
||||||
|
controlStream := connection.NewControlStream(
|
||||||
|
config.Observer,
|
||||||
|
connectedFuse,
|
||||||
|
config.NamedTunnel,
|
||||||
|
connIndex,
|
||||||
|
nil,
|
||||||
|
gracefulShutdownC,
|
||||||
|
config.ConnectionConfig.GracePeriod,
|
||||||
|
)
|
||||||
|
|
||||||
|
if protocol == connection.QUIC {
|
||||||
|
connOptions := config.ConnectionOptions(addr.UDP.String(), uint8(backoff.Retries()))
|
||||||
|
return ServeQUIC(ctx,
|
||||||
|
addr.UDP,
|
||||||
|
config,
|
||||||
|
connOptions,
|
||||||
|
controlStream,
|
||||||
|
connectedFuse,
|
||||||
|
reconnectCh,
|
||||||
|
gracefulShutdownC)
|
||||||
|
}
|
||||||
|
|
||||||
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, config.EdgeTLSConfigs[protocol], addr.TCP)
|
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, config.EdgeTLSConfigs[protocol], addr.TCP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
connLog.Err(err).Msg("Unable to establish connection with Cloudflare edge")
|
connLog.Err(err).Msg("Unable to establish connection with Cloudflare edge")
|
||||||
return err, true
|
return err, true
|
||||||
}
|
}
|
||||||
connectedFuse := &connectedFuse{
|
|
||||||
fuse: fuse,
|
|
||||||
backoff: backoff,
|
|
||||||
}
|
|
||||||
|
|
||||||
if protocol == connection.HTTP2 {
|
if protocol == connection.HTTP2 {
|
||||||
connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.Retries()))
|
connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.Retries()))
|
||||||
|
@ -289,10 +313,10 @@ func ServeTunnel(
|
||||||
config,
|
config,
|
||||||
edgeConn,
|
edgeConn,
|
||||||
connOptions,
|
connOptions,
|
||||||
|
controlStream,
|
||||||
connIndex,
|
connIndex,
|
||||||
connectedFuse,
|
|
||||||
reconnectCh,
|
|
||||||
gracefulShutdownC,
|
gracefulShutdownC,
|
||||||
|
reconnectCh,
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
err = ServeH2mux(
|
err = ServeH2mux(
|
||||||
|
@ -403,22 +427,20 @@ func ServeHTTP2(
|
||||||
config *TunnelConfig,
|
config *TunnelConfig,
|
||||||
tlsServerConn net.Conn,
|
tlsServerConn net.Conn,
|
||||||
connOptions *tunnelpogs.ConnectionOptions,
|
connOptions *tunnelpogs.ConnectionOptions,
|
||||||
|
controlStreamHandler connection.ControlStreamHandler,
|
||||||
connIndex uint8,
|
connIndex uint8,
|
||||||
connectedFuse connection.ConnectedFuse,
|
|
||||||
reconnectCh chan ReconnectSignal,
|
|
||||||
gracefulShutdownC <-chan struct{},
|
gracefulShutdownC <-chan struct{},
|
||||||
|
reconnectCh chan ReconnectSignal,
|
||||||
) error {
|
) error {
|
||||||
connLog.Debug().Msgf("Connecting via http2")
|
connLog.Debug().Msgf("Connecting via http2")
|
||||||
h2conn := connection.NewHTTP2Connection(
|
h2conn := connection.NewHTTP2Connection(
|
||||||
tlsServerConn,
|
tlsServerConn,
|
||||||
config.ConnectionConfig,
|
config.ConnectionConfig,
|
||||||
config.NamedTunnel,
|
|
||||||
connOptions,
|
connOptions,
|
||||||
config.Observer,
|
config.Observer,
|
||||||
connIndex,
|
connIndex,
|
||||||
connectedFuse,
|
controlStreamHandler,
|
||||||
config.Log,
|
config.Log,
|
||||||
gracefulShutdownC,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
errGroup, serveCtx := errgroup.WithContext(ctx)
|
errGroup, serveCtx := errgroup.WithContext(ctx)
|
||||||
|
@ -438,6 +460,75 @@ func ServeHTTP2(
|
||||||
return errGroup.Wait()
|
return errGroup.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ServeQUIC(
|
||||||
|
ctx context.Context,
|
||||||
|
edgeAddr *net.UDPAddr,
|
||||||
|
config *TunnelConfig,
|
||||||
|
connOptions *tunnelpogs.ConnectionOptions,
|
||||||
|
controlStreamHandler connection.ControlStreamHandler,
|
||||||
|
connectedFuse connection.ConnectedFuse,
|
||||||
|
reconnectCh chan ReconnectSignal,
|
||||||
|
gracefulShutdownC <-chan struct{},
|
||||||
|
) (err error, recoverable bool) {
|
||||||
|
tlsConfig := config.EdgeTLSConfigs[connection.QUIC]
|
||||||
|
quicConfig := &quic.Config{
|
||||||
|
HandshakeIdleTimeout: time.Second * 10,
|
||||||
|
KeepAlive: true,
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
quicConn, err := connection.NewQUICConnection(
|
||||||
|
ctx,
|
||||||
|
quicConfig,
|
||||||
|
edgeAddr,
|
||||||
|
tlsConfig,
|
||||||
|
config.ConnectionConfig.OriginProxy,
|
||||||
|
connOptions,
|
||||||
|
controlStreamHandler,
|
||||||
|
config.Observer)
|
||||||
|
if err != nil {
|
||||||
|
config.Log.Error().Msgf("Failed to create new quic connection, err: %v", err)
|
||||||
|
return err, true
|
||||||
|
}
|
||||||
|
|
||||||
|
errGroup, serveCtx := errgroup.WithContext(ctx)
|
||||||
|
errGroup.Go(func() error {
|
||||||
|
err := quicConn.Serve(ctx)
|
||||||
|
if err != nil {
|
||||||
|
config.Log.Error().Msgf("Failed to serve quic connection, err: %v", err)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("Connection with edge closed")
|
||||||
|
})
|
||||||
|
|
||||||
|
errGroup.Go(func() error {
|
||||||
|
return listenReconnect(serveCtx, reconnectCh, gracefulShutdownC)
|
||||||
|
})
|
||||||
|
|
||||||
|
err = errGroup.Wait()
|
||||||
|
if err == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
config.Log.Info().Msg("Reconnecting with the same udp conn")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type quicLogger struct {
|
||||||
|
*zerolog.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ql *quicLogger) Write(p []byte) (n int, err error) {
|
||||||
|
ql.Debug().Msgf("quic log: %v", string(p))
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ql *quicLogger) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func listenReconnect(ctx context.Context, reconnectCh <-chan ReconnectSignal, gracefulShutdownCh <-chan struct{}) error {
|
func listenReconnect(ctx context.Context, reconnectCh <-chan ReconnectSignal, gracefulShutdownCh <-chan struct{}) error {
|
||||||
select {
|
select {
|
||||||
case reconnect := <-reconnectCh:
|
case reconnect := <-reconnectCh:
|
||||||
|
|
Loading…
Reference in New Issue