package pogs import ( "context" "fmt" "time" "github.com/cloudflare/cloudflared/tunnelrpc" "github.com/google/uuid" log "github.com/sirupsen/logrus" capnp "zombiezen.com/go/capnproto2" "zombiezen.com/go/capnproto2/pogs" "zombiezen.com/go/capnproto2/rpc" "zombiezen.com/go/capnproto2/server" ) type Authentication struct { Key string Email string OriginCAKey string } func MarshalAuthentication(s tunnelrpc.Authentication, p *Authentication) error { return pogs.Insert(tunnelrpc.Authentication_TypeID, s.Struct, p) } func UnmarshalAuthentication(s tunnelrpc.Authentication) (*Authentication, error) { p := new(Authentication) err := pogs.Extract(p, tunnelrpc.Authentication_TypeID, s.Struct) return p, err } type TunnelRegistration struct { Err string Url string LogLines []string PermanentFailure bool TunnelID string `capnp:"tunnelID"` } func MarshalTunnelRegistration(s tunnelrpc.TunnelRegistration, p *TunnelRegistration) error { return pogs.Insert(tunnelrpc.TunnelRegistration_TypeID, s.Struct, p) } func UnmarshalTunnelRegistration(s tunnelrpc.TunnelRegistration) (*TunnelRegistration, error) { p := new(TunnelRegistration) err := pogs.Extract(p, tunnelrpc.TunnelRegistration_TypeID, s.Struct) return p, err } type RegistrationOptions struct { ClientID string `capnp:"clientId"` Version string OS string `capnp:"os"` ExistingTunnelPolicy tunnelrpc.ExistingTunnelPolicy PoolName string `capnp:"poolName"` Tags []Tag ConnectionID uint8 `capnp:"connectionId"` OriginLocalIP string `capnp:"originLocalIp"` IsAutoupdated bool `capnp:"isAutoupdated"` RunFromTerminal bool `capnp:"runFromTerminal"` CompressionQuality uint64 `capnp:"compressionQuality"` UUID string `capnp:"uuid"` } func MarshalRegistrationOptions(s tunnelrpc.RegistrationOptions, p *RegistrationOptions) error { return pogs.Insert(tunnelrpc.RegistrationOptions_TypeID, s.Struct, p) } func UnmarshalRegistrationOptions(s tunnelrpc.RegistrationOptions) (*RegistrationOptions, error) { p := new(RegistrationOptions) err := pogs.Extract(p, tunnelrpc.RegistrationOptions_TypeID, s.Struct) return p, err } type ConnectResult struct { Err *ConnectError ServerInfo ServerInfo ClientConfig ClientConfig } func MarshalConnectResult(s tunnelrpc.ConnectResult, p *ConnectResult) error { return pogs.Insert(tunnelrpc.ConnectResult_TypeID, s.Struct, p) } func UnmarshalConnectResult(s tunnelrpc.ConnectResult) (*ConnectResult, error) { p := new(ConnectResult) err := pogs.Extract(p, tunnelrpc.ConnectResult_TypeID, s.Struct) return p, err } type ConnectError struct { Cause string RetryAfter time.Duration ShouldRetry bool } func MarshalConnectError(s tunnelrpc.ConnectError, p *ConnectError) error { return pogs.Insert(tunnelrpc.ConnectError_TypeID, s.Struct, p) } func UnmarshalConnectError(s tunnelrpc.ConnectError) (*ConnectError, error) { p := new(ConnectError) err := pogs.Extract(p, tunnelrpc.ConnectError_TypeID, s.Struct) return p, err } func (e *ConnectError) Error() string { return e.Cause } type Tag struct { Name string `json:"name"` Value string `json:"value"` } type ServerInfo struct { LocationName string } func MarshalServerInfo(s tunnelrpc.ServerInfo, p *ServerInfo) error { return pogs.Insert(tunnelrpc.ServerInfo_TypeID, s.Struct, p) } func UnmarshalServerInfo(s tunnelrpc.ServerInfo) (*ServerInfo, error) { p := new(ServerInfo) err := pogs.Extract(p, tunnelrpc.ServerInfo_TypeID, s.Struct) return p, err } //go-sumtype:decl Scope type Scope interface { Value() string PostgresType() string GraphQLType() string isScope() } type SystemName struct { systemName string } func NewSystemName(systemName string) *SystemName { return &SystemName{systemName: systemName} } func (s *SystemName) Value() string { return s.systemName } func (_ *SystemName) PostgresType() string { return "system_name" } func (_ *SystemName) GraphQLType() string { return "SYSTEM_NAME" } func (_ *SystemName) isScope() {} type Group struct { group string } func NewGroup(group string) *Group { return &Group{group: group} } func (g *Group) Value() string { return g.group } func (_ *Group) PostgresType() string { return "group" } func (_ *Group) GraphQLType() string { return "GROUP" } func (_ *Group) isScope() {} func MarshalScope(s tunnelrpc.Scope, p Scope) error { ss := s.Value() switch scope := p.(type) { case *SystemName: ss.SetSystemName(scope.systemName) case *Group: ss.SetGroup(scope.group) default: return fmt.Errorf("unexpected Scope value: %v", p) } return nil } func UnmarshalScope(s tunnelrpc.Scope) (Scope, error) { ss := s.Value() switch ss.Which() { case tunnelrpc.Scope_value_Which_systemName: systemName, err := ss.SystemName() if err != nil { return nil, err } return NewSystemName(systemName), nil case tunnelrpc.Scope_value_Which_group: group, err := ss.Group() if err != nil { return nil, err } return NewGroup(group), nil default: return nil, fmt.Errorf("unexpected Scope tag: %v", ss.Which()) } } type ConnectParameters struct { OriginCert []byte CloudflaredID uuid.UUID NumPreviousAttempts uint8 Tags []Tag CloudflaredVersion string Scope Scope } func MarshalConnectParameters(s tunnelrpc.CapnpConnectParameters, p *ConnectParameters) error { if err := s.SetOriginCert(p.OriginCert); err != nil { return err } cloudflaredIDBytes, err := p.CloudflaredID.MarshalBinary() if err != nil { return err } if err := s.SetCloudflaredID(cloudflaredIDBytes); err != nil { return err } s.SetNumPreviousAttempts(p.NumPreviousAttempts) if len(p.Tags) > 0 { tagsCapnpList, err := s.NewTags(int32(len(p.Tags))) if err != nil { return err } for i, tag := range p.Tags { tagCapnp := tagsCapnpList.At(i) if err := tagCapnp.SetName(tag.Name); err != nil { return err } if err := tagCapnp.SetValue(tag.Value); err != nil { return err } } } if err := s.SetCloudflaredVersion(p.CloudflaredVersion); err != nil { return err } scope, err := s.NewScope() if err != nil { return err } return MarshalScope(scope, p.Scope) } func UnmarshalConnectParameters(s tunnelrpc.CapnpConnectParameters) (*ConnectParameters, error) { originCert, err := s.OriginCert() if err != nil { return nil, err } cloudflaredIDBytes, err := s.CloudflaredID() if err != nil { return nil, err } cloudflaredID, err := uuid.FromBytes(cloudflaredIDBytes) if err != nil { return nil, err } tagsCapnpList, err := s.Tags() if err != nil { return nil, err } var tags []Tag for i := 0; i < tagsCapnpList.Len(); i++ { tagCapnp := tagsCapnpList.At(i) name, err := tagCapnp.Name() if err != nil { return nil, err } value, err := tagCapnp.Value() if err != nil { return nil, err } tags = append(tags, Tag{Name: name, Value: value}) } cloudflaredVersion, err := s.CloudflaredVersion() if err != nil { return nil, err } scopeCapnp, err := s.Scope() if err != nil { return nil, err } scope, err := UnmarshalScope(scopeCapnp) if err != nil { return nil, err } return &ConnectParameters{ OriginCert: originCert, CloudflaredID: cloudflaredID, NumPreviousAttempts: s.NumPreviousAttempts(), Tags: tags, CloudflaredVersion: cloudflaredVersion, Scope: scope, }, nil } type TunnelServer interface { RegisterTunnel(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) (*TunnelRegistration, error) GetServerInfo(ctx context.Context) (*ServerInfo, error) UnregisterTunnel(ctx context.Context, gracePeriodNanoSec int64) error Connect(ctx context.Context, paramaters *ConnectParameters) (*ConnectResult, error) } func TunnelServer_ServerToClient(s TunnelServer) tunnelrpc.TunnelServer { return tunnelrpc.TunnelServer_ServerToClient(TunnelServer_PogsImpl{s}) } type TunnelServer_PogsImpl struct { impl TunnelServer } func (i TunnelServer_PogsImpl) RegisterTunnel(p tunnelrpc.TunnelServer_registerTunnel) error { originCert, err := p.Params.OriginCert() if err != nil { return err } hostname, err := p.Params.Hostname() if err != nil { return err } options, err := p.Params.Options() if err != nil { return err } pogsOptions, err := UnmarshalRegistrationOptions(options) if err != nil { return err } server.Ack(p.Options) registration, err := i.impl.RegisterTunnel(p.Ctx, originCert, hostname, pogsOptions) if err != nil { return err } result, err := p.Results.NewResult() if err != nil { return err } log.Info(registration.TunnelID) return MarshalTunnelRegistration(result, registration) } func (i TunnelServer_PogsImpl) GetServerInfo(p tunnelrpc.TunnelServer_getServerInfo) error { server.Ack(p.Options) serverInfo, err := i.impl.GetServerInfo(p.Ctx) if err != nil { return err } result, err := p.Results.NewResult() if err != nil { return err } return MarshalServerInfo(result, serverInfo) } func (i TunnelServer_PogsImpl) UnregisterTunnel(p tunnelrpc.TunnelServer_unregisterTunnel) error { gracePeriodNanoSec := p.Params.GracePeriodNanoSec() server.Ack(p.Options) return i.impl.UnregisterTunnel(p.Ctx, gracePeriodNanoSec) } func (i TunnelServer_PogsImpl) Connect(p tunnelrpc.TunnelServer_connect) error { paramaters, err := p.Params.Parameters() if err != nil { return err } pogsParameters, err := UnmarshalConnectParameters(paramaters) if err != nil { return err } server.Ack(p.Options) connectResult, err := i.impl.Connect(p.Ctx, pogsParameters) if err != nil { return err } result, err := p.Results.NewResult() if err != nil { return err } return MarshalConnectResult(result, connectResult) } type TunnelServer_PogsClient struct { Client capnp.Client Conn *rpc.Conn } func (c TunnelServer_PogsClient) Close() error { return c.Conn.Close() } func (c TunnelServer_PogsClient) RegisterTunnel(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) (*TunnelRegistration, error) { client := tunnelrpc.TunnelServer{Client: c.Client} promise := client.RegisterTunnel(ctx, func(p tunnelrpc.TunnelServer_registerTunnel_Params) error { err := p.SetOriginCert(originCert) if err != nil { return err } err = p.SetHostname(hostname) if err != nil { return err } registrationOptions, err := p.NewOptions() if err != nil { return err } err = MarshalRegistrationOptions(registrationOptions, options) if err != nil { return err } return nil }) retval, err := promise.Result().Struct() if err != nil { return nil, err } return UnmarshalTunnelRegistration(retval) } func (c TunnelServer_PogsClient) GetServerInfo(ctx context.Context) (*ServerInfo, error) { client := tunnelrpc.TunnelServer{Client: c.Client} promise := client.GetServerInfo(ctx, func(p tunnelrpc.TunnelServer_getServerInfo_Params) error { return nil }) retval, err := promise.Result().Struct() if err != nil { return nil, err } return UnmarshalServerInfo(retval) } func (c TunnelServer_PogsClient) UnregisterTunnel(ctx context.Context, gracePeriodNanoSec int64) error { client := tunnelrpc.TunnelServer{Client: c.Client} promise := client.UnregisterTunnel(ctx, func(p tunnelrpc.TunnelServer_unregisterTunnel_Params) error { p.SetGracePeriodNanoSec(gracePeriodNanoSec) return nil }) _, err := promise.Struct() return err } func (c TunnelServer_PogsClient) Connect(ctx context.Context, parameters *ConnectParameters, ) (*ConnectResult, error) { client := tunnelrpc.TunnelServer{Client: c.Client} promise := client.Connect(ctx, func(p tunnelrpc.TunnelServer_connect_Params) error { connectParameters, err := p.NewParameters() if err != nil { return err } err = MarshalConnectParameters(connectParameters, parameters) if err != nil { return err } return nil }) retval, err := promise.Result().Struct() if err != nil { return nil, err } return UnmarshalConnectResult(retval) }