259 lines
7.0 KiB
Go
259 lines
7.0 KiB
Go
package quic
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"time"
|
|
|
|
capnp "zombiezen.com/go/capnproto2"
|
|
"zombiezen.com/go/capnproto2/rpc"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/rs/zerolog"
|
|
|
|
"github.com/cloudflare/cloudflared/tunnelrpc"
|
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
|
)
|
|
|
|
// The first 6 bytes of the stream is used to distinguish the type of stream. It ensures whoever performs a handshake does
|
|
// not write data before writing the metadata.
|
|
type ProtocolSignature [6]byte
|
|
|
|
var (
|
|
// DataStreamProtocolSignature is a custom protocol signature for data stream
|
|
DataStreamProtocolSignature = ProtocolSignature{0x0A, 0x36, 0xCD, 0x12, 0xA1, 0x3E}
|
|
|
|
// RPCStreamProtocolSignature is a custom protocol signature for RPC stream
|
|
RPCStreamProtocolSignature = ProtocolSignature{0x52, 0xBB, 0x82, 0x5C, 0xDB, 0x65}
|
|
)
|
|
|
|
const protocolVersionLength = 2
|
|
|
|
type protocolVersion string
|
|
|
|
const (
|
|
protocolV1 protocolVersion = "01"
|
|
)
|
|
|
|
// RequestServerStream is a stream to serve requests
|
|
type RequestServerStream struct {
|
|
io.ReadWriteCloser
|
|
}
|
|
|
|
func NewRequestServerStream(stream io.ReadWriteCloser, signature ProtocolSignature) (*RequestServerStream, error) {
|
|
if signature != DataStreamProtocolSignature {
|
|
return nil, fmt.Errorf("RequestClientStream can only be created from data stream")
|
|
}
|
|
return &RequestServerStream{stream}, nil
|
|
}
|
|
|
|
// ReadConnectRequestData reads the handshake data from a QUIC stream.
|
|
func (rss *RequestServerStream) ReadConnectRequestData() (*ConnectRequest, error) {
|
|
// This is a NO-OP for now. We could cause a branching if we wanted to use multiple versions.
|
|
if _, err := readVersion(rss); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
msg, err := capnp.NewDecoder(rss).Decode()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
r := &ConnectRequest{}
|
|
if err := r.fromPogs(msg); err != nil {
|
|
return nil, err
|
|
}
|
|
return r, nil
|
|
}
|
|
|
|
// WriteConnectResponseData writes response to a QUIC stream.
|
|
func (rss *RequestServerStream) WriteConnectResponseData(respErr error, metadata ...Metadata) error {
|
|
var connectResponse *ConnectResponse
|
|
if respErr != nil {
|
|
connectResponse = &ConnectResponse{
|
|
Error: respErr.Error(),
|
|
}
|
|
} else {
|
|
connectResponse = &ConnectResponse{
|
|
Metadata: metadata,
|
|
}
|
|
}
|
|
|
|
msg, err := connectResponse.toPogs()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := writeDataStreamPreamble(rss); err != nil {
|
|
return err
|
|
}
|
|
return capnp.NewEncoder(rss).Encode(msg)
|
|
}
|
|
|
|
type RequestClientStream struct {
|
|
io.ReadWriteCloser
|
|
}
|
|
|
|
// WriteConnectRequestData writes requestMeta to a stream.
|
|
func (rcs *RequestClientStream) WriteConnectRequestData(dest string, connectionType ConnectionType, metadata ...Metadata) error {
|
|
connectRequest := &ConnectRequest{
|
|
Dest: dest,
|
|
Type: connectionType,
|
|
Metadata: metadata,
|
|
}
|
|
|
|
msg, err := connectRequest.toPogs()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := writeDataStreamPreamble(rcs); err != nil {
|
|
return err
|
|
}
|
|
return capnp.NewEncoder(rcs).Encode(msg)
|
|
}
|
|
|
|
// ReadConnectResponseData reads the response to a RequestMeta in a stream.
|
|
func (rcs *RequestClientStream) ReadConnectResponseData() (*ConnectResponse, error) {
|
|
signature, err := DetermineProtocol(rcs)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if signature != DataStreamProtocolSignature {
|
|
return nil, fmt.Errorf("Wrong protocol signature %v", signature)
|
|
}
|
|
|
|
// This is a NO-OP for now. We could cause a branching if we wanted to use multiple versions.
|
|
if _, err := readVersion(rcs); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
msg, err := capnp.NewDecoder(rcs).Decode()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
r := &ConnectResponse{}
|
|
if err := r.fromPogs(msg); err != nil {
|
|
return nil, err
|
|
}
|
|
return r, nil
|
|
}
|
|
|
|
// RPCServerStream is a stream to serve RPCs. It is closed when the RPC client is done
|
|
type RPCServerStream struct {
|
|
io.ReadWriteCloser
|
|
}
|
|
|
|
func NewRPCServerStream(stream io.ReadWriteCloser, protocol ProtocolSignature) (*RPCServerStream, error) {
|
|
if protocol != RPCStreamProtocolSignature {
|
|
return nil, fmt.Errorf("RPCStream can only be created from rpc stream")
|
|
}
|
|
return &RPCServerStream{stream}, nil
|
|
}
|
|
|
|
func (s *RPCServerStream) Serve(sessionManager tunnelpogs.SessionManager, logger *zerolog.Logger) error {
|
|
rpcTransport := tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(s))
|
|
defer rpcTransport.Close()
|
|
|
|
main := tunnelpogs.SessionManager_ServerToClient(sessionManager)
|
|
rpcConn := rpc.NewConn(
|
|
rpcTransport,
|
|
rpc.MainInterface(main.Client),
|
|
)
|
|
defer rpcConn.Close()
|
|
|
|
return rpcConn.Wait()
|
|
}
|
|
|
|
func DetermineProtocol(stream io.Reader) (ProtocolSignature, error) {
|
|
signature, err := readSignature(stream)
|
|
if err != nil {
|
|
return ProtocolSignature{}, err
|
|
}
|
|
switch signature {
|
|
case DataStreamProtocolSignature:
|
|
return DataStreamProtocolSignature, nil
|
|
case RPCStreamProtocolSignature:
|
|
return RPCStreamProtocolSignature, nil
|
|
default:
|
|
return ProtocolSignature{}, fmt.Errorf("Unknown signature %v", signature)
|
|
}
|
|
}
|
|
|
|
func writeDataStreamPreamble(stream io.Writer) error {
|
|
if err := writeSignature(stream, DataStreamProtocolSignature); err != nil {
|
|
return err
|
|
}
|
|
|
|
return writeVersion(stream)
|
|
}
|
|
|
|
func writeVersion(stream io.Writer) error {
|
|
_, err := stream.Write([]byte(protocolV1)[:protocolVersionLength])
|
|
return err
|
|
}
|
|
|
|
func readVersion(stream io.Reader) (string, error) {
|
|
version := make([]byte, protocolVersionLength)
|
|
_, err := stream.Read(version)
|
|
return string(version), err
|
|
}
|
|
|
|
func readSignature(stream io.Reader) (ProtocolSignature, error) {
|
|
var signature ProtocolSignature
|
|
if _, err := io.ReadFull(stream, signature[:]); err != nil {
|
|
return ProtocolSignature{}, err
|
|
}
|
|
return signature, nil
|
|
}
|
|
|
|
func writeSignature(stream io.Writer, signature ProtocolSignature) error {
|
|
_, err := stream.Write(signature[:])
|
|
return err
|
|
}
|
|
|
|
// RPCClientStream is a stream to call methods of SessionManager
|
|
type RPCClientStream struct {
|
|
client tunnelpogs.SessionManager_PogsClient
|
|
transport rpc.Transport
|
|
}
|
|
|
|
func NewRPCClientStream(ctx context.Context, stream io.ReadWriteCloser, logger *zerolog.Logger) (*RPCClientStream, error) {
|
|
n, err := stream.Write(RPCStreamProtocolSignature[:])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if n != len(RPCStreamProtocolSignature) {
|
|
return nil, fmt.Errorf("expect to write %d bytes for RPC stream protocol signature, wrote %d", len(RPCStreamProtocolSignature), n)
|
|
}
|
|
transport := tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream))
|
|
conn := rpc.NewConn(
|
|
transport,
|
|
tunnelrpc.ConnLog(logger),
|
|
)
|
|
return &RPCClientStream{
|
|
client: tunnelpogs.SessionManager_PogsClient{Client: conn.Bootstrap(ctx), Conn: conn},
|
|
transport: transport,
|
|
}, nil
|
|
}
|
|
|
|
func (rcs *RPCClientStream) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeIdleAfterHint time.Duration) error {
|
|
resp, err := rcs.client.RegisterUdpSession(ctx, sessionID, dstIP, dstPort, closeIdleAfterHint)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return resp.Err
|
|
}
|
|
|
|
func (rcs *RPCClientStream) UnregisterUdpSession(ctx context.Context, sessionID uuid.UUID) error {
|
|
return rcs.client.UnregisterUdpSession(ctx, sessionID)
|
|
}
|
|
|
|
func (rcs *RPCClientStream) Close() {
|
|
_ = rcs.client.Close()
|
|
_ = rcs.transport.Close()
|
|
}
|