TUN-6517: Use QUIC stream context while proxying HTTP requests and TCP connections
This commit is contained in:
parent
06f7ba4523
commit
1733fe8c65
|
@ -143,15 +143,16 @@ func (q *QUICConnection) acceptStream(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *QUICConnection) runStream(quicStream quic.Stream) {
|
func (q *QUICConnection) runStream(quicStream quic.Stream) {
|
||||||
|
ctx := quicStream.Context()
|
||||||
stream := quicpogs.NewSafeStreamCloser(quicStream)
|
stream := quicpogs.NewSafeStreamCloser(quicStream)
|
||||||
defer stream.Close()
|
defer stream.Close()
|
||||||
|
|
||||||
if err := q.handleStream(stream); err != nil {
|
if err := q.handleStream(ctx, stream); err != nil {
|
||||||
q.logger.Err(err).Msg("Failed to handle QUIC stream")
|
q.logger.Err(err).Msg("Failed to handle QUIC stream")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *QUICConnection) handleStream(stream io.ReadWriteCloser) error {
|
func (q *QUICConnection) handleStream(ctx context.Context, stream io.ReadWriteCloser) error {
|
||||||
signature, err := quicpogs.DetermineProtocol(stream)
|
signature, err := quicpogs.DetermineProtocol(stream)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -162,7 +163,7 @@ func (q *QUICConnection) handleStream(stream io.ReadWriteCloser) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return q.handleDataStream(reqServerStream)
|
return q.handleDataStream(ctx, reqServerStream)
|
||||||
case quicpogs.RPCStreamProtocolSignature:
|
case quicpogs.RPCStreamProtocolSignature:
|
||||||
rpcStream, err := quicpogs.NewRPCServerStream(stream, signature)
|
rpcStream, err := quicpogs.NewRPCServerStream(stream, signature)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -174,13 +175,13 @@ func (q *QUICConnection) handleStream(stream io.ReadWriteCloser) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *QUICConnection) handleDataStream(stream *quicpogs.RequestServerStream) error {
|
func (q *QUICConnection) handleDataStream(ctx context.Context, stream *quicpogs.RequestServerStream) error {
|
||||||
request, err := stream.ReadConnectRequestData()
|
request, err := stream.ReadConnectRequestData()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := q.dispatchRequest(stream, err, request); err != nil {
|
if err := q.dispatchRequest(ctx, stream, err, request); err != nil {
|
||||||
_ = stream.WriteConnectResponseData(err)
|
_ = stream.WriteConnectResponseData(err)
|
||||||
q.logger.Err(err).Str("type", request.Type.String()).Str("dest", request.Dest).Msg("Request failed")
|
q.logger.Err(err).Str("type", request.Type.String()).Str("dest", request.Dest).Msg("Request failed")
|
||||||
}
|
}
|
||||||
|
@ -188,7 +189,7 @@ func (q *QUICConnection) handleDataStream(stream *quicpogs.RequestServerStream)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *QUICConnection) dispatchRequest(stream *quicpogs.RequestServerStream, err error, request *quicpogs.ConnectRequest) error {
|
func (q *QUICConnection) dispatchRequest(ctx context.Context, stream *quicpogs.RequestServerStream, err error, request *quicpogs.ConnectRequest) error {
|
||||||
originProxy, err := q.orchestrator.GetOriginProxy()
|
originProxy, err := q.orchestrator.GetOriginProxy()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -196,7 +197,7 @@ func (q *QUICConnection) dispatchRequest(stream *quicpogs.RequestServerStream, e
|
||||||
|
|
||||||
switch request.Type {
|
switch request.Type {
|
||||||
case quicpogs.ConnectionTypeHTTP, quicpogs.ConnectionTypeWebsocket:
|
case quicpogs.ConnectionTypeHTTP, quicpogs.ConnectionTypeWebsocket:
|
||||||
tracedReq, err := buildHTTPRequest(request, stream)
|
tracedReq, err := buildHTTPRequest(ctx, request, stream)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -206,7 +207,7 @@ func (q *QUICConnection) dispatchRequest(stream *quicpogs.RequestServerStream, e
|
||||||
case quicpogs.ConnectionTypeTCP:
|
case quicpogs.ConnectionTypeTCP:
|
||||||
rwa := &streamReadWriteAcker{stream}
|
rwa := &streamReadWriteAcker{stream}
|
||||||
metadata := request.MetadataMap()
|
metadata := request.MetadataMap()
|
||||||
return originProxy.ProxyTCP(context.Background(), rwa, &TCPRequest{
|
return originProxy.ProxyTCP(ctx, rwa, &TCPRequest{
|
||||||
Dest: request.Dest,
|
Dest: request.Dest,
|
||||||
FlowID: metadata[QUICMetadataFlowID],
|
FlowID: metadata[QUICMetadataFlowID],
|
||||||
})
|
})
|
||||||
|
@ -324,14 +325,14 @@ func (hrw httpResponseAdapter) WriteErrorResponse(err error) {
|
||||||
hrw.WriteConnectResponseData(err, quicpogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)})
|
hrw.WriteConnectResponseData(err, quicpogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)})
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildHTTPRequest(connectRequest *quicpogs.ConnectRequest, body io.ReadCloser) (*tracing.TracedRequest, error) {
|
func buildHTTPRequest(ctx context.Context, connectRequest *quicpogs.ConnectRequest, body io.ReadCloser) (*tracing.TracedRequest, error) {
|
||||||
metadata := connectRequest.MetadataMap()
|
metadata := connectRequest.MetadataMap()
|
||||||
dest := connectRequest.Dest
|
dest := connectRequest.Dest
|
||||||
method := metadata[HTTPMethodKey]
|
method := metadata[HTTPMethodKey]
|
||||||
host := metadata[HTTPHostKey]
|
host := metadata[HTTPHostKey]
|
||||||
isWebsocket := connectRequest.Type == quicpogs.ConnectionTypeWebsocket
|
isWebsocket := connectRequest.Type == quicpogs.ConnectionTypeWebsocket
|
||||||
|
|
||||||
req, err := http.NewRequest(method, dest, body)
|
req, err := http.NewRequestWithContext(ctx, method, dest, body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -477,7 +477,7 @@ func TestBuildHTTPRequest(t *testing.T) {
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
req, err := buildHTTPRequest(test.connectRequest, test.body)
|
req, err := buildHTTPRequest(context.Background(), test.connectRequest, test.body)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
test.req = test.req.WithContext(req.Context())
|
test.req = test.req.WithContext(req.Context())
|
||||||
assert.Equal(t, test.req, req.Request)
|
assert.Equal(t, test.req, req.Request)
|
||||||
|
|
Loading…
Reference in New Issue