TUN-6517: Use QUIC stream context while proxying HTTP requests and TCP connections

This commit is contained in:
Igor Postelnik 2022-07-07 18:01:37 -05:00
parent 06f7ba4523
commit 1733fe8c65
2 changed files with 12 additions and 11 deletions

View File

@ -143,15 +143,16 @@ func (q *QUICConnection) acceptStream(ctx context.Context) error {
}
func (q *QUICConnection) runStream(quicStream quic.Stream) {
ctx := quicStream.Context()
stream := quicpogs.NewSafeStreamCloser(quicStream)
defer stream.Close()
if err := q.handleStream(stream); err != nil {
if err := q.handleStream(ctx, stream); err != nil {
q.logger.Err(err).Msg("Failed to handle QUIC stream")
}
}
func (q *QUICConnection) handleStream(stream io.ReadWriteCloser) error {
func (q *QUICConnection) handleStream(ctx context.Context, stream io.ReadWriteCloser) error {
signature, err := quicpogs.DetermineProtocol(stream)
if err != nil {
return err
@ -162,7 +163,7 @@ func (q *QUICConnection) handleStream(stream io.ReadWriteCloser) error {
if err != nil {
return err
}
return q.handleDataStream(reqServerStream)
return q.handleDataStream(ctx, reqServerStream)
case quicpogs.RPCStreamProtocolSignature:
rpcStream, err := quicpogs.NewRPCServerStream(stream, signature)
if err != nil {
@ -174,13 +175,13 @@ func (q *QUICConnection) handleStream(stream io.ReadWriteCloser) error {
}
}
func (q *QUICConnection) handleDataStream(stream *quicpogs.RequestServerStream) error {
func (q *QUICConnection) handleDataStream(ctx context.Context, stream *quicpogs.RequestServerStream) error {
request, err := stream.ReadConnectRequestData()
if err != nil {
return err
}
if err := q.dispatchRequest(stream, err, request); err != nil {
if err := q.dispatchRequest(ctx, stream, err, request); err != nil {
_ = stream.WriteConnectResponseData(err)
q.logger.Err(err).Str("type", request.Type.String()).Str("dest", request.Dest).Msg("Request failed")
}
@ -188,7 +189,7 @@ func (q *QUICConnection) handleDataStream(stream *quicpogs.RequestServerStream)
return nil
}
func (q *QUICConnection) dispatchRequest(stream *quicpogs.RequestServerStream, err error, request *quicpogs.ConnectRequest) error {
func (q *QUICConnection) dispatchRequest(ctx context.Context, stream *quicpogs.RequestServerStream, err error, request *quicpogs.ConnectRequest) error {
originProxy, err := q.orchestrator.GetOriginProxy()
if err != nil {
return err
@ -196,7 +197,7 @@ func (q *QUICConnection) dispatchRequest(stream *quicpogs.RequestServerStream, e
switch request.Type {
case quicpogs.ConnectionTypeHTTP, quicpogs.ConnectionTypeWebsocket:
tracedReq, err := buildHTTPRequest(request, stream)
tracedReq, err := buildHTTPRequest(ctx, request, stream)
if err != nil {
return err
}
@ -206,7 +207,7 @@ func (q *QUICConnection) dispatchRequest(stream *quicpogs.RequestServerStream, e
case quicpogs.ConnectionTypeTCP:
rwa := &streamReadWriteAcker{stream}
metadata := request.MetadataMap()
return originProxy.ProxyTCP(context.Background(), rwa, &TCPRequest{
return originProxy.ProxyTCP(ctx, rwa, &TCPRequest{
Dest: request.Dest,
FlowID: metadata[QUICMetadataFlowID],
})
@ -324,14 +325,14 @@ func (hrw httpResponseAdapter) WriteErrorResponse(err error) {
hrw.WriteConnectResponseData(err, quicpogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)})
}
func buildHTTPRequest(connectRequest *quicpogs.ConnectRequest, body io.ReadCloser) (*tracing.TracedRequest, error) {
func buildHTTPRequest(ctx context.Context, connectRequest *quicpogs.ConnectRequest, body io.ReadCloser) (*tracing.TracedRequest, error) {
metadata := connectRequest.MetadataMap()
dest := connectRequest.Dest
method := metadata[HTTPMethodKey]
host := metadata[HTTPHostKey]
isWebsocket := connectRequest.Type == quicpogs.ConnectionTypeWebsocket
req, err := http.NewRequest(method, dest, body)
req, err := http.NewRequestWithContext(ctx, method, dest, body)
if err != nil {
return nil, err
}

View File

@ -477,7 +477,7 @@ func TestBuildHTTPRequest(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
req, err := buildHTTPRequest(test.connectRequest, test.body)
req, err := buildHTTPRequest(context.Background(), test.connectRequest, test.body)
assert.NoError(t, err)
test.req = test.req.WithContext(req.Context())
assert.Equal(t, test.req, req.Request)