TUN-6517: Use QUIC stream context while proxying HTTP requests and TCP connections
This commit is contained in:
parent
06f7ba4523
commit
1733fe8c65
|
@ -143,15 +143,16 @@ func (q *QUICConnection) acceptStream(ctx context.Context) error {
|
|||
}
|
||||
|
||||
func (q *QUICConnection) runStream(quicStream quic.Stream) {
|
||||
ctx := quicStream.Context()
|
||||
stream := quicpogs.NewSafeStreamCloser(quicStream)
|
||||
defer stream.Close()
|
||||
|
||||
if err := q.handleStream(stream); err != nil {
|
||||
if err := q.handleStream(ctx, stream); err != nil {
|
||||
q.logger.Err(err).Msg("Failed to handle QUIC stream")
|
||||
}
|
||||
}
|
||||
|
||||
func (q *QUICConnection) handleStream(stream io.ReadWriteCloser) error {
|
||||
func (q *QUICConnection) handleStream(ctx context.Context, stream io.ReadWriteCloser) error {
|
||||
signature, err := quicpogs.DetermineProtocol(stream)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -162,7 +163,7 @@ func (q *QUICConnection) handleStream(stream io.ReadWriteCloser) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return q.handleDataStream(reqServerStream)
|
||||
return q.handleDataStream(ctx, reqServerStream)
|
||||
case quicpogs.RPCStreamProtocolSignature:
|
||||
rpcStream, err := quicpogs.NewRPCServerStream(stream, signature)
|
||||
if err != nil {
|
||||
|
@ -174,13 +175,13 @@ func (q *QUICConnection) handleStream(stream io.ReadWriteCloser) error {
|
|||
}
|
||||
}
|
||||
|
||||
func (q *QUICConnection) handleDataStream(stream *quicpogs.RequestServerStream) error {
|
||||
func (q *QUICConnection) handleDataStream(ctx context.Context, stream *quicpogs.RequestServerStream) error {
|
||||
request, err := stream.ReadConnectRequestData()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := q.dispatchRequest(stream, err, request); err != nil {
|
||||
if err := q.dispatchRequest(ctx, stream, err, request); err != nil {
|
||||
_ = stream.WriteConnectResponseData(err)
|
||||
q.logger.Err(err).Str("type", request.Type.String()).Str("dest", request.Dest).Msg("Request failed")
|
||||
}
|
||||
|
@ -188,7 +189,7 @@ func (q *QUICConnection) handleDataStream(stream *quicpogs.RequestServerStream)
|
|||
return nil
|
||||
}
|
||||
|
||||
func (q *QUICConnection) dispatchRequest(stream *quicpogs.RequestServerStream, err error, request *quicpogs.ConnectRequest) error {
|
||||
func (q *QUICConnection) dispatchRequest(ctx context.Context, stream *quicpogs.RequestServerStream, err error, request *quicpogs.ConnectRequest) error {
|
||||
originProxy, err := q.orchestrator.GetOriginProxy()
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -196,7 +197,7 @@ func (q *QUICConnection) dispatchRequest(stream *quicpogs.RequestServerStream, e
|
|||
|
||||
switch request.Type {
|
||||
case quicpogs.ConnectionTypeHTTP, quicpogs.ConnectionTypeWebsocket:
|
||||
tracedReq, err := buildHTTPRequest(request, stream)
|
||||
tracedReq, err := buildHTTPRequest(ctx, request, stream)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -206,7 +207,7 @@ func (q *QUICConnection) dispatchRequest(stream *quicpogs.RequestServerStream, e
|
|||
case quicpogs.ConnectionTypeTCP:
|
||||
rwa := &streamReadWriteAcker{stream}
|
||||
metadata := request.MetadataMap()
|
||||
return originProxy.ProxyTCP(context.Background(), rwa, &TCPRequest{
|
||||
return originProxy.ProxyTCP(ctx, rwa, &TCPRequest{
|
||||
Dest: request.Dest,
|
||||
FlowID: metadata[QUICMetadataFlowID],
|
||||
})
|
||||
|
@ -324,14 +325,14 @@ func (hrw httpResponseAdapter) WriteErrorResponse(err error) {
|
|||
hrw.WriteConnectResponseData(err, quicpogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)})
|
||||
}
|
||||
|
||||
func buildHTTPRequest(connectRequest *quicpogs.ConnectRequest, body io.ReadCloser) (*tracing.TracedRequest, error) {
|
||||
func buildHTTPRequest(ctx context.Context, connectRequest *quicpogs.ConnectRequest, body io.ReadCloser) (*tracing.TracedRequest, error) {
|
||||
metadata := connectRequest.MetadataMap()
|
||||
dest := connectRequest.Dest
|
||||
method := metadata[HTTPMethodKey]
|
||||
host := metadata[HTTPHostKey]
|
||||
isWebsocket := connectRequest.Type == quicpogs.ConnectionTypeWebsocket
|
||||
|
||||
req, err := http.NewRequest(method, dest, body)
|
||||
req, err := http.NewRequestWithContext(ctx, method, dest, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -477,7 +477,7 @@ func TestBuildHTTPRequest(t *testing.T) {
|
|||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
req, err := buildHTTPRequest(test.connectRequest, test.body)
|
||||
req, err := buildHTTPRequest(context.Background(), test.connectRequest, test.body)
|
||||
assert.NoError(t, err)
|
||||
test.req = test.req.WithContext(req.Context())
|
||||
assert.Equal(t, test.req, req.Request)
|
||||
|
|
Loading…
Reference in New Issue