TUN-6576: Consume cf-trace-id from incoming TCP requests to create root span

This commit is contained in:
Devin Carr 2022-07-26 14:00:53 -07:00
parent d96c39196d
commit f48a7cd3dd
13 changed files with 168 additions and 62 deletions

View File

@ -123,23 +123,24 @@ func (t Type) String() string {
// OriginProxy is how data flows from cloudflared to the origin services running behind it. // OriginProxy is how data flows from cloudflared to the origin services running behind it.
type OriginProxy interface { type OriginProxy interface {
ProxyHTTP(w ResponseWriter, tr *tracing.TracedRequest, isWebsocket bool) error ProxyHTTP(w ResponseWriter, tr *tracing.TracedHTTPRequest, isWebsocket bool) error
ProxyTCP(ctx context.Context, rwa ReadWriteAcker, req *TCPRequest) error ProxyTCP(ctx context.Context, rwa ReadWriteAcker, req *TCPRequest) error
} }
// TCPRequest defines the input format needed to perform a TCP proxy. // TCPRequest defines the input format needed to perform a TCP proxy.
type TCPRequest struct { type TCPRequest struct {
Dest string Dest string
CFRay string CFRay string
LBProbe bool LBProbe bool
FlowID string FlowID string
CfTraceID string
} }
// ReadWriteAcker is a readwriter with the ability to Acknowledge to the downstream (edge) that the origin has // ReadWriteAcker is a readwriter with the ability to Acknowledge to the downstream (edge) that the origin has
// accepted the connection. // accepted the connection.
type ReadWriteAcker interface { type ReadWriteAcker interface {
io.ReadWriter io.ReadWriter
AckConnection() error AckConnection(tracePropagation string) error
} }
// HTTPResponseReadWriteAcker is an HTTP implementation of ReadWriteAcker. // HTTPResponseReadWriteAcker is an HTTP implementation of ReadWriteAcker.
@ -168,7 +169,7 @@ func (h *HTTPResponseReadWriteAcker) Write(p []byte) (int, error) {
// AckConnection acks an HTTP connection by sending a switch protocols status code that enables the caller to // AckConnection acks an HTTP connection by sending a switch protocols status code that enables the caller to
// upgrade to streams. // upgrade to streams.
func (h *HTTPResponseReadWriteAcker) AckConnection() error { func (h *HTTPResponseReadWriteAcker) AckConnection(tracePropagation string) error {
resp := &http.Response{ resp := &http.Response{
Status: switchingProtocolText, Status: switchingProtocolText,
StatusCode: http.StatusSwitchingProtocols, StatusCode: http.StatusSwitchingProtocols,
@ -179,6 +180,10 @@ func (h *HTTPResponseReadWriteAcker) AckConnection() error {
resp.Header = websocket.NewResponseHeader(h.req) resp.Header = websocket.NewResponseHeader(h.req)
} }
if tracePropagation != "" {
resp.Header.Add(tracing.CanonicalCloudflaredTracingHeader, tracePropagation)
}
return h.w.WriteRespHeaders(resp.StatusCode, resp.Header) return h.w.WriteRespHeaders(resp.StatusCode, resp.Header)
} }

View File

@ -30,6 +30,8 @@ var (
testLargeResp = make([]byte, largeFileSize) testLargeResp = make([]byte, largeFileSize)
) )
var _ ReadWriteAcker = (*HTTPResponseReadWriteAcker)(nil)
type testRequest struct { type testRequest struct {
name string name string
endpoint string endpoint string
@ -60,7 +62,7 @@ type mockOriginProxy struct{}
func (moc *mockOriginProxy) ProxyHTTP( func (moc *mockOriginProxy) ProxyHTTP(
w ResponseWriter, w ResponseWriter,
tr *tracing.TracedRequest, tr *tracing.TracedHTTPRequest,
isWebsocket bool, isWebsocket bool,
) error { ) error {
req := tr.Request req := tr.Request

View File

@ -69,6 +69,7 @@ func NewH2muxConnection(
connIndex uint8, connIndex uint8,
observer *Observer, observer *Observer,
gracefulShutdownC <-chan struct{}, gracefulShutdownC <-chan struct{},
log *zerolog.Logger,
) (*h2muxConnection, error, bool) { ) (*h2muxConnection, error, bool) {
h := &h2muxConnection{ h := &h2muxConnection{
orchestrator: orchestrator, orchestrator: orchestrator,
@ -79,6 +80,7 @@ func NewH2muxConnection(
observer: observer, observer: observer,
gracefulShutdownC: gracefulShutdownC, gracefulShutdownC: gracefulShutdownC,
newRPCClientFunc: newRegistrationRPCClient, newRPCClientFunc: newRegistrationRPCClient,
log: log,
} }
// Establish a muxed connection with the edge // Establish a muxed connection with the edge
@ -234,7 +236,7 @@ func (h *h2muxConnection) ServeStream(stream *h2mux.MuxedStream) error {
return err return err
} }
err = originProxy.ProxyHTTP(respWriter, tracing.NewTracedRequest(req), sourceConnectionType == TypeWebsocket) err = originProxy.ProxyHTTP(respWriter, tracing.NewTracedHTTPRequest(req, h.log), sourceConnectionType == TypeWebsocket)
if err != nil { if err != nil {
respWriter.WriteErrorResponse() respWriter.WriteErrorResponse()
} }

View File

@ -48,7 +48,7 @@ func newH2MuxConnection(t require.TestingT) (*h2muxConnection, *h2mux.Muxer) {
}() }()
var connIndex = uint8(0) var connIndex = uint8(0)
testObserver := NewObserver(&log, &log) testObserver := NewObserver(&log, &log)
h2muxConn, err, _ := NewH2muxConnection(testOrchestrator, testGracePeriod, testMuxerConfig, originConn, connIndex, testObserver, nil) h2muxConn, err, _ := NewH2muxConnection(testOrchestrator, testGracePeriod, testMuxerConfig, originConn, connIndex, testObserver, nil, &log)
require.NoError(t, err) require.NoError(t, err)
return h2muxConn, <-edgeMuxChan return h2muxConn, <-edgeMuxChan
} }

View File

@ -132,7 +132,7 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
case TypeWebsocket, TypeHTTP: case TypeWebsocket, TypeHTTP:
stripWebsocketUpgradeHeader(r) stripWebsocketUpgradeHeader(r)
// Check for tracing on request // Check for tracing on request
tr := tracing.NewTracedRequest(r) tr := tracing.NewTracedHTTPRequest(r, c.log)
if err := originProxy.ProxyHTTP(respWriter, tr, connType == TypeWebsocket); err != nil { if err := originProxy.ProxyHTTP(respWriter, tr, connType == TypeWebsocket); err != nil {
err := fmt.Errorf("Failed to proxy HTTP: %w", err) err := fmt.Errorf("Failed to proxy HTTP: %w", err)
c.log.Error().Err(err) c.log.Error().Err(err)

View File

@ -197,7 +197,7 @@ func (q *QUICConnection) dispatchRequest(ctx context.Context, stream *quicpogs.R
switch request.Type { switch request.Type {
case quicpogs.ConnectionTypeHTTP, quicpogs.ConnectionTypeWebsocket: case quicpogs.ConnectionTypeHTTP, quicpogs.ConnectionTypeWebsocket:
tracedReq, err := buildHTTPRequest(ctx, request, stream) tracedReq, err := buildHTTPRequest(ctx, request, stream, q.logger)
if err != nil { if err != nil {
return err return err
} }
@ -208,8 +208,9 @@ func (q *QUICConnection) dispatchRequest(ctx context.Context, stream *quicpogs.R
rwa := &streamReadWriteAcker{stream} rwa := &streamReadWriteAcker{stream}
metadata := request.MetadataMap() metadata := request.MetadataMap()
return originProxy.ProxyTCP(ctx, rwa, &TCPRequest{ return originProxy.ProxyTCP(ctx, rwa, &TCPRequest{
Dest: request.Dest, Dest: request.Dest,
FlowID: metadata[QUICMetadataFlowID], FlowID: metadata[QUICMetadataFlowID],
CfTraceID: metadata[tracing.TracerContextName],
}) })
} }
return nil return nil
@ -296,8 +297,12 @@ type streamReadWriteAcker struct {
} }
// AckConnection acks response back to the proxy. // AckConnection acks response back to the proxy.
func (s *streamReadWriteAcker) AckConnection() error { func (s *streamReadWriteAcker) AckConnection(tracePropagation string) error {
return s.WriteConnectResponseData(nil) metadata := quicpogs.Metadata{
Key: tracing.CanonicalCloudflaredTracingHeader,
Val: tracePropagation,
}
return s.WriteConnectResponseData(nil, metadata)
} }
// httpResponseAdapter translates responses written by the HTTP Proxy into ones that can be used in QUIC. // httpResponseAdapter translates responses written by the HTTP Proxy into ones that can be used in QUIC.
@ -325,7 +330,12 @@ func (hrw httpResponseAdapter) WriteErrorResponse(err error) {
hrw.WriteConnectResponseData(err, quicpogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)}) hrw.WriteConnectResponseData(err, quicpogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)})
} }
func buildHTTPRequest(ctx context.Context, connectRequest *quicpogs.ConnectRequest, body io.ReadCloser) (*tracing.TracedRequest, error) { func buildHTTPRequest(
ctx context.Context,
connectRequest *quicpogs.ConnectRequest,
body io.ReadCloser,
log *zerolog.Logger,
) (*tracing.TracedHTTPRequest, error) {
metadata := connectRequest.MetadataMap() metadata := connectRequest.MetadataMap()
dest := connectRequest.Dest dest := connectRequest.Dest
method := metadata[HTTPMethodKey] method := metadata[HTTPMethodKey]
@ -367,7 +377,7 @@ func buildHTTPRequest(ctx context.Context, connectRequest *quicpogs.ConnectReque
stripWebsocketUpgradeHeader(req) stripWebsocketUpgradeHeader(req)
// Check for tracing on request // Check for tracing on request
tracedReq := tracing.NewTracedRequest(req) tracedReq := tracing.NewTracedHTTPRequest(req, log)
return tracedReq, err return tracedReq, err
} }

View File

@ -36,6 +36,8 @@ var (
} }
) )
var _ ReadWriteAcker = (*streamReadWriteAcker)(nil)
// TestQUICServer tests if a quic server accepts and responds to a quic client with the acceptance protocol. // TestQUICServer tests if a quic server accepts and responds to a quic client with the acceptance protocol.
// It also serves as a demonstration for communication with the QUIC connection started by a cloudflared. // It also serves as a demonstration for communication with the QUIC connection started by a cloudflared.
func TestQUICServer(t *testing.T) { func TestQUICServer(t *testing.T) {
@ -220,7 +222,7 @@ func quicServer(
type mockOriginProxyWithRequest struct{} type mockOriginProxyWithRequest struct{}
func (moc *mockOriginProxyWithRequest) ProxyHTTP(w ResponseWriter, tr *tracing.TracedRequest, isWebsocket bool) error { func (moc *mockOriginProxyWithRequest) ProxyHTTP(w ResponseWriter, tr *tracing.TracedHTTPRequest, isWebsocket bool) error {
// These are a series of crude tests to ensure the headers and http related data is transferred from // These are a series of crude tests to ensure the headers and http related data is transferred from
// metadata. // metadata.
r := tr.Request r := tr.Request
@ -475,9 +477,10 @@ func TestBuildHTTPRequest(t *testing.T) {
}, },
} }
log := zerolog.Nop()
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
req, err := buildHTTPRequest(context.Background(), test.connectRequest, test.body) req, err := buildHTTPRequest(context.Background(), test.connectRequest, test.body, &log)
assert.NoError(t, err) assert.NoError(t, err)
test.req = test.req.WithContext(req.Context()) test.req = test.req.WithContext(req.Context())
assert.Equal(t, test.req, req.Request) assert.Equal(t, test.req, req.Request)
@ -486,7 +489,7 @@ func TestBuildHTTPRequest(t *testing.T) {
} }
func (moc *mockOriginProxyWithRequest) ProxyTCP(ctx context.Context, rwa ReadWriteAcker, tcpRequest *TCPRequest) error { func (moc *mockOriginProxyWithRequest) ProxyTCP(ctx context.Context, rwa ReadWriteAcker, tcpRequest *TCPRequest) error {
rwa.AckConnection() rwa.AckConnection("")
io.Copy(rwa, rwa) io.Copy(rwa, rwa)
return nil return nil
} }

View File

@ -355,7 +355,7 @@ func proxyHTTP(originProxy connection.OriginProxy, hostname string) (*http.Respo
return nil, err return nil, err
} }
err = originProxy.ProxyHTTP(respWriter, tracing.NewTracedRequest(req), false) err = originProxy.ProxyHTTP(respWriter, tracing.NewTracedHTTPRequest(req, &log), false)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -604,7 +604,7 @@ func TestPersistentConnection(t *testing.T) {
respWriter, err := connection.NewHTTP2RespWriter(req, wsRespReadWriter, connection.TypeWebsocket, &log) respWriter, err := connection.NewHTTP2RespWriter(req, wsRespReadWriter, connection.TypeWebsocket, &log)
require.NoError(t, err) require.NoError(t, err)
err = originProxy.ProxyHTTP(respWriter, tracing.NewTracedRequest(req), true) err = originProxy.ProxyHTTP(respWriter, tracing.NewTracedHTTPRequest(req, &log), true)
require.NoError(t, err) require.NoError(t, err)
}() }()

View File

@ -63,7 +63,7 @@ func NewOriginProxy(
// a simple roundtrip or a tcp/websocket dial depending on ingres rule setup. // a simple roundtrip or a tcp/websocket dial depending on ingres rule setup.
func (p *Proxy) ProxyHTTP( func (p *Proxy) ProxyHTTP(
w connection.ResponseWriter, w connection.ResponseWriter,
tr *tracing.TracedRequest, tr *tracing.TracedHTTPRequest,
isWebsocket bool, isWebsocket bool,
) error { ) error {
incrementRequests() incrementRequests()
@ -108,7 +108,7 @@ func (p *Proxy) ProxyHTTP(
} }
rws := connection.NewHTTPResponseReadWriterAcker(w, req) rws := connection.NewHTTPResponseReadWriterAcker(w, req)
if err := p.proxyStream(req.Context(), rws, dest, originProxy); err != nil { if err := p.proxyStream(tr.ToTracedContext(), rws, dest, originProxy); err != nil {
rule, srv := ruleField(p.ingressRules, ruleNum) rule, srv := ruleField(p.ingressRules, ruleNum)
p.logRequestError(err, cfRay, "", rule, srv) p.logRequestError(err, cfRay, "", rule, srv)
return err return err
@ -137,9 +137,11 @@ func (p *Proxy) ProxyTCP(
serveCtx, cancel := context.WithCancel(ctx) serveCtx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
tracedCtx := tracing.NewTracedContext(serveCtx, req.CfTraceID, p.log)
p.log.Debug().Str(LogFieldFlowID, req.FlowID).Msg("tcp proxy stream started") p.log.Debug().Str(LogFieldFlowID, req.FlowID).Msg("tcp proxy stream started")
if err := p.proxyStream(serveCtx, rwa, req.Dest, p.warpRouting.Proxy); err != nil { if err := p.proxyStream(tracedCtx, rwa, req.Dest, p.warpRouting.Proxy); err != nil {
p.logRequestError(err, req.CFRay, req.FlowID, "", ingress.ServiceWarpRouting) p.logRequestError(err, req.CFRay, req.FlowID, "", ingress.ServiceWarpRouting)
return err return err
} }
@ -160,7 +162,7 @@ func ruleField(ing ingress.Ingress, ruleNum int) (ruleID string, srv string) {
// ProxyHTTPRequest proxies requests of underlying type http and websocket to the origin service. // ProxyHTTPRequest proxies requests of underlying type http and websocket to the origin service.
func (p *Proxy) proxyHTTPRequest( func (p *Proxy) proxyHTTPRequest(
w connection.ResponseWriter, w connection.ResponseWriter,
tr *tracing.TracedRequest, tr *tracing.TracedHTTPRequest,
httpService ingress.HTTPOriginProxy, httpService ingress.HTTPOriginProxy,
isWebsocket bool, isWebsocket bool,
disableChunkedEncoding bool, disableChunkedEncoding bool,
@ -211,7 +213,7 @@ func (p *Proxy) proxyHTTPRequest(
} }
// Add spans to response header (if available) // Add spans to response header (if available)
tr.AddSpans(resp.Header, p.log) tr.AddSpans(resp.Header)
err = w.WriteRespHeaders(resp.StatusCode, resp.Header) err = w.WriteRespHeaders(resp.StatusCode, resp.Header)
if err != nil { if err != nil {
@ -248,17 +250,23 @@ func (p *Proxy) proxyHTTPRequest(
// proxyStream proxies type TCP and other underlying types if the connection is defined as a stream oriented // proxyStream proxies type TCP and other underlying types if the connection is defined as a stream oriented
// ingress rule. // ingress rule.
func (p *Proxy) proxyStream( func (p *Proxy) proxyStream(
ctx context.Context, tr *tracing.TracedContext,
rwa connection.ReadWriteAcker, rwa connection.ReadWriteAcker,
dest string, dest string,
connectionProxy ingress.StreamBasedOriginProxy, connectionProxy ingress.StreamBasedOriginProxy,
) error { ) error {
ctx := tr.Context
_, connectSpan := tr.Tracer().Start(ctx, "stream_connect")
originConn, err := connectionProxy.EstablishConnection(ctx, dest) originConn, err := connectionProxy.EstablishConnection(ctx, dest)
if err != nil { if err != nil {
tracing.EndWithErrorStatus(connectSpan, err)
return err return err
} }
connectSpan.End()
if err := rwa.AckConnection(); err != nil { encodedSpans := tr.GetSpans()
if err := rwa.AckConnection(encodedSpans); err != nil {
return err return err
} }

View File

@ -157,7 +157,8 @@ func testProxyHTTP(proxy connection.OriginProxy) func(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "http://localhost:8080", nil) req, err := http.NewRequest(http.MethodGet, "http://localhost:8080", nil)
require.NoError(t, err) require.NoError(t, err)
err = proxy.ProxyHTTP(responseWriter, tracing.NewTracedRequest(req), false) log := zerolog.Nop()
err = proxy.ProxyHTTP(responseWriter, tracing.NewTracedHTTPRequest(req, &log), false)
require.NoError(t, err) require.NoError(t, err)
for _, tag := range testTags { for _, tag := range testTags {
assert.Equal(t, tag.Value, req.Header.Get(TagHeaderNamePrefix+tag.Name)) assert.Equal(t, tag.Value, req.Header.Get(TagHeaderNamePrefix+tag.Name))
@ -184,7 +185,8 @@ func testProxyWebsocket(proxy connection.OriginProxy) func(t *testing.T) {
errGroup, ctx := errgroup.WithContext(ctx) errGroup, ctx := errgroup.WithContext(ctx)
errGroup.Go(func() error { errGroup.Go(func() error {
err = proxy.ProxyHTTP(responseWriter, tracing.NewTracedRequest(req), true) log := zerolog.Nop()
err = proxy.ProxyHTTP(responseWriter, tracing.NewTracedHTTPRequest(req, &log), true)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, http.StatusSwitchingProtocols, responseWriter.Code) require.Equal(t, http.StatusSwitchingProtocols, responseWriter.Code)
@ -245,7 +247,8 @@ func testProxySSE(proxy connection.OriginProxy) func(t *testing.T) {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
err = proxy.ProxyHTTP(responseWriter, tracing.NewTracedRequest(req), false) log := zerolog.Nop()
err = proxy.ProxyHTTP(responseWriter, tracing.NewTracedHTTPRequest(req, &log), false)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, http.StatusOK, responseWriter.Code) require.Equal(t, http.StatusOK, responseWriter.Code)
@ -357,7 +360,7 @@ func runIngressTestScenarios(t *testing.T, unvalidatedIngress []config.Unvalidat
req, err := http.NewRequest(http.MethodGet, test.url, nil) req, err := http.NewRequest(http.MethodGet, test.url, nil)
require.NoError(t, err) require.NoError(t, err)
err = proxy.ProxyHTTP(responseWriter, tracing.NewTracedRequest(req), false) err = proxy.ProxyHTTP(responseWriter, tracing.NewTracedHTTPRequest(req, &log), false)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, test.expectedStatus, responseWriter.Code) assert.Equal(t, test.expectedStatus, responseWriter.Code)
@ -404,7 +407,7 @@ func TestProxyError(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil) req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
assert.NoError(t, err) assert.NoError(t, err)
assert.Error(t, proxy.ProxyHTTP(responseWriter, tracing.NewTracedRequest(req), false)) assert.Error(t, proxy.ProxyHTTP(responseWriter, tracing.NewTracedHTTPRequest(req, &log), false))
} }
type replayer struct { type replayer struct {
@ -682,7 +685,8 @@ func TestConnections(t *testing.T) {
rwa := connection.NewHTTPResponseReadWriterAcker(respWriter, req) rwa := connection.NewHTTPResponseReadWriterAcker(respWriter, req)
err = proxy.ProxyTCP(ctx, rwa, &connection.TCPRequest{Dest: dest}) err = proxy.ProxyTCP(ctx, rwa, &connection.TCPRequest{Dest: dest})
} else { } else {
err = proxy.ProxyHTTP(respWriter, tracing.NewTracedRequest(req), test.args.connectionType == connection.TypeWebsocket) log := zerolog.Nop()
err = proxy.ProxyHTTP(respWriter, tracing.NewTracedHTTPRequest(req, &log), test.args.connectionType == connection.TypeWebsocket)
} }
cancel() cancel()

View File

@ -557,6 +557,7 @@ func ServeH2mux(
connIndex, connIndex,
config.Observer, config.Observer,
gracefulShutdownC, gracefulShutdownC,
config.Log,
) )
if err != nil { if err != nil {
if !recoverable { if !recoverable {

View File

@ -8,6 +8,7 @@ import (
"net/http" "net/http"
"os" "os"
"runtime" "runtime"
"strings"
"github.com/rs/zerolog" "github.com/rs/zerolog"
otelContrib "go.opentelemetry.io/contrib/propagators/jaeger" otelContrib "go.opentelemetry.io/contrib/propagators/jaeger"
@ -33,6 +34,9 @@ const (
MaxErrorDescriptionLen = 100 MaxErrorDescriptionLen = 100
traceHttpStatusCodeKey = "upstreamStatusCode" traceHttpStatusCodeKey = "upstreamStatusCode"
traceID128bitsWidth = 128 / 4
separator = ":"
) )
var ( var (
@ -66,22 +70,50 @@ func Init(version string) {
cloudflaredVersionAttribute = semconv.ProcessRuntimeVersionKey.String(version) cloudflaredVersionAttribute = semconv.ProcessRuntimeVersionKey.String(version)
} }
type TracedRequest struct { type TracedHTTPRequest struct {
*http.Request *http.Request
trace.TracerProvider *cfdTracer
exporter InMemoryClient
} }
// NewTracedRequest creates a new tracer for the current request context. // NewTracedHTTPRequest creates a new tracer for the current HTTP request context.
func NewTracedRequest(req *http.Request) *TracedRequest { func NewTracedHTTPRequest(req *http.Request, log *zerolog.Logger) *TracedHTTPRequest {
ctx, exists := extractTrace(req) ctx, exists := extractTrace(req)
if !exists { if !exists {
return &TracedRequest{req, trace.NewNoopTracerProvider(), &NoopOtlpClient{}} return &TracedHTTPRequest{req, &cfdTracer{trace.NewNoopTracerProvider(), &NoopOtlpClient{}, log}}
} }
return &TracedHTTPRequest{req.WithContext(ctx), newCfdTracer(ctx, log)}
}
func (tr *TracedHTTPRequest) ToTracedContext() *TracedContext {
return &TracedContext{tr.Context(), tr.cfdTracer}
}
type TracedContext struct {
context.Context
*cfdTracer
}
// NewTracedHTTPRequest creates a new tracer for the current HTTP request context.
func NewTracedContext(ctx context.Context, traceContext string, log *zerolog.Logger) *TracedContext {
ctx, exists := extractTraceFromString(ctx, traceContext)
if !exists {
return &TracedContext{ctx, &cfdTracer{trace.NewNoopTracerProvider(), &NoopOtlpClient{}, log}}
}
return &TracedContext{ctx, newCfdTracer(ctx, log)}
}
type cfdTracer struct {
trace.TracerProvider
exporter InMemoryClient
log *zerolog.Logger
}
// NewCfdTracer creates a new tracer for the current request context.
func newCfdTracer(ctx context.Context, log *zerolog.Logger) *cfdTracer {
mc := new(InMemoryOtlpClient) mc := new(InMemoryOtlpClient)
exp, err := otlptrace.New(req.Context(), mc) exp, err := otlptrace.New(ctx, mc)
if err != nil { if err != nil {
return &TracedRequest{req, trace.NewNoopTracerProvider(), &NoopOtlpClient{}} return &cfdTracer{trace.NewNoopTracerProvider(), &NoopOtlpClient{}, log}
} }
tp := tracesdk.NewTracerProvider( tp := tracesdk.NewTracerProvider(
// We want to dump to in-memory exporter immediately // We want to dump to in-memory exporter immediately
@ -98,36 +130,43 @@ func NewTracedRequest(req *http.Request) *TracedRequest {
)), )),
) )
return &TracedRequest{req.WithContext(ctx), tp, mc} return &cfdTracer{tp, mc, log}
} }
func (cft *TracedRequest) Tracer() trace.Tracer { func (cft *cfdTracer) Tracer() trace.Tracer {
return cft.TracerProvider.Tracer(tracerInstrumentName) return cft.TracerProvider.Tracer(tracerInstrumentName)
} }
// Spans returns the spans as base64 encoded protobuf otlp traces. // GetSpans returns the spans as base64 encoded string of protobuf otlp traces.
func (cft *TracedRequest) AddSpans(headers http.Header, log *zerolog.Logger) { func (cft *cfdTracer) GetSpans() (enc string) {
if headers == nil {
log.Error().Msgf("provided headers map is nil")
return
}
enc, err := cft.exporter.Spans() enc, err := cft.exporter.Spans()
switch err { switch err {
case nil: case nil:
break break
case errNoTraces: case errNoTraces:
log.Error().Err(err).Msgf("expected traces to be available") cft.log.Error().Err(err).Msgf("expected traces to be available")
return return
case errNoopTracer: case errNoopTracer:
return // noop tracer has no traces return // noop tracer has no traces
default: default:
log.Error().Err(err) cft.log.Error().Err(err)
return return
} }
return
}
// AddSpans assigns spans as base64 encoded protobuf otlp traces to provided
// HTTP headers.
func (cft *cfdTracer) AddSpans(headers http.Header) {
if headers == nil {
cft.log.Error().Msgf("provided headers map is nil")
return
}
enc := cft.GetSpans()
// No need to add header if no traces // No need to add header if no traces
if enc == "" { if enc == "" {
log.Error().Msgf("no traces provided and no error from exporter") cft.log.Error().Msgf("no traces provided and no error from exporter")
return return
} }
@ -166,6 +205,33 @@ func endSpan(span trace.Span, upstreamStatusCode int, spanStatusCode codes.Code,
span.End() span.End()
} }
// extractTraceFromString will extract the trace information from the provided
// propagated trace string context.
func extractTraceFromString(ctx context.Context, trace string) (context.Context, bool) {
if trace == "" {
return ctx, false
}
// Jaeger specific separator
parts := strings.Split(trace, separator)
if len(parts) != 4 {
return ctx, false
}
if parts[0] == "" {
return ctx, false
}
// Correctly left pad the trace to a length of 32
if len(parts[0]) < traceID128bitsWidth {
left := traceID128bitsWidth - len(parts[0])
parts[0] = strings.Repeat("0", left) + parts[0]
trace = strings.Join(parts, separator)
}
// Override the 'cf-trace-id' as 'uber-trace-id' so the jaeger propagator can extract it.
traceHeader := map[string]string{TracerContextNameOverride: trace}
remoteCtx := otel.GetTextMapPropagator().Extract(ctx, propagation.MapCarrier(traceHeader))
return remoteCtx, true
}
// extractTrace attempts to check for a cf-trace-id from a request and return the // extractTrace attempts to check for a cf-trace-id from a request and return the
// trace context with the provided http.Request. // trace context with the provided http.Request.
func extractTrace(req *http.Request) (context.Context, bool) { func extractTrace(req *http.Request) (context.Context, bool) {

View File

@ -14,38 +14,42 @@ import (
) )
func TestNewCfTracer(t *testing.T) { func TestNewCfTracer(t *testing.T) {
log := zerolog.Nop()
req := httptest.NewRequest("GET", "http://localhost", nil) req := httptest.NewRequest("GET", "http://localhost", nil)
req.Header.Add(TracerContextName, "14cb070dde8e51fc5ae8514e69ba42ca:b38f1bf5eae406f3:0:1") req.Header.Add(TracerContextName, "14cb070dde8e51fc5ae8514e69ba42ca:b38f1bf5eae406f3:0:1")
tr := NewTracedRequest(req) tr := NewTracedHTTPRequest(req, &log)
assert.NotNil(t, tr) assert.NotNil(t, tr)
assert.IsType(t, tracesdk.NewTracerProvider(), tr.TracerProvider) assert.IsType(t, tracesdk.NewTracerProvider(), tr.TracerProvider)
assert.IsType(t, &InMemoryOtlpClient{}, tr.exporter) assert.IsType(t, &InMemoryOtlpClient{}, tr.exporter)
} }
func TestNewCfTracerMultiple(t *testing.T) { func TestNewCfTracerMultiple(t *testing.T) {
log := zerolog.Nop()
req := httptest.NewRequest("GET", "http://localhost", nil) req := httptest.NewRequest("GET", "http://localhost", nil)
req.Header.Add(TracerContextName, "1241ce3ecdefc68854e8514e69ba42ca:b38f1bf5eae406f3:0:1") req.Header.Add(TracerContextName, "1241ce3ecdefc68854e8514e69ba42ca:b38f1bf5eae406f3:0:1")
req.Header.Add(TracerContextName, "14cb070dde8e51fc5ae8514e69ba42ca:b38f1bf5eae406f3:0:1") req.Header.Add(TracerContextName, "14cb070dde8e51fc5ae8514e69ba42ca:b38f1bf5eae406f3:0:1")
tr := NewTracedRequest(req) tr := NewTracedHTTPRequest(req, &log)
assert.NotNil(t, tr) assert.NotNil(t, tr)
assert.IsType(t, tracesdk.NewTracerProvider(), tr.TracerProvider) assert.IsType(t, tracesdk.NewTracerProvider(), tr.TracerProvider)
assert.IsType(t, &InMemoryOtlpClient{}, tr.exporter) assert.IsType(t, &InMemoryOtlpClient{}, tr.exporter)
} }
func TestNewCfTracerNilHeader(t *testing.T) { func TestNewCfTracerNilHeader(t *testing.T) {
log := zerolog.Nop()
req := httptest.NewRequest("GET", "http://localhost", nil) req := httptest.NewRequest("GET", "http://localhost", nil)
req.Header[http.CanonicalHeaderKey(TracerContextName)] = nil req.Header[http.CanonicalHeaderKey(TracerContextName)] = nil
tr := NewTracedRequest(req) tr := NewTracedHTTPRequest(req, &log)
assert.NotNil(t, tr) assert.NotNil(t, tr)
assert.IsType(t, trace.NewNoopTracerProvider(), tr.TracerProvider) assert.IsType(t, trace.NewNoopTracerProvider(), tr.TracerProvider)
assert.IsType(t, &NoopOtlpClient{}, tr.exporter) assert.IsType(t, &NoopOtlpClient{}, tr.exporter)
} }
func TestNewCfTracerInvalidHeaders(t *testing.T) { func TestNewCfTracerInvalidHeaders(t *testing.T) {
log := zerolog.Nop()
req := httptest.NewRequest("GET", "http://localhost", nil) req := httptest.NewRequest("GET", "http://localhost", nil)
for _, test := range [][]string{nil, {""}} { for _, test := range [][]string{nil, {""}} {
req.Header[http.CanonicalHeaderKey(TracerContextName)] = test req.Header[http.CanonicalHeaderKey(TracerContextName)] = test
tr := NewTracedRequest(req) tr := NewTracedHTTPRequest(req, &log)
assert.NotNil(t, tr) assert.NotNil(t, tr)
assert.IsType(t, trace.NewNoopTracerProvider(), tr.TracerProvider) assert.IsType(t, trace.NewNoopTracerProvider(), tr.TracerProvider)
assert.IsType(t, &NoopOtlpClient{}, tr.exporter) assert.IsType(t, &NoopOtlpClient{}, tr.exporter)
@ -53,9 +57,10 @@ func TestNewCfTracerInvalidHeaders(t *testing.T) {
} }
func TestAddingSpansWithNilMap(t *testing.T) { func TestAddingSpansWithNilMap(t *testing.T) {
log := zerolog.Nop()
req := httptest.NewRequest("GET", "http://localhost", nil) req := httptest.NewRequest("GET", "http://localhost", nil)
req.Header.Add(TracerContextName, "14cb070dde8e51fc5ae8514e69ba42ca:b38f1bf5eae406f3:0:1") req.Header.Add(TracerContextName, "14cb070dde8e51fc5ae8514e69ba42ca:b38f1bf5eae406f3:0:1")
tr := NewTracedRequest(req) tr := NewTracedHTTPRequest(req, &log)
exporter := tr.exporter.(*InMemoryOtlpClient) exporter := tr.exporter.(*InMemoryOtlpClient)
@ -65,5 +70,5 @@ func TestAddingSpansWithNilMap(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
// a panic shouldn't occur // a panic shouldn't occur
tr.AddSpans(nil, &zerolog.Logger{}) tr.AddSpans(nil)
} }