374 lines
11 KiB
Go
374 lines
11 KiB
Go
package pogs
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/cloudflare/cloudflared/tunnelrpc"
|
|
capnp "zombiezen.com/go/capnproto2"
|
|
"zombiezen.com/go/capnproto2/pogs"
|
|
"zombiezen.com/go/capnproto2/rpc"
|
|
)
|
|
|
|
///
|
|
/// Structs
|
|
///
|
|
|
|
type CloudflaredConfig struct {
|
|
Timestamp time.Time
|
|
AutoUpdateFrequency time.Duration
|
|
MetricsUpdateFrequency time.Duration
|
|
HeartbeatInterval time.Duration
|
|
MaxFailedHeartbeats uint64
|
|
GracePeriod time.Duration
|
|
DoHProxyConfigs []*DoHProxyConfig
|
|
ReverseProxyConfigs []*ReverseProxyConfig
|
|
}
|
|
|
|
type UseConfigurationResult struct {
|
|
Success bool
|
|
ErrorMessage string
|
|
}
|
|
|
|
type DoHProxyConfig struct {
|
|
ListenHost string
|
|
ListenPort uint16
|
|
Upstreams []string
|
|
}
|
|
|
|
type ReverseProxyConfig struct {
|
|
TunnelID string
|
|
Origin OriginConfig
|
|
Retries uint64
|
|
ConnectionTimeout time.Duration
|
|
ChunkedEncoding bool
|
|
CompressionQuality uint64
|
|
}
|
|
|
|
//go-sumtype:decl OriginConfig
|
|
type OriginConfig interface {
|
|
isOriginConfig()
|
|
}
|
|
|
|
type HTTPOriginConfig struct {
|
|
URL string `capnp:"url"`
|
|
TCPKeepAlive time.Duration `capnp:"tcpKeepAlive"`
|
|
DialDualStack bool
|
|
TLSHandshakeTimeout time.Duration `capnp:"tlsHandshakeTimeout"`
|
|
TLSVerify bool `capnp:"tlsVerify"`
|
|
OriginCAPool string
|
|
OriginServerName string
|
|
MaxIdleConnections uint64
|
|
IdleConnectionTimeout time.Duration
|
|
}
|
|
|
|
func (_ *HTTPOriginConfig) isOriginConfig() {}
|
|
|
|
type UnixSocketOriginConfig struct {
|
|
Path string
|
|
}
|
|
|
|
func (_ *UnixSocketOriginConfig) isOriginConfig() {}
|
|
|
|
type WebSocketOriginConfig struct {
|
|
URL string `capnp:"url"`
|
|
}
|
|
|
|
func (_ *WebSocketOriginConfig) isOriginConfig() {}
|
|
|
|
type HelloWorldOriginConfig struct{}
|
|
|
|
func (_ *HelloWorldOriginConfig) isOriginConfig() {}
|
|
|
|
///
|
|
/// Boilerplate to convert between these structs and the primitive structs generated by capnp-go
|
|
///
|
|
|
|
func MarshalCloudflaredConfig(s tunnelrpc.CloudflaredConfig, p *CloudflaredConfig) error {
|
|
s.SetTimestamp(p.Timestamp.UnixNano())
|
|
s.SetAutoUpdateFrequency(p.AutoUpdateFrequency.Nanoseconds())
|
|
s.SetMetricsUpdateFrequency(p.MetricsUpdateFrequency.Nanoseconds())
|
|
s.SetHeartbeatInterval(p.HeartbeatInterval.Nanoseconds())
|
|
s.SetMaxFailedHeartbeats(p.MaxFailedHeartbeats)
|
|
s.SetGracePeriod(p.GracePeriod.Nanoseconds())
|
|
err := marshalDoHProxyConfigs(s, p.DoHProxyConfigs)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = marshalReverseProxyConfigs(s, p.ReverseProxyConfigs)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func marshalDoHProxyConfigs(s tunnelrpc.CloudflaredConfig, dohProxyConfigs []*DoHProxyConfig) error {
|
|
capnpList, err := s.NewDohProxyConfigs(int32(len(dohProxyConfigs)))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for i, unmarshalledConfig := range dohProxyConfigs {
|
|
err := MarshalDoHProxyConfig(capnpList.At(i), unmarshalledConfig)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func marshalReverseProxyConfigs(s tunnelrpc.CloudflaredConfig, reverseProxyConfigs []*ReverseProxyConfig) error {
|
|
capnpList, err := s.NewReverseProxyConfigs(int32(len(reverseProxyConfigs)))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for i, unmarshalledConfig := range reverseProxyConfigs {
|
|
err := MarshalReverseProxyConfig(capnpList.At(i), unmarshalledConfig)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func UnmarshalCloudflaredConfig(s tunnelrpc.CloudflaredConfig) (*CloudflaredConfig, error) {
|
|
p := new(CloudflaredConfig)
|
|
p.Timestamp = time.Unix(0, s.Timestamp()).UTC()
|
|
p.AutoUpdateFrequency = time.Duration(s.AutoUpdateFrequency())
|
|
p.MetricsUpdateFrequency = time.Duration(s.MetricsUpdateFrequency())
|
|
p.HeartbeatInterval = time.Duration(s.HeartbeatInterval())
|
|
p.MaxFailedHeartbeats = s.MaxFailedHeartbeats()
|
|
p.GracePeriod = time.Duration(s.GracePeriod())
|
|
dohProxyConfigs, err := unmarshalDoHProxyConfigs(s)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
p.DoHProxyConfigs = dohProxyConfigs
|
|
reverseProxyConfigs, err := unmarshalReverseProxyConfigs(s)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
p.ReverseProxyConfigs = reverseProxyConfigs
|
|
return p, err
|
|
}
|
|
|
|
func unmarshalDoHProxyConfigs(s tunnelrpc.CloudflaredConfig) ([]*DoHProxyConfig, error) {
|
|
var result []*DoHProxyConfig
|
|
marshalledDoHProxyConfigs, err := s.DohProxyConfigs()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for i := 0; i < marshalledDoHProxyConfigs.Len(); i++ {
|
|
ss := marshalledDoHProxyConfigs.At(i)
|
|
dohProxyConfig, err := UnmarshalDoHProxyConfig(ss)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
result = append(result, dohProxyConfig)
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func unmarshalReverseProxyConfigs(s tunnelrpc.CloudflaredConfig) ([]*ReverseProxyConfig, error) {
|
|
var result []*ReverseProxyConfig
|
|
marshalledReverseProxyConfigs, err := s.ReverseProxyConfigs()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for i := 0; i < marshalledReverseProxyConfigs.Len(); i++ {
|
|
ss := marshalledReverseProxyConfigs.At(i)
|
|
reverseProxyConfig, err := UnmarshalReverseProxyConfig(ss)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
result = append(result, reverseProxyConfig)
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func MarshalUseConfigurationResult(s tunnelrpc.UseConfigurationResult, p *UseConfigurationResult) error {
|
|
return pogs.Insert(tunnelrpc.UseConfigurationResult_TypeID, s.Struct, p)
|
|
}
|
|
|
|
func UnmarshalUseConfigurationResult(s tunnelrpc.UseConfigurationResult) (*UseConfigurationResult, error) {
|
|
p := new(UseConfigurationResult)
|
|
err := pogs.Extract(p, tunnelrpc.UseConfigurationResult_TypeID, s.Struct)
|
|
return p, err
|
|
}
|
|
|
|
func MarshalDoHProxyConfig(s tunnelrpc.DoHProxyConfig, p *DoHProxyConfig) error {
|
|
return pogs.Insert(tunnelrpc.DoHProxyConfig_TypeID, s.Struct, p)
|
|
}
|
|
|
|
func UnmarshalDoHProxyConfig(s tunnelrpc.DoHProxyConfig) (*DoHProxyConfig, error) {
|
|
p := new(DoHProxyConfig)
|
|
err := pogs.Extract(p, tunnelrpc.DoHProxyConfig_TypeID, s.Struct)
|
|
return p, err
|
|
}
|
|
|
|
func MarshalReverseProxyConfig(s tunnelrpc.ReverseProxyConfig, p *ReverseProxyConfig) error {
|
|
s.SetTunnelID(p.TunnelID)
|
|
switch config := p.Origin.(type) {
|
|
case *HTTPOriginConfig:
|
|
ss, err := s.Origin().NewHttp()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
MarshalHTTPOriginConfig(ss, config)
|
|
case *UnixSocketOriginConfig:
|
|
ss, err := s.Origin().NewSocket()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
MarshalUnixSocketOriginConfig(ss, config)
|
|
case *WebSocketOriginConfig:
|
|
ss, err := s.Origin().NewWebsocket()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
MarshalWebSocketOriginConfig(ss, config)
|
|
case *HelloWorldOriginConfig:
|
|
ss, err := s.Origin().NewHelloWorld()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
MarshalHelloWorldOriginConfig(ss, config)
|
|
default:
|
|
return fmt.Errorf("Unknown type for config: %T", config)
|
|
}
|
|
s.SetRetries(p.Retries)
|
|
s.SetConnectionTimeout(p.ConnectionTimeout.Nanoseconds())
|
|
s.SetChunkedEncoding(p.ChunkedEncoding)
|
|
s.SetCompressionQuality(p.CompressionQuality)
|
|
return nil
|
|
}
|
|
|
|
func UnmarshalReverseProxyConfig(s tunnelrpc.ReverseProxyConfig) (*ReverseProxyConfig, error) {
|
|
p := new(ReverseProxyConfig)
|
|
tunnelID, err := s.TunnelID()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
p.TunnelID = tunnelID
|
|
switch s.Origin().Which() {
|
|
case tunnelrpc.ReverseProxyConfig_origin_Which_http:
|
|
ss, err := s.Origin().Http()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
config, err := UnmarshalHTTPOriginConfig(ss)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
p.Origin = config
|
|
case tunnelrpc.ReverseProxyConfig_origin_Which_socket:
|
|
ss, err := s.Origin().Socket()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
config, err := UnmarshalUnixSocketOriginConfig(ss)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
p.Origin = config
|
|
case tunnelrpc.ReverseProxyConfig_origin_Which_websocket:
|
|
ss, err := s.Origin().Websocket()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
config, err := UnmarshalWebSocketOriginConfig(ss)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
p.Origin = config
|
|
case tunnelrpc.ReverseProxyConfig_origin_Which_helloWorld:
|
|
ss, err := s.Origin().HelloWorld()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
config, err := UnmarshalHelloWorldOriginConfig(ss)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
p.Origin = config
|
|
}
|
|
p.Retries = s.Retries()
|
|
p.ConnectionTimeout = time.Duration(s.ConnectionTimeout())
|
|
p.ChunkedEncoding = s.ChunkedEncoding()
|
|
p.CompressionQuality = s.CompressionQuality()
|
|
return p, nil
|
|
}
|
|
|
|
func MarshalHTTPOriginConfig(s tunnelrpc.HTTPOriginConfig, p *HTTPOriginConfig) error {
|
|
return pogs.Insert(tunnelrpc.HTTPOriginConfig_TypeID, s.Struct, p)
|
|
}
|
|
|
|
func UnmarshalHTTPOriginConfig(s tunnelrpc.HTTPOriginConfig) (*HTTPOriginConfig, error) {
|
|
p := new(HTTPOriginConfig)
|
|
err := pogs.Extract(p, tunnelrpc.HTTPOriginConfig_TypeID, s.Struct)
|
|
return p, err
|
|
}
|
|
|
|
func MarshalUnixSocketOriginConfig(s tunnelrpc.UnixSocketOriginConfig, p *UnixSocketOriginConfig) error {
|
|
return pogs.Insert(tunnelrpc.UnixSocketOriginConfig_TypeID, s.Struct, p)
|
|
}
|
|
|
|
func UnmarshalUnixSocketOriginConfig(s tunnelrpc.UnixSocketOriginConfig) (*UnixSocketOriginConfig, error) {
|
|
p := new(UnixSocketOriginConfig)
|
|
err := pogs.Extract(p, tunnelrpc.UnixSocketOriginConfig_TypeID, s.Struct)
|
|
return p, err
|
|
}
|
|
|
|
func MarshalWebSocketOriginConfig(s tunnelrpc.WebSocketOriginConfig, p *WebSocketOriginConfig) error {
|
|
return pogs.Insert(tunnelrpc.WebSocketOriginConfig_TypeID, s.Struct, p)
|
|
}
|
|
|
|
func UnmarshalWebSocketOriginConfig(s tunnelrpc.WebSocketOriginConfig) (*WebSocketOriginConfig, error) {
|
|
p := new(WebSocketOriginConfig)
|
|
err := pogs.Extract(p, tunnelrpc.WebSocketOriginConfig_TypeID, s.Struct)
|
|
return p, err
|
|
}
|
|
|
|
func MarshalHelloWorldOriginConfig(s tunnelrpc.HelloWorldOriginConfig, p *HelloWorldOriginConfig) error {
|
|
return pogs.Insert(tunnelrpc.HelloWorldOriginConfig_TypeID, s.Struct, p)
|
|
}
|
|
|
|
func UnmarshalHelloWorldOriginConfig(s tunnelrpc.HelloWorldOriginConfig) (*HelloWorldOriginConfig, error) {
|
|
p := new(HelloWorldOriginConfig)
|
|
err := pogs.Extract(p, tunnelrpc.HelloWorldOriginConfig_TypeID, s.Struct)
|
|
return p, err
|
|
}
|
|
|
|
type CloudflaredServer interface {
|
|
UseConfiguration(ctx context.Context, config *CloudflaredConfig) (*CloudflaredConfig, error)
|
|
GetConfiguration(ctx context.Context) (*CloudflaredConfig, error)
|
|
}
|
|
|
|
type CloudflaredServer_PogsClient struct {
|
|
Client capnp.Client
|
|
Conn *rpc.Conn
|
|
}
|
|
|
|
func (c *CloudflaredServer_PogsClient) Close() error {
|
|
return c.Conn.Close()
|
|
}
|
|
|
|
func (c *CloudflaredServer_PogsClient) UseConfiguration(
|
|
ctx context.Context,
|
|
config *CloudflaredConfig,
|
|
) (*UseConfigurationResult, error) {
|
|
client := tunnelrpc.CloudflaredServer{Client: c.Client}
|
|
promise := client.UseConfiguration(ctx, func(p tunnelrpc.CloudflaredServer_useConfiguration_Params) error {
|
|
cloudflaredConfig, err := p.NewCloudflaredConfig()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return MarshalCloudflaredConfig(cloudflaredConfig, config)
|
|
})
|
|
retval, err := promise.Result().Struct()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return UnmarshalUseConfigurationResult(retval)
|
|
}
|