222 lines
5.8 KiB
Go
222 lines
5.8 KiB
Go
package connection
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"io"
|
|
"math"
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/cloudflare/cloudflared/h2mux"
|
|
"github.com/cloudflare/cloudflared/logger"
|
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
|
|
|
"golang.org/x/net/http2"
|
|
)
|
|
|
|
const (
|
|
internalUpgradeHeader = "Cf-Cloudflared-Proxy-Connection-Upgrade"
|
|
websocketUpgrade = "websocket"
|
|
controlStreamUpgrade = "control-stream"
|
|
)
|
|
|
|
var (
|
|
errNotFlusher = errors.New("ResponseWriter doesn't implement http.Flusher")
|
|
)
|
|
|
|
type http2Connection struct {
|
|
conn net.Conn
|
|
server *http2.Server
|
|
config *Config
|
|
namedTunnel *NamedTunnelConfig
|
|
connOptions *tunnelpogs.ConnectionOptions
|
|
observer *Observer
|
|
connIndexStr string
|
|
connIndex uint8
|
|
wg *sync.WaitGroup
|
|
// newRPCClientFunc allows us to mock RPCs during testing
|
|
newRPCClientFunc func(context.Context, io.ReadWriteCloser, logger.Service) NamedTunnelRPCClient
|
|
connectedFuse ConnectedFuse
|
|
}
|
|
|
|
func NewHTTP2Connection(
|
|
conn net.Conn,
|
|
config *Config,
|
|
namedTunnelConfig *NamedTunnelConfig,
|
|
connOptions *tunnelpogs.ConnectionOptions,
|
|
observer *Observer,
|
|
connIndex uint8,
|
|
connectedFuse ConnectedFuse,
|
|
) *http2Connection {
|
|
return &http2Connection{
|
|
conn: conn,
|
|
server: &http2.Server{
|
|
MaxConcurrentStreams: math.MaxUint32,
|
|
},
|
|
config: config,
|
|
namedTunnel: namedTunnelConfig,
|
|
connOptions: connOptions,
|
|
observer: observer,
|
|
connIndexStr: uint8ToString(connIndex),
|
|
connIndex: connIndex,
|
|
wg: &sync.WaitGroup{},
|
|
newRPCClientFunc: newRegistrationRPCClient,
|
|
connectedFuse: connectedFuse,
|
|
}
|
|
}
|
|
|
|
func (c *http2Connection) Serve(ctx context.Context) {
|
|
go func() {
|
|
<-ctx.Done()
|
|
c.close()
|
|
}()
|
|
c.server.ServeConn(c.conn, &http2.ServeConnOpts{
|
|
Context: ctx,
|
|
Handler: c,
|
|
})
|
|
}
|
|
|
|
func (c *http2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
c.wg.Add(1)
|
|
defer c.wg.Done()
|
|
|
|
respWriter := &http2RespWriter{
|
|
r: r.Body,
|
|
w: w,
|
|
}
|
|
flusher, isFlusher := w.(http.Flusher)
|
|
if !isFlusher {
|
|
c.observer.Errorf("%T doesn't implement http.Flusher", w)
|
|
respWriter.WriteErrorResponse()
|
|
return
|
|
}
|
|
respWriter.flusher = flusher
|
|
var err error
|
|
if isControlStreamUpgrade(r) {
|
|
respWriter.shouldFlush = true
|
|
err = c.serveControlStream(r.Context(), respWriter)
|
|
} else if isWebsocketUpgrade(r) {
|
|
respWriter.shouldFlush = true
|
|
stripWebsocketUpgradeHeader(r)
|
|
err = c.config.OriginClient.Proxy(respWriter, r, true)
|
|
} else {
|
|
err = c.config.OriginClient.Proxy(respWriter, r, false)
|
|
}
|
|
|
|
if err != nil {
|
|
respWriter.WriteErrorResponse()
|
|
}
|
|
}
|
|
|
|
func (c *http2Connection) serveControlStream(ctx context.Context, respWriter *http2RespWriter) error {
|
|
rpcClient := c.newRPCClientFunc(ctx, respWriter, c.observer)
|
|
defer rpcClient.Close()
|
|
|
|
if err := rpcClient.RegisterConnection(ctx, c.namedTunnel, c.connOptions, c.connIndex, c.observer); err != nil {
|
|
return err
|
|
}
|
|
c.connectedFuse.Connected()
|
|
|
|
<-ctx.Done()
|
|
rpcClient.GracefulShutdown(ctx, c.config.GracePeriod)
|
|
return nil
|
|
}
|
|
|
|
func (c *http2Connection) close() {
|
|
// Wait for all serve HTTP handlers to return
|
|
c.wg.Wait()
|
|
c.conn.Close()
|
|
}
|
|
|
|
type http2RespWriter struct {
|
|
r io.Reader
|
|
w http.ResponseWriter
|
|
flusher http.Flusher
|
|
shouldFlush bool
|
|
}
|
|
|
|
func (rp *http2RespWriter) WriteRespHeaders(resp *http.Response) error {
|
|
dest := rp.w.Header()
|
|
userHeaders := make(http.Header, len(resp.Header))
|
|
for header, values := range resp.Header {
|
|
// Since these are http2 headers, they're required to be lowercase
|
|
h2name := strings.ToLower(header)
|
|
for _, v := range values {
|
|
if h2name == "content-length" {
|
|
// This header has meaning in HTTP/2 and will be used by the edge,
|
|
// so it should be sent as an HTTP/2 response header.
|
|
dest.Add(h2name, v)
|
|
// Since these are http2 headers, they're required to be lowercase
|
|
} else if !h2mux.IsControlHeader(h2name) || h2mux.IsWebsocketClientHeader(h2name) {
|
|
// User headers, on the other hand, must all be serialized so that
|
|
// HTTP/2 header validation won't be applied to HTTP/1 header values
|
|
userHeaders.Add(h2name, v)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Perform user header serialization and set them in the single header
|
|
dest.Set(canonicalResponseUserHeadersField, h2mux.SerializeHeaders(userHeaders))
|
|
rp.setResponseMetaHeader(responseMetaHeaderOrigin)
|
|
status := resp.StatusCode
|
|
// HTTP2 removes support for 101 Switching Protocols https://tools.ietf.org/html/rfc7540#section-8.1.1
|
|
if status == http.StatusSwitchingProtocols {
|
|
status = http.StatusOK
|
|
}
|
|
rp.w.WriteHeader(status)
|
|
if IsServerSentEvent(resp.Header) {
|
|
rp.shouldFlush = true
|
|
}
|
|
if rp.shouldFlush {
|
|
rp.flusher.Flush()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (rp *http2RespWriter) WriteErrorResponse() {
|
|
rp.setResponseMetaHeader(responseMetaHeaderCfd)
|
|
rp.w.WriteHeader(http.StatusBadGateway)
|
|
}
|
|
|
|
func (rp *http2RespWriter) setResponseMetaHeader(value string) {
|
|
rp.w.Header().Set(canonicalResponseMetaHeaderField, value)
|
|
}
|
|
|
|
func (rp *http2RespWriter) Read(p []byte) (n int, err error) {
|
|
return rp.r.Read(p)
|
|
}
|
|
|
|
func (rp *http2RespWriter) Write(p []byte) (n int, err error) {
|
|
defer func() {
|
|
// Implementer of OriginClient should make sure it doesn't write to the connection after Proxy returns
|
|
// Register a recover routine just in case.
|
|
if r := recover(); r != nil {
|
|
println("Recover from http2 response writer panic, error", r)
|
|
}
|
|
}()
|
|
n, err = rp.w.Write(p)
|
|
if err == nil && rp.shouldFlush {
|
|
rp.flusher.Flush()
|
|
}
|
|
return n, err
|
|
}
|
|
|
|
func (rp *http2RespWriter) Close() error {
|
|
return nil
|
|
}
|
|
|
|
func isControlStreamUpgrade(r *http.Request) bool {
|
|
return strings.ToLower(r.Header.Get(internalUpgradeHeader)) == controlStreamUpgrade
|
|
}
|
|
|
|
func isWebsocketUpgrade(r *http.Request) bool {
|
|
return strings.ToLower(r.Header.Get(internalUpgradeHeader)) == websocketUpgrade
|
|
}
|
|
|
|
func stripWebsocketUpgradeHeader(r *http.Request) {
|
|
r.Header.Del(internalUpgradeHeader)
|
|
}
|