package pogs import ( "context" "fmt" "time" "github.com/cloudflare/cloudflared/tunnelrpc" capnp "zombiezen.com/go/capnproto2" "zombiezen.com/go/capnproto2/pogs" "zombiezen.com/go/capnproto2/rpc" ) /// /// Structs /// type CloudflaredConfig struct { Timestamp time.Time AutoUpdateFrequency time.Duration MetricsUpdateFrequency time.Duration HeartbeatInterval time.Duration MaxFailedHeartbeats uint64 GracePeriod time.Duration DoHProxyConfigs []*DoHProxyConfig ReverseProxyConfigs []*ReverseProxyConfig } type UseConfigurationResult struct { Success bool ErrorMessage string } type DoHProxyConfig struct { ListenHost string ListenPort uint16 Upstreams []string } type ReverseProxyConfig struct { TunnelID string Origin OriginConfig Retries uint64 ConnectionTimeout time.Duration ChunkedEncoding bool CompressionQuality uint64 } //go-sumtype:decl OriginConfig type OriginConfig interface { isOriginConfig() } type HTTPOriginConfig struct { URL string `capnp:"url"` TCPKeepAlive time.Duration `capnp:"tcpKeepAlive"` DialDualStack bool TLSHandshakeTimeout time.Duration `capnp:"tlsHandshakeTimeout"` TLSVerify bool `capnp:"tlsVerify"` OriginCAPool string OriginServerName string MaxIdleConnections uint64 IdleConnectionTimeout time.Duration } func (_ *HTTPOriginConfig) isOriginConfig() {} type UnixSocketOriginConfig struct { Path string } func (_ *UnixSocketOriginConfig) isOriginConfig() {} type WebSocketOriginConfig struct { URL string `capnp:"url"` } func (_ *WebSocketOriginConfig) isOriginConfig() {} type HelloWorldOriginConfig struct{} func (_ *HelloWorldOriginConfig) isOriginConfig() {} /// /// Boilerplate to convert between these structs and the primitive structs generated by capnp-go /// func MarshalCloudflaredConfig(s tunnelrpc.CloudflaredConfig, p *CloudflaredConfig) error { s.SetTimestamp(p.Timestamp.UnixNano()) s.SetAutoUpdateFrequency(p.AutoUpdateFrequency.Nanoseconds()) s.SetMetricsUpdateFrequency(p.MetricsUpdateFrequency.Nanoseconds()) s.SetHeartbeatInterval(p.HeartbeatInterval.Nanoseconds()) s.SetMaxFailedHeartbeats(p.MaxFailedHeartbeats) s.SetGracePeriod(p.GracePeriod.Nanoseconds()) err := marshalDoHProxyConfigs(s, p.DoHProxyConfigs) if err != nil { return err } err = marshalReverseProxyConfigs(s, p.ReverseProxyConfigs) if err != nil { return err } return nil } func marshalDoHProxyConfigs(s tunnelrpc.CloudflaredConfig, dohProxyConfigs []*DoHProxyConfig) error { capnpList, err := s.NewDohProxyConfigs(int32(len(dohProxyConfigs))) if err != nil { return err } for i, unmarshalledConfig := range dohProxyConfigs { err := MarshalDoHProxyConfig(capnpList.At(i), unmarshalledConfig) if err != nil { return err } } return nil } func marshalReverseProxyConfigs(s tunnelrpc.CloudflaredConfig, reverseProxyConfigs []*ReverseProxyConfig) error { capnpList, err := s.NewReverseProxyConfigs(int32(len(reverseProxyConfigs))) if err != nil { return err } for i, unmarshalledConfig := range reverseProxyConfigs { err := MarshalReverseProxyConfig(capnpList.At(i), unmarshalledConfig) if err != nil { return err } } return nil } func UnmarshalCloudflaredConfig(s tunnelrpc.CloudflaredConfig) (*CloudflaredConfig, error) { p := new(CloudflaredConfig) p.Timestamp = time.Unix(0, s.Timestamp()).UTC() p.AutoUpdateFrequency = time.Duration(s.AutoUpdateFrequency()) p.MetricsUpdateFrequency = time.Duration(s.MetricsUpdateFrequency()) p.HeartbeatInterval = time.Duration(s.HeartbeatInterval()) p.MaxFailedHeartbeats = s.MaxFailedHeartbeats() p.GracePeriod = time.Duration(s.GracePeriod()) dohProxyConfigs, err := unmarshalDoHProxyConfigs(s) if err != nil { return nil, err } p.DoHProxyConfigs = dohProxyConfigs reverseProxyConfigs, err := unmarshalReverseProxyConfigs(s) if err != nil { return nil, err } p.ReverseProxyConfigs = reverseProxyConfigs return p, err } func unmarshalDoHProxyConfigs(s tunnelrpc.CloudflaredConfig) ([]*DoHProxyConfig, error) { var result []*DoHProxyConfig marshalledDoHProxyConfigs, err := s.DohProxyConfigs() if err != nil { return nil, err } for i := 0; i < marshalledDoHProxyConfigs.Len(); i++ { ss := marshalledDoHProxyConfigs.At(i) dohProxyConfig, err := UnmarshalDoHProxyConfig(ss) if err != nil { return nil, err } result = append(result, dohProxyConfig) } return result, nil } func unmarshalReverseProxyConfigs(s tunnelrpc.CloudflaredConfig) ([]*ReverseProxyConfig, error) { var result []*ReverseProxyConfig marshalledReverseProxyConfigs, err := s.ReverseProxyConfigs() if err != nil { return nil, err } for i := 0; i < marshalledReverseProxyConfigs.Len(); i++ { ss := marshalledReverseProxyConfigs.At(i) reverseProxyConfig, err := UnmarshalReverseProxyConfig(ss) if err != nil { return nil, err } result = append(result, reverseProxyConfig) } return result, nil } func MarshalUseConfigurationResult(s tunnelrpc.UseConfigurationResult, p *UseConfigurationResult) error { return pogs.Insert(tunnelrpc.UseConfigurationResult_TypeID, s.Struct, p) } func UnmarshalUseConfigurationResult(s tunnelrpc.UseConfigurationResult) (*UseConfigurationResult, error) { p := new(UseConfigurationResult) err := pogs.Extract(p, tunnelrpc.UseConfigurationResult_TypeID, s.Struct) return p, err } func MarshalDoHProxyConfig(s tunnelrpc.DoHProxyConfig, p *DoHProxyConfig) error { return pogs.Insert(tunnelrpc.DoHProxyConfig_TypeID, s.Struct, p) } func UnmarshalDoHProxyConfig(s tunnelrpc.DoHProxyConfig) (*DoHProxyConfig, error) { p := new(DoHProxyConfig) err := pogs.Extract(p, tunnelrpc.DoHProxyConfig_TypeID, s.Struct) return p, err } func MarshalReverseProxyConfig(s tunnelrpc.ReverseProxyConfig, p *ReverseProxyConfig) error { s.SetTunnelID(p.TunnelID) switch config := p.Origin.(type) { case *HTTPOriginConfig: ss, err := s.Origin().NewHttp() if err != nil { return err } MarshalHTTPOriginConfig(ss, config) case *UnixSocketOriginConfig: ss, err := s.Origin().NewSocket() if err != nil { return err } MarshalUnixSocketOriginConfig(ss, config) case *WebSocketOriginConfig: ss, err := s.Origin().NewWebsocket() if err != nil { return err } MarshalWebSocketOriginConfig(ss, config) case *HelloWorldOriginConfig: ss, err := s.Origin().NewHelloWorld() if err != nil { return err } MarshalHelloWorldOriginConfig(ss, config) default: return fmt.Errorf("Unknown type for config: %T", config) } s.SetRetries(p.Retries) s.SetConnectionTimeout(p.ConnectionTimeout.Nanoseconds()) s.SetChunkedEncoding(p.ChunkedEncoding) s.SetCompressionQuality(p.CompressionQuality) return nil } func UnmarshalReverseProxyConfig(s tunnelrpc.ReverseProxyConfig) (*ReverseProxyConfig, error) { p := new(ReverseProxyConfig) tunnelID, err := s.TunnelID() if err != nil { return nil, err } p.TunnelID = tunnelID switch s.Origin().Which() { case tunnelrpc.ReverseProxyConfig_origin_Which_http: ss, err := s.Origin().Http() if err != nil { return nil, err } config, err := UnmarshalHTTPOriginConfig(ss) if err != nil { return nil, err } p.Origin = config case tunnelrpc.ReverseProxyConfig_origin_Which_socket: ss, err := s.Origin().Socket() if err != nil { return nil, err } config, err := UnmarshalUnixSocketOriginConfig(ss) if err != nil { return nil, err } p.Origin = config case tunnelrpc.ReverseProxyConfig_origin_Which_websocket: ss, err := s.Origin().Websocket() if err != nil { return nil, err } config, err := UnmarshalWebSocketOriginConfig(ss) if err != nil { return nil, err } p.Origin = config case tunnelrpc.ReverseProxyConfig_origin_Which_helloWorld: ss, err := s.Origin().HelloWorld() if err != nil { return nil, err } config, err := UnmarshalHelloWorldOriginConfig(ss) if err != nil { return nil, err } p.Origin = config } p.Retries = s.Retries() p.ConnectionTimeout = time.Duration(s.ConnectionTimeout()) p.ChunkedEncoding = s.ChunkedEncoding() p.CompressionQuality = s.CompressionQuality() return p, nil } func MarshalHTTPOriginConfig(s tunnelrpc.HTTPOriginConfig, p *HTTPOriginConfig) error { return pogs.Insert(tunnelrpc.HTTPOriginConfig_TypeID, s.Struct, p) } func UnmarshalHTTPOriginConfig(s tunnelrpc.HTTPOriginConfig) (*HTTPOriginConfig, error) { p := new(HTTPOriginConfig) err := pogs.Extract(p, tunnelrpc.HTTPOriginConfig_TypeID, s.Struct) return p, err } func MarshalUnixSocketOriginConfig(s tunnelrpc.UnixSocketOriginConfig, p *UnixSocketOriginConfig) error { return pogs.Insert(tunnelrpc.UnixSocketOriginConfig_TypeID, s.Struct, p) } func UnmarshalUnixSocketOriginConfig(s tunnelrpc.UnixSocketOriginConfig) (*UnixSocketOriginConfig, error) { p := new(UnixSocketOriginConfig) err := pogs.Extract(p, tunnelrpc.UnixSocketOriginConfig_TypeID, s.Struct) return p, err } func MarshalWebSocketOriginConfig(s tunnelrpc.WebSocketOriginConfig, p *WebSocketOriginConfig) error { return pogs.Insert(tunnelrpc.WebSocketOriginConfig_TypeID, s.Struct, p) } func UnmarshalWebSocketOriginConfig(s tunnelrpc.WebSocketOriginConfig) (*WebSocketOriginConfig, error) { p := new(WebSocketOriginConfig) err := pogs.Extract(p, tunnelrpc.WebSocketOriginConfig_TypeID, s.Struct) return p, err } func MarshalHelloWorldOriginConfig(s tunnelrpc.HelloWorldOriginConfig, p *HelloWorldOriginConfig) error { return pogs.Insert(tunnelrpc.HelloWorldOriginConfig_TypeID, s.Struct, p) } func UnmarshalHelloWorldOriginConfig(s tunnelrpc.HelloWorldOriginConfig) (*HelloWorldOriginConfig, error) { p := new(HelloWorldOriginConfig) err := pogs.Extract(p, tunnelrpc.HelloWorldOriginConfig_TypeID, s.Struct) return p, err } type CloudflaredServer interface { UseConfiguration(ctx context.Context, config *CloudflaredConfig) (*CloudflaredConfig, error) GetConfiguration(ctx context.Context) (*CloudflaredConfig, error) } type CloudflaredServer_PogsClient struct { Client capnp.Client Conn *rpc.Conn } func (c *CloudflaredServer_PogsClient) Close() error { return c.Conn.Close() } func (c *CloudflaredServer_PogsClient) UseConfiguration( ctx context.Context, config *CloudflaredConfig, ) (*UseConfigurationResult, error) { client := tunnelrpc.CloudflaredServer{Client: c.Client} promise := client.UseConfiguration(ctx, func(p tunnelrpc.CloudflaredServer_useConfiguration_Params) error { cloudflaredConfig, err := p.NewCloudflaredConfig() if err != nil { return err } return MarshalCloudflaredConfig(cloudflaredConfig, config) }) retval, err := promise.Result().Struct() if err != nil { return nil, err } return UnmarshalUseConfigurationResult(retval) }