diff --git a/ingress/ingress.go b/ingress/ingress.go index a325271a..f00ed231 100644 --- a/ingress/ingress.go +++ b/ingress/ingress.go @@ -255,6 +255,10 @@ func validateIngress(ingress []config.UnvalidatedIngressRule, defaults OriginReq } else if prefix := "unix+tls:"; strings.HasPrefix(r.Service, prefix) { path := strings.TrimPrefix(r.Service, prefix) service = &unixSocketPath{path: path, scheme: "https"} + } else if prefix := "unix+tcp:"; strings.HasPrefix(r.Service, prefix) { + // Stream raw bytes (e.g. SSH, RDP protocol) directly into a unix socket without HTTP wrapping + path := strings.TrimPrefix(r.Service, prefix) + service = &unixSocketTCPService{path: path} } else if prefix := "http_status:"; strings.HasPrefix(r.Service, prefix) { statusCode, err := strconv.Atoi(strings.TrimPrefix(r.Service, prefix)) if err != nil { diff --git a/ingress/origin_proxy.go b/ingress/origin_proxy.go index 7371eac9..10f33156 100644 --- a/ingress/origin_proxy.go +++ b/ingress/origin_proxy.go @@ -119,3 +119,14 @@ func (o *tcpOverWSService) EstablishConnection(ctx context.Context, dest string, func (o *socksProxyOverWSService) EstablishConnection(_ context.Context, _ string, _ *zerolog.Logger) (OriginConnection, error) { return o.conn, nil } + +func (o *unixSocketTCPService) EstablishConnection(ctx context.Context, _ string, _ *zerolog.Logger) (OriginConnection, error) { + conn, err := o.dialer.DialContext(ctx, "unix", o.path) + if err != nil { + return nil, err + } + return &tcpOverWSConnection{ + conn: conn, + streamHandler: o.streamHandler, + }, nil +} diff --git a/ingress/origin_service.go b/ingress/origin_service.go index e13204c5..0a5a2be2 100644 --- a/ingress/origin_service.go +++ b/ingress/origin_service.go @@ -46,6 +46,14 @@ type unixSocketPath struct { transport *http.Transport } +// unixSocketTCPService is an OriginService that streams raw bytes (e.g. SSH, RDP) directly into a +// unix socket, bypassing HTTP entirely. It is the unix-socket analogue of tcpOverWSService. +type unixSocketTCPService struct { + path string + streamHandler streamHandlerFunc + dialer net.Dialer +} + func (o *unixSocketPath) String() string { scheme := "" if o.scheme == "https" { @@ -67,6 +75,21 @@ func (o unixSocketPath) MarshalJSON() ([]byte, error) { return json.Marshal(o.String()) } +func (o *unixSocketTCPService) String() string { + return "unix+tcp:" + o.path +} + +func (o *unixSocketTCPService) start(_ *zerolog.Logger, _ <-chan struct{}, cfg OriginRequestConfig) error { + o.streamHandler = DefaultStreamHandler + o.dialer.Timeout = cfg.ConnectTimeout.Duration + o.dialer.KeepAlive = cfg.TCPKeepAlive.Duration + return nil +} + +func (o unixSocketTCPService) MarshalJSON() ([]byte, error) { + return json.Marshal(o.String()) +} + type httpService struct { url *url.URL hostHeader string