TUN-4866: Add Control Stream for QUIC

This commit adds support to Register and Unregister Connections via RPC
on the QUIC transport protocol
This commit is contained in:
Sudarsan Reddy 2021-08-17 15:30:02 +01:00
parent 1082ac1c36
commit 12ad264eb3
8 changed files with 390 additions and 96 deletions

View File

@ -248,10 +248,17 @@ func prepareTunnelConfig(
edgeTLSConfigs := make(map[connection.Protocol]*tls.Config, len(connection.ProtocolList)) edgeTLSConfigs := make(map[connection.Protocol]*tls.Config, len(connection.ProtocolList))
for _, p := range connection.ProtocolList { for _, p := range connection.ProtocolList {
edgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c, p.ServerName()) tlsSettings := p.TLSSettings()
if tlsSettings == nil {
return nil, ingress.Ingress{}, fmt.Errorf("%s has unknown TLS settings", p)
}
edgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c, tlsSettings.ServerName)
if err != nil { if err != nil {
return nil, ingress.Ingress{}, errors.Wrap(err, "unable to create TLS config to connect with edge") return nil, ingress.Ingress{}, errors.Wrap(err, "unable to create TLS config to connect with edge")
} }
if len(edgeTLSConfig.NextProtos) > 0 {
edgeTLSConfig.NextProtos = tlsSettings.NextProtos
}
edgeTLSConfigs[p] = edgeTLSConfig edgeTLSConfigs[p] = edgeTLSConfig
} }

89
connection/control.go Normal file
View File

@ -0,0 +1,89 @@
package connection
import (
"context"
"io"
"time"
"github.com/rs/zerolog"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
// RPCClientFunc derives a named tunnel rpc client that can then be used to register and unregister connections.
type RPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient
type controlStream struct {
observer *Observer
connectedFuse ConnectedFuse
namedTunnelConfig *NamedTunnelConfig
connIndex uint8
newRPCClientFunc RPCClientFunc
gracefulShutdownC <-chan struct{}
gracePeriod time.Duration
stoppedGracefully bool
}
// ControlStreamHandler registers connections with origintunneld and initiates graceful shutdown.
type ControlStreamHandler interface {
ServeControlStream(ctx context.Context, rw io.ReadWriteCloser, connOptions *tunnelpogs.ConnectionOptions) error
IsStopped() bool
}
// NewControlStream returns a new instance of ControlStreamHandler
func NewControlStream(
observer *Observer,
connectedFuse ConnectedFuse,
namedTunnelConfig *NamedTunnelConfig,
connIndex uint8,
newRPCClientFunc RPCClientFunc,
gracefulShutdownC <-chan struct{},
gracePeriod time.Duration,
) ControlStreamHandler {
if newRPCClientFunc == nil {
newRPCClientFunc = newRegistrationRPCClient
}
return &controlStream{
observer: observer,
connectedFuse: connectedFuse,
namedTunnelConfig: namedTunnelConfig,
newRPCClientFunc: newRPCClientFunc,
connIndex: connIndex,
gracefulShutdownC: gracefulShutdownC,
gracePeriod: gracePeriod,
}
}
func (c *controlStream) ServeControlStream(
ctx context.Context,
rw io.ReadWriteCloser,
connOptions *tunnelpogs.ConnectionOptions,
) error {
rpcClient := c.newRPCClientFunc(ctx, rw, c.observer.log)
defer rpcClient.Close()
if err := rpcClient.RegisterConnection(ctx, c.namedTunnelConfig, connOptions, c.connIndex, c.observer); err != nil {
return err
}
c.connectedFuse.Connected()
// wait for connection termination or start of graceful shutdown
select {
case <-ctx.Done():
break
case <-c.gracefulShutdownC:
c.stoppedGracefully = true
}
c.observer.sendUnregisteringEvent(c.connIndex)
rpcClient.GracefulShutdown(ctx, c.gracePeriod)
c.observer.log.Info().Uint8(LogFieldConnIndex, c.connIndex).Msg("Unregistered tunnel connection")
return nil
}
func (c *controlStream) IsStopped() bool {
return c.stoppedGracefully
}

View File

@ -33,18 +33,15 @@ type HTTP2Connection struct {
conn net.Conn conn net.Conn
server *http2.Server server *http2.Server
config *Config config *Config
namedTunnel *NamedTunnelConfig
connOptions *tunnelpogs.ConnectionOptions connOptions *tunnelpogs.ConnectionOptions
observer *Observer observer *Observer
connIndexStr string
connIndex uint8 connIndex uint8
// newRPCClientFunc allows us to mock RPCs during testing // newRPCClientFunc allows us to mock RPCs during testing
newRPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient newRPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient
log *zerolog.Logger log *zerolog.Logger
activeRequestsWG sync.WaitGroup activeRequestsWG sync.WaitGroup
connectedFuse ConnectedFuse controlStreamHandler ControlStreamHandler
gracefulShutdownC <-chan struct{}
stoppedGracefully bool stoppedGracefully bool
controlStreamErr error // result of running control stream handler controlStreamErr error // result of running control stream handler
} }
@ -53,13 +50,11 @@ type HTTP2Connection struct {
func NewHTTP2Connection( func NewHTTP2Connection(
conn net.Conn, conn net.Conn,
config *Config, config *Config,
namedTunnelConfig *NamedTunnelConfig,
connOptions *tunnelpogs.ConnectionOptions, connOptions *tunnelpogs.ConnectionOptions,
observer *Observer, observer *Observer,
connIndex uint8, connIndex uint8,
connectedFuse ConnectedFuse, controlStreamHandler ControlStreamHandler,
log *zerolog.Logger, log *zerolog.Logger,
gracefulShutdownC <-chan struct{},
) *HTTP2Connection { ) *HTTP2Connection {
return &HTTP2Connection{ return &HTTP2Connection{
conn: conn, conn: conn,
@ -67,15 +62,12 @@ func NewHTTP2Connection(
MaxConcurrentStreams: math.MaxUint32, MaxConcurrentStreams: math.MaxUint32,
}, },
config: config, config: config,
namedTunnel: namedTunnelConfig,
connOptions: connOptions, connOptions: connOptions,
observer: observer, observer: observer,
connIndexStr: uint8ToString(connIndex),
connIndex: connIndex, connIndex: connIndex,
newRPCClientFunc: newRegistrationRPCClient, newRPCClientFunc: newRegistrationRPCClient,
connectedFuse: connectedFuse, controlStreamHandler: controlStreamHandler,
log: log, log: log,
gracefulShutdownC: gracefulShutdownC,
} }
} }
@ -91,7 +83,7 @@ func (c *HTTP2Connection) Serve(ctx context.Context) error {
}) })
switch { switch {
case c.stoppedGracefully: case c.controlStreamHandler.IsStopped():
return nil return nil
case c.controlStreamErr != nil: case c.controlStreamErr != nil:
return c.controlStreamErr return c.controlStreamErr
@ -116,7 +108,7 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
switch connType { switch connType {
case TypeControlStream: case TypeControlStream:
if err := c.serveControlStream(r.Context(), respWriter); err != nil { if err := c.controlStreamHandler.ServeControlStream(r.Context(), respWriter, c.connOptions); err != nil {
c.controlStreamErr = err c.controlStreamErr = err
c.log.Error().Err(err) c.log.Error().Err(err)
respWriter.WriteErrorResponse() respWriter.WriteErrorResponse()
@ -154,29 +146,6 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
} }
func (c *HTTP2Connection) serveControlStream(ctx context.Context, respWriter *http2RespWriter) error {
rpcClient := c.newRPCClientFunc(ctx, respWriter, c.observer.log)
defer rpcClient.Close()
if err := rpcClient.RegisterConnection(ctx, c.namedTunnel, c.connOptions, c.connIndex, c.observer); err != nil {
return err
}
c.connectedFuse.Connected()
// wait for connection termination or start of graceful shutdown
select {
case <-ctx.Done():
break
case <-c.gracefulShutdownC:
c.stoppedGracefully = true
}
c.observer.sendUnregisteringEvent(c.connIndex)
rpcClient.GracefulShutdown(ctx, c.config.GracePeriod)
c.observer.log.Info().Uint8(LogFieldConnIndex, c.connIndex).Msg("Unregistered tunnel connection")
return nil
}
func (c *HTTP2Connection) close() { func (c *HTTP2Connection) close() {
// Wait for all serve HTTP handlers to return // Wait for all serve HTTP handlers to return
c.activeRequestsWG.Wait() c.activeRequestsWG.Wait()

View File

@ -30,16 +30,24 @@ func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) {
edgeConn, originConn := net.Pipe() edgeConn, originConn := net.Pipe()
var connIndex = uint8(0) var connIndex = uint8(0)
log := zerolog.Nop() log := zerolog.Nop()
obs := NewObserver(&log, &log, false)
controlStream := NewControlStream(
obs,
mockConnectedFuse{},
&NamedTunnelConfig{},
connIndex,
nil,
nil,
1*time.Second,
)
return NewHTTP2Connection( return NewHTTP2Connection(
originConn, originConn,
testConfig, testConfig,
&NamedTunnelConfig{},
&pogs.ConnectionOptions{}, &pogs.ConnectionOptions{},
NewObserver(&log, &log, false), obs,
connIndex, connIndex,
mockConnectedFuse{}, controlStream,
&log, &log,
nil,
), edgeConn ), edgeConn
} }
@ -225,7 +233,18 @@ func TestServeControlStream(t *testing.T) {
registered: make(chan struct{}), registered: make(chan struct{}),
unregistered: make(chan struct{}), unregistered: make(chan struct{}),
} }
http2Conn.newRPCClientFunc = rpcClientFactory.newMockRPCClient
obs := NewObserver(&log, &log, false)
controlStream := NewControlStream(
obs,
mockConnectedFuse{},
&NamedTunnelConfig{},
1,
rpcClientFactory.newMockRPCClient,
nil,
1*time.Second,
)
http2Conn.controlStreamHandler = controlStream
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup var wg sync.WaitGroup
@ -264,7 +283,18 @@ func TestFailRegistration(t *testing.T) {
registered: make(chan struct{}), registered: make(chan struct{}),
unregistered: make(chan struct{}), unregistered: make(chan struct{}),
} }
http2Conn.newRPCClientFunc = rpcClientFactory.newMockRPCClient
obs := NewObserver(&log, &log, false)
controlStream := NewControlStream(
obs,
mockConnectedFuse{},
&NamedTunnelConfig{},
http2Conn.connIndex,
rpcClientFactory.newMockRPCClient,
nil,
1*time.Second,
)
http2Conn.controlStreamHandler = controlStream
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup var wg sync.WaitGroup
@ -297,10 +327,21 @@ func TestGracefulShutdownHTTP2(t *testing.T) {
unregistered: make(chan struct{}), unregistered: make(chan struct{}),
} }
events := &eventCollectorSink{} events := &eventCollectorSink{}
http2Conn.newRPCClientFunc = rpcClientFactory.newMockRPCClient
http2Conn.observer.RegisterSink(events)
shutdownC := make(chan struct{}) shutdownC := make(chan struct{})
http2Conn.gracefulShutdownC = shutdownC obs := NewObserver(&log, &log, false)
obs.RegisterSink(events)
controlStream := NewControlStream(
obs,
mockConnectedFuse{},
&NamedTunnelConfig{},
http2Conn.connIndex,
rpcClientFactory.newMockRPCClient,
shutdownC,
1*time.Second,
)
http2Conn.controlStreamHandler = controlStream
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup var wg sync.WaitGroup
@ -339,7 +380,7 @@ func TestGracefulShutdownHTTP2(t *testing.T) {
case <-time.Tick(time.Second): case <-time.Tick(time.Second):
t.Fatal("timeout out waiting for unregistered signal") t.Fatal("timeout out waiting for unregistered signal")
} }
assert.True(t, http2Conn.stoppedGracefully) assert.True(t, controlStream.IsStopped())
cancel() cancel()
wg.Wait() wg.Wait()

View File

@ -15,33 +15,29 @@ const (
edgeH2muxTLSServerName = "cftunnel.com" edgeH2muxTLSServerName = "cftunnel.com"
// edgeH2TLSServerName is the server name to establish http2 connection with edge // edgeH2TLSServerName is the server name to establish http2 connection with edge
edgeH2TLSServerName = "h2.cftunnel.com" edgeH2TLSServerName = "h2.cftunnel.com"
// edgeQUICServerName is the server name to establish quic connection with edge.
edgeQUICServerName = "quic.cftunnel.com"
// threshold to switch back to h2mux when the user intentionally pick --protocol http2 // threshold to switch back to h2mux when the user intentionally pick --protocol http2
explicitHTTP2FallbackThreshold = -1 explicitHTTP2FallbackThreshold = -1
autoSelectFlag = "auto" autoSelectFlag = "auto"
) )
var ( var (
ProtocolList = []Protocol{H2mux, HTTP2} // ProtocolList represents a list of supported protocols for communication with the edge.
ProtocolList = []Protocol{H2mux, HTTP2, QUIC}
) )
type Protocol int64 type Protocol int64
const ( const (
// H2mux protocol can be used both with Classic and Named Tunnels. .
H2mux Protocol = iota H2mux Protocol = iota
// HTTP2 is used only with named tunnels. It's more efficient than H2mux for L4 proxying.
HTTP2 HTTP2
// QUIC is used only with named tunnels.
QUIC
) )
func (p Protocol) ServerName() string {
switch p {
case H2mux:
return edgeH2muxTLSServerName
case HTTP2:
return edgeH2TLSServerName
default:
return ""
}
}
// Fallback returns the fallback protocol and whether the protocol has a fallback // Fallback returns the fallback protocol and whether the protocol has a fallback
func (p Protocol) fallback() (Protocol, bool) { func (p Protocol) fallback() (Protocol, bool) {
switch p { switch p {
@ -49,6 +45,8 @@ func (p Protocol) fallback() (Protocol, bool) {
return 0, false return 0, false
case HTTP2: case HTTP2:
return H2mux, true return H2mux, true
case QUIC:
return HTTP2, true
default: default:
return 0, false return 0, false
} }
@ -60,11 +58,39 @@ func (p Protocol) String() string {
return "h2mux" return "h2mux"
case HTTP2: case HTTP2:
return "http2" return "http2"
case QUIC:
return "quic"
default: default:
return fmt.Sprintf("unknown protocol") return fmt.Sprintf("unknown protocol")
} }
} }
func (p Protocol) TLSSettings() *TLSSettings {
switch p {
case H2mux:
return &TLSSettings{
ServerName: edgeH2muxTLSServerName,
}
case HTTP2:
return &TLSSettings{
ServerName: edgeH2TLSServerName,
}
case QUIC:
fmt.Println("returning this?")
return &TLSSettings{
ServerName: edgeQUICServerName,
NextProtos: []string{"argotunnel"},
}
default:
return nil
}
}
type TLSSettings struct {
ServerName string
NextProtos []string
}
type ProtocolSelector interface { type ProtocolSelector interface {
Current() Protocol Current() Protocol
Fallback() (Protocol, bool) Fallback() (Protocol, bool)
@ -184,6 +210,10 @@ func NewProtocolSelector(
return newAutoProtocolSelector(HTTP2, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil return newAutoProtocolSelector(HTTP2, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil
} }
if protocolFlag == QUIC.String() {
return newAutoProtocolSelector(QUIC, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil
}
if protocolFlag != autoSelectFlag { if protocolFlag != autoSelectFlag {
return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage) return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage)
} }

View File

@ -15,6 +15,7 @@ import (
"github.com/rs/zerolog" "github.com/rs/zerolog"
quicpogs "github.com/cloudflare/cloudflared/quic" quicpogs "github.com/cloudflare/cloudflared/quic"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
) )
const ( const (
@ -29,8 +30,10 @@ const (
// QUICConnection represents the type that facilitates Proxying via QUIC streams. // QUICConnection represents the type that facilitates Proxying via QUIC streams.
type QUICConnection struct { type QUICConnection struct {
session quic.Session session quic.Session
logger zerolog.Logger logger *zerolog.Logger
httpProxy OriginProxy httpProxy OriginProxy
gracefulShutdownC <-chan struct{}
stoppedGracefully bool
} }
// NewQUICConnection returns a new instance of QUICConnection. // NewQUICConnection returns a new instance of QUICConnection.
@ -40,19 +43,26 @@ func NewQUICConnection(
edgeAddr net.Addr, edgeAddr net.Addr,
tlsConfig *tls.Config, tlsConfig *tls.Config,
httpProxy OriginProxy, httpProxy OriginProxy,
logger zerolog.Logger, connOptions *tunnelpogs.ConnectionOptions,
controlStreamHandler ControlStreamHandler,
observer *Observer,
) (*QUICConnection, error) { ) (*QUICConnection, error) {
session, err := quic.DialAddr(edgeAddr.String(), tlsConfig, quicConfig) session, err := quic.DialAddr(edgeAddr.String(), tlsConfig, quicConfig)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "failed to dial to edge") return nil, errors.Wrap(err, "failed to dial to edge")
} }
//TODO: RegisterConnectionRPC here. registrationStream, err := session.OpenStream()
if err != nil {
return nil, errors.Wrap(err, "failed to open a registration stream")
}
go controlStreamHandler.ServeControlStream(ctx, registrationStream, connOptions)
return &QUICConnection{ return &QUICConnection{
session: session, session: session,
httpProxy: httpProxy, httpProxy: httpProxy,
logger: logger, logger: observer.log,
}, nil }, nil
} }
@ -96,11 +106,26 @@ func (q *QUICConnection) handleStream(stream quic.Stream) error {
w := newHTTPResponseAdapter(stream) w := newHTTPResponseAdapter(stream)
return q.httpProxy.ProxyHTTP(w, req, connectRequest.Type == quicpogs.ConnectionTypeWebsocket) return q.httpProxy.ProxyHTTP(w, req, connectRequest.Type == quicpogs.ConnectionTypeWebsocket)
case quicpogs.ConnectionTypeTCP: case quicpogs.ConnectionTypeTCP:
return errors.New("not implemented") // TODO: This is a placeholder for testing completion. TUN-4865 will add proper TCP support.
rwa := &streamReadWriteAcker{
ReadWriter: stream,
}
return q.httpProxy.ProxyTCP(context.Background(), rwa, &TCPRequest{Dest: connectRequest.Dest})
} }
return nil return nil
} }
// streamReadWriteAcker is a light wrapper over QUIC streams with a callback to send response back to
// the client.
type streamReadWriteAcker struct {
io.ReadWriter
}
// AckConnection acks response back to the proxy.
func (s *streamReadWriteAcker) AckConnection() error {
return quicpogs.WriteConnectResponseData(s, nil)
}
// httpResponseAdapter translates responses written by the HTTP Proxy into ones that can be used in QUIC. // httpResponseAdapter translates responses written by the HTTP Proxy into ones that can be used in QUIC.
type httpResponseAdapter struct { type httpResponseAdapter struct {
io.Writer io.Writer

View File

@ -16,6 +16,7 @@ import (
"os" "os"
"sync" "sync"
"testing" "testing"
"time"
"github.com/gobwas/ws/wsutil" "github.com/gobwas/ws/wsutil"
"github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go"
@ -25,6 +26,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
quicpogs "github.com/cloudflare/cloudflared/quic" quicpogs "github.com/cloudflare/cloudflared/quic"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
) )
// TestQUICServer tests if a quic server accepts and responds to a quic client with the acceptance protocol. // TestQUICServer tests if a quic server accepts and responds to a quic client with the acceptance protocol.
@ -134,6 +136,13 @@ func TestQUICServer(t *testing.T) {
message: wsBuf.Bytes(), message: wsBuf.Bytes(),
expectedResponse: []byte{0x81, 0x5, 0x48, 0x65, 0x6c, 0x6c, 0x6f}, expectedResponse: []byte{0x81, 0x5, 0x48, 0x65, 0x6c, 0x6c, 0x6f},
}, },
{
desc: "test tcp proxy",
connectionType: quicpogs.ConnectionTypeTCP,
metadata: []quicpogs.Metadata{},
message: []byte("Here is some tcp data"),
expectedResponse: []byte("Here is some tcp data"),
},
} }
for _, test := range tests { for _, test := range tests {
@ -149,11 +158,43 @@ func TestQUICServer(t *testing.T) {
) )
}() }()
qC, err := NewQUICConnection(ctx, quicConfig, udpListener.LocalAddr(), tlsClientConfig, originProxy, log) rpcClientFactory := mockRPCClientFactory{
registered: make(chan struct{}),
unregistered: make(chan struct{}),
}
obs := NewObserver(&log, &log, false)
controlStream := NewControlStream(
obs,
mockConnectedFuse{},
&NamedTunnelConfig{},
1,
rpcClientFactory.newMockRPCClient,
nil,
1*time.Second,
)
qC, err := NewQUICConnection(
ctx,
quicConfig,
udpListener.LocalAddr(),
tlsClientConfig,
originProxy,
&pogs.ConnectionOptions{},
controlStream,
NewObserver(&log, &log, false),
)
require.NoError(t, err) require.NoError(t, err)
go qC.Serve(ctx) go qC.Serve(ctx)
wg.Wait() wg.Wait()
select {
case <-rpcClientFactory.registered:
break //ok
case <-time.Tick(time.Second):
t.Fatal("timeout out waiting for registration")
}
cancel() cancel()
}) })
} }
@ -174,7 +215,7 @@ func quicServer(
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
earlyListener, err := quic.ListenEarly(conn, tlsConf, config) earlyListener, err := quic.Listen(conn, tlsConf, config)
require.NoError(t, err) require.NoError(t, err)
session, err := earlyListener.Accept(ctx) session, err := earlyListener.Accept(ctx)
@ -183,7 +224,6 @@ func quicServer(
stream, err := session.OpenStreamSync(context.Background()) stream, err := session.OpenStreamSync(context.Background())
require.NoError(t, err) require.NoError(t, err)
// Start off ALPN
err = quicpogs.WriteConnectRequestData(stream, dest, connectionType, metadata...) err = quicpogs.WriteConnectRequestData(stream, dest, connectionType, metadata...)
require.NoError(t, err) require.NoError(t, err)
@ -264,5 +304,7 @@ func (moc *mockOriginProxyWithRequest) ProxyHTTP(w ResponseWriter, r *http.Reque
} }
func (moc *mockOriginProxyWithRequest) ProxyTCP(ctx context.Context, rwa ReadWriteAcker, tcpRequest *TCPRequest) error { func (moc *mockOriginProxyWithRequest) ProxyTCP(ctx context.Context, rwa ReadWriteAcker, tcpRequest *TCPRequest) error {
rwa.AckConnection()
io.Copy(rwa, rwa)
return nil return nil
} }

View File

@ -11,6 +11,7 @@ import (
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/lucas-clemente/quic-go"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
@ -271,15 +272,38 @@ func ServeTunnel(
defer config.Observer.SendDisconnect(connIndex) defer config.Observer.SendDisconnect(connIndex)
connectedFuse := &connectedFuse{
fuse: fuse,
backoff: backoff,
}
controlStream := connection.NewControlStream(
config.Observer,
connectedFuse,
config.NamedTunnel,
connIndex,
nil,
gracefulShutdownC,
config.ConnectionConfig.GracePeriod,
)
if protocol == connection.QUIC {
connOptions := config.ConnectionOptions(addr.UDP.String(), uint8(backoff.Retries()))
return ServeQUIC(ctx,
addr.UDP,
config,
connOptions,
controlStream,
connectedFuse,
reconnectCh,
gracefulShutdownC)
}
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, config.EdgeTLSConfigs[protocol], addr.TCP) edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, config.EdgeTLSConfigs[protocol], addr.TCP)
if err != nil { if err != nil {
connLog.Err(err).Msg("Unable to establish connection with Cloudflare edge") connLog.Err(err).Msg("Unable to establish connection with Cloudflare edge")
return err, true return err, true
} }
connectedFuse := &connectedFuse{
fuse: fuse,
backoff: backoff,
}
if protocol == connection.HTTP2 { if protocol == connection.HTTP2 {
connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.Retries())) connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.Retries()))
@ -289,10 +313,10 @@ func ServeTunnel(
config, config,
edgeConn, edgeConn,
connOptions, connOptions,
controlStream,
connIndex, connIndex,
connectedFuse,
reconnectCh,
gracefulShutdownC, gracefulShutdownC,
reconnectCh,
) )
} else { } else {
err = ServeH2mux( err = ServeH2mux(
@ -403,22 +427,20 @@ func ServeHTTP2(
config *TunnelConfig, config *TunnelConfig,
tlsServerConn net.Conn, tlsServerConn net.Conn,
connOptions *tunnelpogs.ConnectionOptions, connOptions *tunnelpogs.ConnectionOptions,
controlStreamHandler connection.ControlStreamHandler,
connIndex uint8, connIndex uint8,
connectedFuse connection.ConnectedFuse,
reconnectCh chan ReconnectSignal,
gracefulShutdownC <-chan struct{}, gracefulShutdownC <-chan struct{},
reconnectCh chan ReconnectSignal,
) error { ) error {
connLog.Debug().Msgf("Connecting via http2") connLog.Debug().Msgf("Connecting via http2")
h2conn := connection.NewHTTP2Connection( h2conn := connection.NewHTTP2Connection(
tlsServerConn, tlsServerConn,
config.ConnectionConfig, config.ConnectionConfig,
config.NamedTunnel,
connOptions, connOptions,
config.Observer, config.Observer,
connIndex, connIndex,
connectedFuse, controlStreamHandler,
config.Log, config.Log,
gracefulShutdownC,
) )
errGroup, serveCtx := errgroup.WithContext(ctx) errGroup, serveCtx := errgroup.WithContext(ctx)
@ -438,6 +460,75 @@ func ServeHTTP2(
return errGroup.Wait() return errGroup.Wait()
} }
func ServeQUIC(
ctx context.Context,
edgeAddr *net.UDPAddr,
config *TunnelConfig,
connOptions *tunnelpogs.ConnectionOptions,
controlStreamHandler connection.ControlStreamHandler,
connectedFuse connection.ConnectedFuse,
reconnectCh chan ReconnectSignal,
gracefulShutdownC <-chan struct{},
) (err error, recoverable bool) {
tlsConfig := config.EdgeTLSConfigs[connection.QUIC]
quicConfig := &quic.Config{
HandshakeIdleTimeout: time.Second * 10,
KeepAlive: true,
}
for {
select {
case <-ctx.Done():
return
default:
quicConn, err := connection.NewQUICConnection(
ctx,
quicConfig,
edgeAddr,
tlsConfig,
config.ConnectionConfig.OriginProxy,
connOptions,
controlStreamHandler,
config.Observer)
if err != nil {
config.Log.Error().Msgf("Failed to create new quic connection, err: %v", err)
return err, true
}
errGroup, serveCtx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
err := quicConn.Serve(ctx)
if err != nil {
config.Log.Error().Msgf("Failed to serve quic connection, err: %v", err)
}
return fmt.Errorf("Connection with edge closed")
})
errGroup.Go(func() error {
return listenReconnect(serveCtx, reconnectCh, gracefulShutdownC)
})
err = errGroup.Wait()
if err == nil {
return nil, false
}
config.Log.Info().Msg("Reconnecting with the same udp conn")
}
}
}
type quicLogger struct {
*zerolog.Logger
}
func (ql *quicLogger) Write(p []byte) (n int, err error) {
ql.Debug().Msgf("quic log: %v", string(p))
return len(p), nil
}
func (ql *quicLogger) Close() error {
return nil
}
func listenReconnect(ctx context.Context, reconnectCh <-chan ReconnectSignal, gracefulShutdownCh <-chan struct{}) error { func listenReconnect(ctx context.Context, reconnectCh <-chan ReconnectSignal, gracefulShutdownCh <-chan struct{}) error {
select { select {
case reconnect := <-reconnectCh: case reconnect := <-reconnectCh: