TUN-4866: Add Control Stream for QUIC

This commit adds support to Register and Unregister Connections via RPC
on the QUIC transport protocol
This commit is contained in:
Sudarsan Reddy 2021-08-17 15:30:02 +01:00
parent 1082ac1c36
commit 12ad264eb3
8 changed files with 390 additions and 96 deletions

View File

@ -248,10 +248,17 @@ func prepareTunnelConfig(
edgeTLSConfigs := make(map[connection.Protocol]*tls.Config, len(connection.ProtocolList))
for _, p := range connection.ProtocolList {
edgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c, p.ServerName())
tlsSettings := p.TLSSettings()
if tlsSettings == nil {
return nil, ingress.Ingress{}, fmt.Errorf("%s has unknown TLS settings", p)
}
edgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c, tlsSettings.ServerName)
if err != nil {
return nil, ingress.Ingress{}, errors.Wrap(err, "unable to create TLS config to connect with edge")
}
if len(edgeTLSConfig.NextProtos) > 0 {
edgeTLSConfig.NextProtos = tlsSettings.NextProtos
}
edgeTLSConfigs[p] = edgeTLSConfig
}

89
connection/control.go Normal file
View File

@ -0,0 +1,89 @@
package connection
import (
"context"
"io"
"time"
"github.com/rs/zerolog"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
// RPCClientFunc derives a named tunnel rpc client that can then be used to register and unregister connections.
type RPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient
type controlStream struct {
observer *Observer
connectedFuse ConnectedFuse
namedTunnelConfig *NamedTunnelConfig
connIndex uint8
newRPCClientFunc RPCClientFunc
gracefulShutdownC <-chan struct{}
gracePeriod time.Duration
stoppedGracefully bool
}
// ControlStreamHandler registers connections with origintunneld and initiates graceful shutdown.
type ControlStreamHandler interface {
ServeControlStream(ctx context.Context, rw io.ReadWriteCloser, connOptions *tunnelpogs.ConnectionOptions) error
IsStopped() bool
}
// NewControlStream returns a new instance of ControlStreamHandler
func NewControlStream(
observer *Observer,
connectedFuse ConnectedFuse,
namedTunnelConfig *NamedTunnelConfig,
connIndex uint8,
newRPCClientFunc RPCClientFunc,
gracefulShutdownC <-chan struct{},
gracePeriod time.Duration,
) ControlStreamHandler {
if newRPCClientFunc == nil {
newRPCClientFunc = newRegistrationRPCClient
}
return &controlStream{
observer: observer,
connectedFuse: connectedFuse,
namedTunnelConfig: namedTunnelConfig,
newRPCClientFunc: newRPCClientFunc,
connIndex: connIndex,
gracefulShutdownC: gracefulShutdownC,
gracePeriod: gracePeriod,
}
}
func (c *controlStream) ServeControlStream(
ctx context.Context,
rw io.ReadWriteCloser,
connOptions *tunnelpogs.ConnectionOptions,
) error {
rpcClient := c.newRPCClientFunc(ctx, rw, c.observer.log)
defer rpcClient.Close()
if err := rpcClient.RegisterConnection(ctx, c.namedTunnelConfig, connOptions, c.connIndex, c.observer); err != nil {
return err
}
c.connectedFuse.Connected()
// wait for connection termination or start of graceful shutdown
select {
case <-ctx.Done():
break
case <-c.gracefulShutdownC:
c.stoppedGracefully = true
}
c.observer.sendUnregisteringEvent(c.connIndex)
rpcClient.GracefulShutdown(ctx, c.gracePeriod)
c.observer.log.Info().Uint8(LogFieldConnIndex, c.connIndex).Msg("Unregistered tunnel connection")
return nil
}
func (c *controlStream) IsStopped() bool {
return c.stoppedGracefully
}

View File

@ -30,52 +30,44 @@ var errEdgeConnectionClosed = fmt.Errorf("connection with edge closed")
// HTTP2Connection represents a net.Conn that uses HTTP2 frames to proxy traffic from the edge to cloudflared on the
// origin.
type HTTP2Connection struct {
conn net.Conn
server *http2.Server
config *Config
namedTunnel *NamedTunnelConfig
connOptions *tunnelpogs.ConnectionOptions
observer *Observer
connIndexStr string
connIndex uint8
conn net.Conn
server *http2.Server
config *Config
connOptions *tunnelpogs.ConnectionOptions
observer *Observer
connIndex uint8
// newRPCClientFunc allows us to mock RPCs during testing
newRPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient
log *zerolog.Logger
activeRequestsWG sync.WaitGroup
connectedFuse ConnectedFuse
gracefulShutdownC <-chan struct{}
stoppedGracefully bool
controlStreamErr error // result of running control stream handler
log *zerolog.Logger
activeRequestsWG sync.WaitGroup
controlStreamHandler ControlStreamHandler
stoppedGracefully bool
controlStreamErr error // result of running control stream handler
}
// NewHTTP2Connection returns a new instance of HTTP2Connection.
func NewHTTP2Connection(
conn net.Conn,
config *Config,
namedTunnelConfig *NamedTunnelConfig,
connOptions *tunnelpogs.ConnectionOptions,
observer *Observer,
connIndex uint8,
connectedFuse ConnectedFuse,
controlStreamHandler ControlStreamHandler,
log *zerolog.Logger,
gracefulShutdownC <-chan struct{},
) *HTTP2Connection {
return &HTTP2Connection{
conn: conn,
server: &http2.Server{
MaxConcurrentStreams: math.MaxUint32,
},
config: config,
namedTunnel: namedTunnelConfig,
connOptions: connOptions,
observer: observer,
connIndexStr: uint8ToString(connIndex),
connIndex: connIndex,
newRPCClientFunc: newRegistrationRPCClient,
connectedFuse: connectedFuse,
log: log,
gracefulShutdownC: gracefulShutdownC,
config: config,
connOptions: connOptions,
observer: observer,
connIndex: connIndex,
newRPCClientFunc: newRegistrationRPCClient,
controlStreamHandler: controlStreamHandler,
log: log,
}
}
@ -91,7 +83,7 @@ func (c *HTTP2Connection) Serve(ctx context.Context) error {
})
switch {
case c.stoppedGracefully:
case c.controlStreamHandler.IsStopped():
return nil
case c.controlStreamErr != nil:
return c.controlStreamErr
@ -116,7 +108,7 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
switch connType {
case TypeControlStream:
if err := c.serveControlStream(r.Context(), respWriter); err != nil {
if err := c.controlStreamHandler.ServeControlStream(r.Context(), respWriter, c.connOptions); err != nil {
c.controlStreamErr = err
c.log.Error().Err(err)
respWriter.WriteErrorResponse()
@ -154,29 +146,6 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}
func (c *HTTP2Connection) serveControlStream(ctx context.Context, respWriter *http2RespWriter) error {
rpcClient := c.newRPCClientFunc(ctx, respWriter, c.observer.log)
defer rpcClient.Close()
if err := rpcClient.RegisterConnection(ctx, c.namedTunnel, c.connOptions, c.connIndex, c.observer); err != nil {
return err
}
c.connectedFuse.Connected()
// wait for connection termination or start of graceful shutdown
select {
case <-ctx.Done():
break
case <-c.gracefulShutdownC:
c.stoppedGracefully = true
}
c.observer.sendUnregisteringEvent(c.connIndex)
rpcClient.GracefulShutdown(ctx, c.config.GracePeriod)
c.observer.log.Info().Uint8(LogFieldConnIndex, c.connIndex).Msg("Unregistered tunnel connection")
return nil
}
func (c *HTTP2Connection) close() {
// Wait for all serve HTTP handlers to return
c.activeRequestsWG.Wait()

View File

@ -30,16 +30,24 @@ func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) {
edgeConn, originConn := net.Pipe()
var connIndex = uint8(0)
log := zerolog.Nop()
obs := NewObserver(&log, &log, false)
controlStream := NewControlStream(
obs,
mockConnectedFuse{},
&NamedTunnelConfig{},
connIndex,
nil,
nil,
1*time.Second,
)
return NewHTTP2Connection(
originConn,
testConfig,
&NamedTunnelConfig{},
&pogs.ConnectionOptions{},
NewObserver(&log, &log, false),
obs,
connIndex,
mockConnectedFuse{},
controlStream,
&log,
nil,
), edgeConn
}
@ -225,7 +233,18 @@ func TestServeControlStream(t *testing.T) {
registered: make(chan struct{}),
unregistered: make(chan struct{}),
}
http2Conn.newRPCClientFunc = rpcClientFactory.newMockRPCClient
obs := NewObserver(&log, &log, false)
controlStream := NewControlStream(
obs,
mockConnectedFuse{},
&NamedTunnelConfig{},
1,
rpcClientFactory.newMockRPCClient,
nil,
1*time.Second,
)
http2Conn.controlStreamHandler = controlStream
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
@ -264,7 +283,18 @@ func TestFailRegistration(t *testing.T) {
registered: make(chan struct{}),
unregistered: make(chan struct{}),
}
http2Conn.newRPCClientFunc = rpcClientFactory.newMockRPCClient
obs := NewObserver(&log, &log, false)
controlStream := NewControlStream(
obs,
mockConnectedFuse{},
&NamedTunnelConfig{},
http2Conn.connIndex,
rpcClientFactory.newMockRPCClient,
nil,
1*time.Second,
)
http2Conn.controlStreamHandler = controlStream
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
@ -297,10 +327,21 @@ func TestGracefulShutdownHTTP2(t *testing.T) {
unregistered: make(chan struct{}),
}
events := &eventCollectorSink{}
http2Conn.newRPCClientFunc = rpcClientFactory.newMockRPCClient
http2Conn.observer.RegisterSink(events)
shutdownC := make(chan struct{})
http2Conn.gracefulShutdownC = shutdownC
obs := NewObserver(&log, &log, false)
obs.RegisterSink(events)
controlStream := NewControlStream(
obs,
mockConnectedFuse{},
&NamedTunnelConfig{},
http2Conn.connIndex,
rpcClientFactory.newMockRPCClient,
shutdownC,
1*time.Second,
)
http2Conn.controlStreamHandler = controlStream
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
@ -339,7 +380,7 @@ func TestGracefulShutdownHTTP2(t *testing.T) {
case <-time.Tick(time.Second):
t.Fatal("timeout out waiting for unregistered signal")
}
assert.True(t, http2Conn.stoppedGracefully)
assert.True(t, controlStream.IsStopped())
cancel()
wg.Wait()

View File

@ -15,33 +15,29 @@ const (
edgeH2muxTLSServerName = "cftunnel.com"
// edgeH2TLSServerName is the server name to establish http2 connection with edge
edgeH2TLSServerName = "h2.cftunnel.com"
// edgeQUICServerName is the server name to establish quic connection with edge.
edgeQUICServerName = "quic.cftunnel.com"
// threshold to switch back to h2mux when the user intentionally pick --protocol http2
explicitHTTP2FallbackThreshold = -1
autoSelectFlag = "auto"
)
var (
ProtocolList = []Protocol{H2mux, HTTP2}
// ProtocolList represents a list of supported protocols for communication with the edge.
ProtocolList = []Protocol{H2mux, HTTP2, QUIC}
)
type Protocol int64
const (
// H2mux protocol can be used both with Classic and Named Tunnels. .
H2mux Protocol = iota
// HTTP2 is used only with named tunnels. It's more efficient than H2mux for L4 proxying.
HTTP2
// QUIC is used only with named tunnels.
QUIC
)
func (p Protocol) ServerName() string {
switch p {
case H2mux:
return edgeH2muxTLSServerName
case HTTP2:
return edgeH2TLSServerName
default:
return ""
}
}
// Fallback returns the fallback protocol and whether the protocol has a fallback
func (p Protocol) fallback() (Protocol, bool) {
switch p {
@ -49,6 +45,8 @@ func (p Protocol) fallback() (Protocol, bool) {
return 0, false
case HTTP2:
return H2mux, true
case QUIC:
return HTTP2, true
default:
return 0, false
}
@ -60,11 +58,39 @@ func (p Protocol) String() string {
return "h2mux"
case HTTP2:
return "http2"
case QUIC:
return "quic"
default:
return fmt.Sprintf("unknown protocol")
}
}
func (p Protocol) TLSSettings() *TLSSettings {
switch p {
case H2mux:
return &TLSSettings{
ServerName: edgeH2muxTLSServerName,
}
case HTTP2:
return &TLSSettings{
ServerName: edgeH2TLSServerName,
}
case QUIC:
fmt.Println("returning this?")
return &TLSSettings{
ServerName: edgeQUICServerName,
NextProtos: []string{"argotunnel"},
}
default:
return nil
}
}
type TLSSettings struct {
ServerName string
NextProtos []string
}
type ProtocolSelector interface {
Current() Protocol
Fallback() (Protocol, bool)
@ -184,6 +210,10 @@ func NewProtocolSelector(
return newAutoProtocolSelector(HTTP2, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil
}
if protocolFlag == QUIC.String() {
return newAutoProtocolSelector(QUIC, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil
}
if protocolFlag != autoSelectFlag {
return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage)
}

View File

@ -15,6 +15,7 @@ import (
"github.com/rs/zerolog"
quicpogs "github.com/cloudflare/cloudflared/quic"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
const (
@ -28,9 +29,11 @@ const (
// QUICConnection represents the type that facilitates Proxying via QUIC streams.
type QUICConnection struct {
session quic.Session
logger zerolog.Logger
httpProxy OriginProxy
session quic.Session
logger *zerolog.Logger
httpProxy OriginProxy
gracefulShutdownC <-chan struct{}
stoppedGracefully bool
}
// NewQUICConnection returns a new instance of QUICConnection.
@ -40,19 +43,26 @@ func NewQUICConnection(
edgeAddr net.Addr,
tlsConfig *tls.Config,
httpProxy OriginProxy,
logger zerolog.Logger,
connOptions *tunnelpogs.ConnectionOptions,
controlStreamHandler ControlStreamHandler,
observer *Observer,
) (*QUICConnection, error) {
session, err := quic.DialAddr(edgeAddr.String(), tlsConfig, quicConfig)
if err != nil {
return nil, errors.Wrap(err, "failed to dial to edge")
}
//TODO: RegisterConnectionRPC here.
registrationStream, err := session.OpenStream()
if err != nil {
return nil, errors.Wrap(err, "failed to open a registration stream")
}
go controlStreamHandler.ServeControlStream(ctx, registrationStream, connOptions)
return &QUICConnection{
session: session,
httpProxy: httpProxy,
logger: logger,
logger: observer.log,
}, nil
}
@ -96,11 +106,26 @@ func (q *QUICConnection) handleStream(stream quic.Stream) error {
w := newHTTPResponseAdapter(stream)
return q.httpProxy.ProxyHTTP(w, req, connectRequest.Type == quicpogs.ConnectionTypeWebsocket)
case quicpogs.ConnectionTypeTCP:
return errors.New("not implemented")
// TODO: This is a placeholder for testing completion. TUN-4865 will add proper TCP support.
rwa := &streamReadWriteAcker{
ReadWriter: stream,
}
return q.httpProxy.ProxyTCP(context.Background(), rwa, &TCPRequest{Dest: connectRequest.Dest})
}
return nil
}
// streamReadWriteAcker is a light wrapper over QUIC streams with a callback to send response back to
// the client.
type streamReadWriteAcker struct {
io.ReadWriter
}
// AckConnection acks response back to the proxy.
func (s *streamReadWriteAcker) AckConnection() error {
return quicpogs.WriteConnectResponseData(s, nil)
}
// httpResponseAdapter translates responses written by the HTTP Proxy into ones that can be used in QUIC.
type httpResponseAdapter struct {
io.Writer

View File

@ -16,6 +16,7 @@ import (
"os"
"sync"
"testing"
"time"
"github.com/gobwas/ws/wsutil"
"github.com/lucas-clemente/quic-go"
@ -25,6 +26,7 @@ import (
"github.com/stretchr/testify/require"
quicpogs "github.com/cloudflare/cloudflared/quic"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
// TestQUICServer tests if a quic server accepts and responds to a quic client with the acceptance protocol.
@ -134,6 +136,13 @@ func TestQUICServer(t *testing.T) {
message: wsBuf.Bytes(),
expectedResponse: []byte{0x81, 0x5, 0x48, 0x65, 0x6c, 0x6c, 0x6f},
},
{
desc: "test tcp proxy",
connectionType: quicpogs.ConnectionTypeTCP,
metadata: []quicpogs.Metadata{},
message: []byte("Here is some tcp data"),
expectedResponse: []byte("Here is some tcp data"),
},
}
for _, test := range tests {
@ -149,11 +158,43 @@ func TestQUICServer(t *testing.T) {
)
}()
qC, err := NewQUICConnection(ctx, quicConfig, udpListener.LocalAddr(), tlsClientConfig, originProxy, log)
rpcClientFactory := mockRPCClientFactory{
registered: make(chan struct{}),
unregistered: make(chan struct{}),
}
obs := NewObserver(&log, &log, false)
controlStream := NewControlStream(
obs,
mockConnectedFuse{},
&NamedTunnelConfig{},
1,
rpcClientFactory.newMockRPCClient,
nil,
1*time.Second,
)
qC, err := NewQUICConnection(
ctx,
quicConfig,
udpListener.LocalAddr(),
tlsClientConfig,
originProxy,
&pogs.ConnectionOptions{},
controlStream,
NewObserver(&log, &log, false),
)
require.NoError(t, err)
go qC.Serve(ctx)
wg.Wait()
select {
case <-rpcClientFactory.registered:
break //ok
case <-time.Tick(time.Second):
t.Fatal("timeout out waiting for registration")
}
cancel()
})
}
@ -174,7 +215,7 @@ func quicServer(
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
earlyListener, err := quic.ListenEarly(conn, tlsConf, config)
earlyListener, err := quic.Listen(conn, tlsConf, config)
require.NoError(t, err)
session, err := earlyListener.Accept(ctx)
@ -183,7 +224,6 @@ func quicServer(
stream, err := session.OpenStreamSync(context.Background())
require.NoError(t, err)
// Start off ALPN
err = quicpogs.WriteConnectRequestData(stream, dest, connectionType, metadata...)
require.NoError(t, err)
@ -264,5 +304,7 @@ func (moc *mockOriginProxyWithRequest) ProxyHTTP(w ResponseWriter, r *http.Reque
}
func (moc *mockOriginProxyWithRequest) ProxyTCP(ctx context.Context, rwa ReadWriteAcker, tcpRequest *TCPRequest) error {
rwa.AckConnection()
io.Copy(rwa, rwa)
return nil
}

View File

@ -11,6 +11,7 @@ import (
"time"
"github.com/google/uuid"
"github.com/lucas-clemente/quic-go"
"github.com/pkg/errors"
"github.com/rs/zerolog"
"golang.org/x/sync/errgroup"
@ -271,15 +272,38 @@ func ServeTunnel(
defer config.Observer.SendDisconnect(connIndex)
connectedFuse := &connectedFuse{
fuse: fuse,
backoff: backoff,
}
controlStream := connection.NewControlStream(
config.Observer,
connectedFuse,
config.NamedTunnel,
connIndex,
nil,
gracefulShutdownC,
config.ConnectionConfig.GracePeriod,
)
if protocol == connection.QUIC {
connOptions := config.ConnectionOptions(addr.UDP.String(), uint8(backoff.Retries()))
return ServeQUIC(ctx,
addr.UDP,
config,
connOptions,
controlStream,
connectedFuse,
reconnectCh,
gracefulShutdownC)
}
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, config.EdgeTLSConfigs[protocol], addr.TCP)
if err != nil {
connLog.Err(err).Msg("Unable to establish connection with Cloudflare edge")
return err, true
}
connectedFuse := &connectedFuse{
fuse: fuse,
backoff: backoff,
}
if protocol == connection.HTTP2 {
connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.Retries()))
@ -289,10 +313,10 @@ func ServeTunnel(
config,
edgeConn,
connOptions,
controlStream,
connIndex,
connectedFuse,
reconnectCh,
gracefulShutdownC,
reconnectCh,
)
} else {
err = ServeH2mux(
@ -403,22 +427,20 @@ func ServeHTTP2(
config *TunnelConfig,
tlsServerConn net.Conn,
connOptions *tunnelpogs.ConnectionOptions,
controlStreamHandler connection.ControlStreamHandler,
connIndex uint8,
connectedFuse connection.ConnectedFuse,
reconnectCh chan ReconnectSignal,
gracefulShutdownC <-chan struct{},
reconnectCh chan ReconnectSignal,
) error {
connLog.Debug().Msgf("Connecting via http2")
h2conn := connection.NewHTTP2Connection(
tlsServerConn,
config.ConnectionConfig,
config.NamedTunnel,
connOptions,
config.Observer,
connIndex,
connectedFuse,
controlStreamHandler,
config.Log,
gracefulShutdownC,
)
errGroup, serveCtx := errgroup.WithContext(ctx)
@ -438,6 +460,75 @@ func ServeHTTP2(
return errGroup.Wait()
}
func ServeQUIC(
ctx context.Context,
edgeAddr *net.UDPAddr,
config *TunnelConfig,
connOptions *tunnelpogs.ConnectionOptions,
controlStreamHandler connection.ControlStreamHandler,
connectedFuse connection.ConnectedFuse,
reconnectCh chan ReconnectSignal,
gracefulShutdownC <-chan struct{},
) (err error, recoverable bool) {
tlsConfig := config.EdgeTLSConfigs[connection.QUIC]
quicConfig := &quic.Config{
HandshakeIdleTimeout: time.Second * 10,
KeepAlive: true,
}
for {
select {
case <-ctx.Done():
return
default:
quicConn, err := connection.NewQUICConnection(
ctx,
quicConfig,
edgeAddr,
tlsConfig,
config.ConnectionConfig.OriginProxy,
connOptions,
controlStreamHandler,
config.Observer)
if err != nil {
config.Log.Error().Msgf("Failed to create new quic connection, err: %v", err)
return err, true
}
errGroup, serveCtx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
err := quicConn.Serve(ctx)
if err != nil {
config.Log.Error().Msgf("Failed to serve quic connection, err: %v", err)
}
return fmt.Errorf("Connection with edge closed")
})
errGroup.Go(func() error {
return listenReconnect(serveCtx, reconnectCh, gracefulShutdownC)
})
err = errGroup.Wait()
if err == nil {
return nil, false
}
config.Log.Info().Msg("Reconnecting with the same udp conn")
}
}
}
type quicLogger struct {
*zerolog.Logger
}
func (ql *quicLogger) Write(p []byte) (n int, err error) {
ql.Debug().Msgf("quic log: %v", string(p))
return len(p), nil
}
func (ql *quicLogger) Close() error {
return nil
}
func listenReconnect(ctx context.Context, reconnectCh <-chan ReconnectSignal, gracefulShutdownCh <-chan struct{}) error {
select {
case reconnect := <-reconnectCh: