diff --git a/connection/quic.go b/connection/quic.go index 4c530b19..7adc03c2 100644 --- a/connection/quic.go +++ b/connection/quic.go @@ -143,15 +143,16 @@ func (q *QUICConnection) acceptStream(ctx context.Context) error { } func (q *QUICConnection) runStream(quicStream quic.Stream) { + ctx := quicStream.Context() stream := quicpogs.NewSafeStreamCloser(quicStream) defer stream.Close() - if err := q.handleStream(stream); err != nil { + if err := q.handleStream(ctx, stream); err != nil { q.logger.Err(err).Msg("Failed to handle QUIC stream") } } -func (q *QUICConnection) handleStream(stream io.ReadWriteCloser) error { +func (q *QUICConnection) handleStream(ctx context.Context, stream io.ReadWriteCloser) error { signature, err := quicpogs.DetermineProtocol(stream) if err != nil { return err @@ -162,7 +163,7 @@ func (q *QUICConnection) handleStream(stream io.ReadWriteCloser) error { if err != nil { return err } - return q.handleDataStream(reqServerStream) + return q.handleDataStream(ctx, reqServerStream) case quicpogs.RPCStreamProtocolSignature: rpcStream, err := quicpogs.NewRPCServerStream(stream, signature) if err != nil { @@ -174,13 +175,13 @@ func (q *QUICConnection) handleStream(stream io.ReadWriteCloser) error { } } -func (q *QUICConnection) handleDataStream(stream *quicpogs.RequestServerStream) error { +func (q *QUICConnection) handleDataStream(ctx context.Context, stream *quicpogs.RequestServerStream) error { request, err := stream.ReadConnectRequestData() if err != nil { return err } - if err := q.dispatchRequest(stream, err, request); err != nil { + if err := q.dispatchRequest(ctx, stream, err, request); err != nil { _ = stream.WriteConnectResponseData(err) q.logger.Err(err).Str("type", request.Type.String()).Str("dest", request.Dest).Msg("Request failed") } @@ -188,7 +189,7 @@ func (q *QUICConnection) handleDataStream(stream *quicpogs.RequestServerStream) return nil } -func (q *QUICConnection) dispatchRequest(stream *quicpogs.RequestServerStream, err error, request *quicpogs.ConnectRequest) error { +func (q *QUICConnection) dispatchRequest(ctx context.Context, stream *quicpogs.RequestServerStream, err error, request *quicpogs.ConnectRequest) error { originProxy, err := q.orchestrator.GetOriginProxy() if err != nil { return err @@ -196,7 +197,7 @@ func (q *QUICConnection) dispatchRequest(stream *quicpogs.RequestServerStream, e switch request.Type { case quicpogs.ConnectionTypeHTTP, quicpogs.ConnectionTypeWebsocket: - tracedReq, err := buildHTTPRequest(request, stream) + tracedReq, err := buildHTTPRequest(ctx, request, stream) if err != nil { return err } @@ -206,7 +207,7 @@ func (q *QUICConnection) dispatchRequest(stream *quicpogs.RequestServerStream, e case quicpogs.ConnectionTypeTCP: rwa := &streamReadWriteAcker{stream} metadata := request.MetadataMap() - return originProxy.ProxyTCP(context.Background(), rwa, &TCPRequest{ + return originProxy.ProxyTCP(ctx, rwa, &TCPRequest{ Dest: request.Dest, FlowID: metadata[QUICMetadataFlowID], }) @@ -324,14 +325,14 @@ func (hrw httpResponseAdapter) WriteErrorResponse(err error) { hrw.WriteConnectResponseData(err, quicpogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)}) } -func buildHTTPRequest(connectRequest *quicpogs.ConnectRequest, body io.ReadCloser) (*tracing.TracedRequest, error) { +func buildHTTPRequest(ctx context.Context, connectRequest *quicpogs.ConnectRequest, body io.ReadCloser) (*tracing.TracedRequest, error) { metadata := connectRequest.MetadataMap() dest := connectRequest.Dest method := metadata[HTTPMethodKey] host := metadata[HTTPHostKey] isWebsocket := connectRequest.Type == quicpogs.ConnectionTypeWebsocket - req, err := http.NewRequest(method, dest, body) + req, err := http.NewRequestWithContext(ctx, method, dest, body) if err != nil { return nil, err } diff --git a/connection/quic_test.go b/connection/quic_test.go index cbfaf714..33da98ac 100644 --- a/connection/quic_test.go +++ b/connection/quic_test.go @@ -477,7 +477,7 @@ func TestBuildHTTPRequest(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - req, err := buildHTTPRequest(test.connectRequest, test.body) + req, err := buildHTTPRequest(context.Background(), test.connectRequest, test.body) assert.NoError(t, err) test.req = test.req.WithContext(req.Context()) assert.Equal(t, test.req, req.Request)