cloudflared-mirror/origin/h2mux.go

230 lines
7.1 KiB
Go

package origin
import (
"bufio"
"context"
"io"
"net"
"net/http"
"strconv"
"github.com/cloudflare/cloudflared/buffer"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/logger"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/websocket"
"github.com/pkg/errors"
)
type TunnelHandler struct {
ingressRules ingress.Ingress
muxer *h2mux.Muxer
tags []tunnelpogs.Tag
metrics *TunnelMetrics
// connectionID is only used by metrics, and prometheus requires labels to be string
connectionID string
logger logger.Service
bufferPool *buffer.Pool
}
// NewTunnelHandler returns a TunnelHandler, origin LAN IP and error
func NewTunnelHandler(ctx context.Context,
config *TunnelConfig,
addr *net.TCPAddr,
connectionID uint8,
bufferPool *buffer.Pool,
) (*TunnelHandler, string, error) {
h := &TunnelHandler{
ingressRules: config.IngressRules,
tags: config.Tags,
metrics: config.Metrics,
connectionID: uint8ToString(connectionID),
logger: config.Logger,
bufferPool: bufferPool,
}
edgeConn, err := connection.DialEdge(ctx, dialTimeout, config.TlsConfig, addr)
if err != nil {
return nil, "", err
}
// Establish a muxed connection with the edge
// Client mux handshake with agent server
h.muxer, err = h2mux.Handshake(edgeConn, edgeConn, config.muxerConfig(h), h.metrics.activeStreams)
if err != nil {
return nil, "", errors.Wrap(err, "h2mux handshake with edge error")
}
return h, edgeConn.LocalAddr().String(), nil
}
func (h *TunnelHandler) AppendTagHeaders(r *http.Request) {
for _, tag := range h.tags {
r.Header.Add(TagHeaderNamePrefix+tag.Name, tag.Value)
}
}
func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
h.metrics.incrementRequests(h.connectionID)
defer h.metrics.decrementConcurrentRequests(h.connectionID)
req, rule, reqErr := h.createRequest(stream)
if reqErr != nil {
h.writeErrorResponse(stream, reqErr)
return reqErr
}
cfRay := findCfRayHeader(req)
lbProbe := isLBProbeRequest(req)
h.logRequest(req, cfRay, lbProbe)
var resp *http.Response
var respErr error
if websocket.IsWebSocketUpgrade(req) {
resp, respErr = serveWebsocket(&h2muxWebsocketResp{stream}, req, rule)
} else {
resp, respErr = h.serveHTTP(stream, req, rule)
}
if respErr != nil {
h.writeErrorResponse(stream, respErr)
return respErr
}
h.logResponseOk(resp, cfRay, lbProbe)
return nil
}
func (h *TunnelHandler) createRequest(stream *h2mux.MuxedStream) (*http.Request, *ingress.Rule, error) {
req, err := http.NewRequest("GET", "http://localhost:8080", h2mux.MuxedStreamReader{MuxedStream: stream})
if err != nil {
return nil, nil, errors.Wrap(err, "Unexpected error from http.NewRequest")
}
err = h2mux.H2RequestHeadersToH1Request(stream.Headers, req)
if err != nil {
return nil, nil, errors.Wrap(err, "invalid request received")
}
rule, _ := h.ingressRules.FindMatchingRule(req.Host, req.URL.Path)
rule.Service.RewriteOriginURL(req.URL)
return req, rule, nil
}
func (h *TunnelHandler) serveHTTP(stream *h2mux.MuxedStream, req *http.Request, rule *ingress.Rule) (*http.Response, error) {
// Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate
if rule.Config.DisableChunkedEncoding {
req.TransferEncoding = []string{"gzip", "deflate"}
cLength, err := strconv.Atoi(req.Header.Get("Content-Length"))
if err == nil {
req.ContentLength = int64(cLength)
}
}
// Request origin to keep connection alive to improve performance
req.Header.Set("Connection", "keep-alive")
if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" {
req.Header.Set("Host", hostHeader)
req.Host = hostHeader
}
response, err := h.httpClient.RoundTrip(req)
if err != nil {
return nil, errors.Wrap(err, "Error proxying request to origin")
}
defer response.Body.Close()
headers := h2mux.H1ResponseToH2ResponseHeaders(response)
headers = append(headers, h2mux.CreateResponseMetaHeader(h2mux.ResponseMetaHeaderField, h2mux.ResponseSourceOrigin))
err = stream.WriteHeaders(headers)
if err != nil {
return nil, errors.Wrap(err, "Error writing response header")
}
if h.isEventStream(response) {
h.writeEventStream(stream, response.Body)
} else {
// Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream
// compression generates dictionary on first write
buf := h.bufferPool.Get()
defer h.bufferPool.Put(buf)
io.CopyBuffer(stream, response.Body, buf)
}
return response, nil
}
func (h *TunnelHandler) writeEventStream(stream *h2mux.MuxedStream, responseBody io.ReadCloser) {
reader := bufio.NewReader(responseBody)
for {
line, err := reader.ReadBytes('\n')
if err != nil {
break
}
stream.Write(line)
}
}
func (h *TunnelHandler) isEventStream(response *http.Response) bool {
if response.Header.Get("content-type") == "text/event-stream" {
h.logger.Debug("Detected Server-Side Events from Origin")
return true
}
return false
}
func (h *TunnelHandler) writeErrorResponse(stream *h2mux.MuxedStream, err error) {
h.logger.Errorf("HTTP request error: %s", err)
stream.WriteHeaders([]h2mux.Header{
{Name: ":status", Value: "502"},
h2mux.CreateResponseMetaHeader(h2mux.ResponseMetaHeaderField, h2mux.ResponseSourceCloudflared),
})
stream.Write([]byte("502 Bad Gateway"))
h.metrics.incrementResponses(h.connectionID, "502")
}
func (h *TunnelHandler) logRequest(req *http.Request, cfRay string, lbProbe bool) {
logger := h.logger
if cfRay != "" {
logger.Debugf("CF-RAY: %s %s %s %s", cfRay, req.Method, req.URL, req.Proto)
} else if lbProbe {
logger.Debugf("CF-RAY: %s Load Balancer health check %s %s %s", cfRay, req.Method, req.URL, req.Proto)
} else {
logger.Infof("CF-RAY: %s All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", cfRay, req.Method, req.URL, req.Proto)
}
logger.Debugf("CF-RAY: %s Request Headers %+v", cfRay, req.Header)
if contentLen := req.ContentLength; contentLen == -1 {
logger.Debugf("CF-RAY: %s Request Content length unknown", cfRay)
} else {
logger.Debugf("CF-RAY: %s Request content length %d", cfRay, contentLen)
}
}
func (h *TunnelHandler) logResponseOk(r *http.Response, cfRay string, lbProbe bool) {
h.metrics.incrementResponses(h.connectionID, "200")
logger := h.logger
if cfRay != "" {
logger.Debugf("CF-RAY: %s %s", cfRay, r.Status)
} else if lbProbe {
logger.Debugf("Response to Load Balancer health check %s", r.Status)
} else {
logger.Infof("%s", r.Status)
}
logger.Debugf("CF-RAY: %s Response Headers %+v", cfRay, r.Header)
if contentLen := r.ContentLength; contentLen == -1 {
logger.Debugf("CF-RAY: %s Response content length unknown", cfRay)
} else {
logger.Debugf("CF-RAY: %s Response content length %d", cfRay, contentLen)
}
}
func (h *TunnelHandler) UpdateMetrics(connectionID string) {
h.metrics.updateMuxerMetrics(connectionID, h.muxer.Metrics())
}
type h2muxWebsocketResp struct {
*h2mux.MuxedStream
}
func (wr *h2muxWebsocketResp) WriteRespHeaders(resp *http.Response) error {
return wr.WriteHeaders(h2mux.H1ResponseToH2ResponseHeaders(resp))
}