diff --git a/h2mux/muxreader.go b/h2mux/muxreader.go index 3bdc8216..d97fcd8c 100644 --- a/h2mux/muxreader.go +++ b/h2mux/muxreader.go @@ -11,6 +11,10 @@ import ( "golang.org/x/net/http2" ) +const ( + CloudflaredProxyTunnelHostnameHeader = "cf-cloudflared-proxy-tunnel-hostname" +) + type MuxReader struct { // f is used to read HTTP2 frames. f *http2.Framer @@ -235,6 +239,8 @@ func (r *MuxReader) receiveHeaderData(frame *http2.MetaHeadersFrame) error { if r.dictionaries.write != nil { continue } + case CloudflaredProxyTunnelHostnameHeader: + stream.tunnelHostname = TunnelHostname(header.Value) } headers = append(headers, Header{Name: header.Name, Value: header.Value}) } diff --git a/h2mux/muxreader_test.go b/h2mux/muxreader_test.go new file mode 100644 index 00000000..dd3bf440 --- /dev/null +++ b/h2mux/muxreader_test.go @@ -0,0 +1,107 @@ +package h2mux + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +var ( + methodHeader = Header{ + Name: ":method", + Value: "GET", + } + schemeHeader = Header{ + Name: ":scheme", + Value: "https", + } + pathHeader = Header{ + Name: ":path", + Value: "/api/tunnels", + } + tunnelHostnameHeader = Header{ + Name: CloudflaredProxyTunnelHostnameHeader, + Value: "tunnel.example.com", + } + respStatusHeader = Header{ + Name: ":status", + Value: "200", + } +) + +type mockOriginStreamHandler struct { + stream *MuxedStream +} + +func (mosh *mockOriginStreamHandler) ServeStream(stream *MuxedStream) error { + mosh.stream = stream + // Echo tunnel hostname in header + stream.WriteHeaders([]Header{respStatusHeader}) + return nil +} + +func getCloudflaredProxyTunnelHostnameHeader(stream *MuxedStream) string { + for _, header := range stream.Headers { + if header.Name == CloudflaredProxyTunnelHostnameHeader { + return header.Value + } + } + return "" +} + +func assertOpenStreamSucceed(t *testing.T, stream *MuxedStream, err error) { + assert.NoError(t, err) + assert.Len(t, stream.Headers, 1) + assert.Equal(t, respStatusHeader, stream.Headers[0]) +} + +func TestMissingHeaders(t *testing.T) { + originHandler := &mockOriginStreamHandler{} + muxPair := NewDefaultMuxerPair(t, originHandler.ServeStream) + muxPair.Serve(t) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + reqHeaders := []Header{ + { + Name: "content-type", + Value: "application/json", + }, + } + + // Request doesn't contain CloudflaredProxyTunnelHostnameHeader + stream, err := muxPair.EdgeMux.OpenStream(ctx, reqHeaders, nil) + assertOpenStreamSucceed(t, stream, err) + + assert.Empty(t, originHandler.stream.method) + assert.Empty(t, originHandler.stream.path) + assert.False(t, originHandler.stream.TunnelHostname().IsSet()) +} + +func TestReceiveHeaderData(t *testing.T) { + originHandler := &mockOriginStreamHandler{} + muxPair := NewDefaultMuxerPair(t, originHandler.ServeStream) + muxPair.Serve(t) + + reqHeaders := []Header{ + methodHeader, + schemeHeader, + pathHeader, + tunnelHostnameHeader, + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + reqHeaders = append(reqHeaders, tunnelHostnameHeader) + stream, err := muxPair.EdgeMux.OpenStream(ctx, reqHeaders, nil) + assertOpenStreamSucceed(t, stream, err) + + assert.Equal(t, methodHeader.Value, originHandler.stream.method) + assert.Equal(t, pathHeader.Value, originHandler.stream.path) + assert.True(t, originHandler.stream.TunnelHostname().IsSet()) + assert.Equal(t, tunnelHostnameHeader.Value, originHandler.stream.TunnelHostname().String()) +}