cloudflared-mirror/tunnelrpc/pogs/session_manager.go

188 lines
4.7 KiB
Go

package pogs
import (
"context"
"fmt"
"net"
"time"
"github.com/google/uuid"
capnp "zombiezen.com/go/capnproto2"
"zombiezen.com/go/capnproto2/rpc"
"zombiezen.com/go/capnproto2/server"
"github.com/cloudflare/cloudflared/tunnelrpc/proto"
)
type SessionManager interface {
// RegisterUdpSession is the call provided to cloudflared to handle an incoming
// capnproto RegisterUdpSession request from the edge.
RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeAfterIdleHint time.Duration, traceContext string) (*RegisterUdpSessionResponse, error)
// UnregisterUdpSession is the call provided to cloudflared to handle an incoming
// capnproto UnregisterUdpSession request from the edge.
UnregisterUdpSession(ctx context.Context, sessionID uuid.UUID, message string) error
}
type SessionManager_PogsImpl struct {
impl SessionManager
}
func SessionManager_ServerToClient(s SessionManager) proto.SessionManager {
return proto.SessionManager_ServerToClient(SessionManager_PogsImpl{s})
}
func (i SessionManager_PogsImpl) RegisterUdpSession(p proto.SessionManager_registerUdpSession) error {
server.Ack(p.Options)
sessionIDRaw, err := p.Params.SessionId()
if err != nil {
return err
}
sessionID, err := uuid.FromBytes(sessionIDRaw)
if err != nil {
return err
}
dstIPRaw, err := p.Params.DstIp()
if err != nil {
return err
}
dstIP := net.IP(dstIPRaw)
if dstIP == nil {
return fmt.Errorf("%v is not valid IP", dstIPRaw)
}
dstPort := p.Params.DstPort()
closeIdleAfterHint := time.Duration(p.Params.CloseAfterIdleHint())
traceContext, err := p.Params.TraceContext()
if err != nil {
return err
}
resp, registrationErr := i.impl.RegisterUdpSession(p.Ctx, sessionID, dstIP, dstPort, closeIdleAfterHint, traceContext)
if registrationErr != nil {
// Make sure to assign a response even if one is not returned from register
if resp == nil {
resp = &RegisterUdpSessionResponse{}
}
resp.Err = registrationErr
}
result, err := p.Results.NewResult()
if err != nil {
return err
}
return resp.Marshal(result)
}
func (i SessionManager_PogsImpl) UnregisterUdpSession(p proto.SessionManager_unregisterUdpSession) error {
server.Ack(p.Options)
sessionIDRaw, err := p.Params.SessionId()
if err != nil {
return err
}
sessionID, err := uuid.FromBytes(sessionIDRaw)
if err != nil {
return err
}
message, err := p.Params.Message()
if err != nil {
return err
}
return i.impl.UnregisterUdpSession(p.Ctx, sessionID, message)
}
type RegisterUdpSessionResponse struct {
Err error
Spans []byte // Spans in protobuf format
}
func (p *RegisterUdpSessionResponse) Marshal(s proto.RegisterUdpSessionResponse) error {
if p.Err != nil {
return s.SetErr(p.Err.Error())
}
if err := s.SetSpans(p.Spans); err != nil {
return err
}
return nil
}
func (p *RegisterUdpSessionResponse) Unmarshal(s proto.RegisterUdpSessionResponse) error {
respErr, err := s.Err()
if err != nil {
return err
}
if respErr != "" {
p.Err = fmt.Errorf(respErr)
}
p.Spans, err = s.Spans()
if err != nil {
return err
}
return nil
}
type SessionManager_PogsClient struct {
Client capnp.Client
Conn *rpc.Conn
}
func NewSessionManager_PogsClient(client capnp.Client, conn *rpc.Conn) SessionManager_PogsClient {
return SessionManager_PogsClient{
Client: client,
Conn: conn,
}
}
func (c SessionManager_PogsClient) Close() error {
c.Client.Close()
return c.Conn.Close()
}
func (c SessionManager_PogsClient) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeAfterIdleHint time.Duration, traceContext string) (*RegisterUdpSessionResponse, error) {
client := proto.SessionManager{Client: c.Client}
promise := client.RegisterUdpSession(ctx, func(p proto.SessionManager_registerUdpSession_Params) error {
if err := p.SetSessionId(sessionID[:]); err != nil {
return err
}
if err := p.SetDstIp(dstIP); err != nil {
return err
}
p.SetDstPort(dstPort)
p.SetCloseAfterIdleHint(int64(closeAfterIdleHint))
p.SetTraceContext(traceContext)
return nil
})
result, err := promise.Result().Struct()
if err != nil {
return nil, wrapRPCError(err)
}
response := new(RegisterUdpSessionResponse)
err = response.Unmarshal(result)
if err != nil {
return nil, err
}
return response, nil
}
func (c SessionManager_PogsClient) UnregisterUdpSession(ctx context.Context, sessionID uuid.UUID, message string) error {
client := proto.SessionManager{Client: c.Client}
promise := client.UnregisterUdpSession(ctx, func(p proto.SessionManager_unregisterUdpSession_Params) error {
if err := p.SetSessionId(sessionID[:]); err != nil {
return err
}
if err := p.SetMessage(message); err != nil {
return err
}
return nil
})
_, err := promise.Struct()
return err
}