cloudflared-mirror/tunnelrpc/pogs/sessionrpc.go

175 lines
4.2 KiB
Go
Raw Normal View History

package pogs
import (
"context"
"fmt"
"net"
"time"
"github.com/cloudflare/cloudflared/tunnelrpc"
"github.com/google/uuid"
capnp "zombiezen.com/go/capnproto2"
"zombiezen.com/go/capnproto2/rpc"
"zombiezen.com/go/capnproto2/server"
)
type SessionManager interface {
RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeAfterIdleHint time.Duration, traceContext string) error
UnregisterUdpSession(ctx context.Context, sessionID uuid.UUID, message string) error
}
type SessionManager_PogsImpl struct {
impl SessionManager
}
func SessionManager_ServerToClient(s SessionManager) tunnelrpc.SessionManager {
return tunnelrpc.SessionManager_ServerToClient(SessionManager_PogsImpl{s})
}
func (i SessionManager_PogsImpl) RegisterUdpSession(p tunnelrpc.SessionManager_registerUdpSession) error {
server.Ack(p.Options)
sessionIDRaw, err := p.Params.SessionId()
if err != nil {
return err
}
sessionID, err := uuid.FromBytes(sessionIDRaw)
if err != nil {
return err
}
dstIPRaw, err := p.Params.DstIp()
if err != nil {
return err
}
dstIP := net.IP(dstIPRaw)
if dstIP == nil {
return fmt.Errorf("%v is not valid IP", dstIPRaw)
}
dstPort := p.Params.DstPort()
closeIdleAfterHint := time.Duration(p.Params.CloseAfterIdleHint())
traceContext, err := p.Params.TraceContext()
if err != nil {
return err
}
resp := RegisterUdpSessionResponse{}
registrationErr := i.impl.RegisterUdpSession(p.Ctx, sessionID, dstIP, dstPort, closeIdleAfterHint, traceContext)
if registrationErr != nil {
resp.Err = registrationErr
}
// TUN-6689: Add spans to return path for RegisterUdpSession
result, err := p.Results.NewResult()
if err != nil {
return err
}
return resp.Marshal(result)
}
func (i SessionManager_PogsImpl) UnregisterUdpSession(p tunnelrpc.SessionManager_unregisterUdpSession) error {
server.Ack(p.Options)
sessionIDRaw, err := p.Params.SessionId()
if err != nil {
return err
}
sessionID, err := uuid.FromBytes(sessionIDRaw)
if err != nil {
return err
}
message, err := p.Params.Message()
if err != nil {
return err
}
return i.impl.UnregisterUdpSession(p.Ctx, sessionID, message)
}
type RegisterUdpSessionResponse struct {
Err error
Spans []byte // Spans in protobuf format
}
func (p *RegisterUdpSessionResponse) Marshal(s tunnelrpc.RegisterUdpSessionResponse) error {
if p.Err != nil {
return s.SetErr(p.Err.Error())
}
if err := s.SetSpans(p.Spans); err != nil {
return err
}
return nil
}
func (p *RegisterUdpSessionResponse) Unmarshal(s tunnelrpc.RegisterUdpSessionResponse) error {
respErr, err := s.Err()
if err != nil {
return err
}
if respErr != "" {
p.Err = fmt.Errorf(respErr)
}
p.Spans, err = s.Spans()
if err != nil {
return err
}
return nil
}
type SessionManager_PogsClient struct {
Client capnp.Client
Conn *rpc.Conn
}
func (c SessionManager_PogsClient) Close() error {
c.Client.Close()
return c.Conn.Close()
}
func (c SessionManager_PogsClient) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeAfterIdleHint time.Duration, traceContext string) (*RegisterUdpSessionResponse, error) {
client := tunnelrpc.SessionManager{Client: c.Client}
promise := client.RegisterUdpSession(ctx, func(p tunnelrpc.SessionManager_registerUdpSession_Params) error {
if err := p.SetSessionId(sessionID[:]); err != nil {
return err
}
if err := p.SetDstIp(dstIP); err != nil {
return err
}
p.SetDstPort(dstPort)
p.SetCloseAfterIdleHint(int64(closeAfterIdleHint))
p.SetTraceContext(traceContext)
return nil
})
result, err := promise.Result().Struct()
if err != nil {
return nil, wrapRPCError(err)
}
response := new(RegisterUdpSessionResponse)
err = response.Unmarshal(result)
if err != nil {
return nil, err
}
return response, nil
}
func (c SessionManager_PogsClient) UnregisterUdpSession(ctx context.Context, sessionID uuid.UUID, message string) error {
client := tunnelrpc.SessionManager{Client: c.Client}
promise := client.UnregisterUdpSession(ctx, func(p tunnelrpc.SessionManager_unregisterUdpSession_Params) error {
if err := p.SetSessionId(sessionID[:]); err != nil {
return err
}
if err := p.SetMessage(message); err != nil {
return err
}
return nil
})
_, err := promise.Struct()
return err
}