This commit is contained in:
Ivan Kruglov 2026-03-15 11:23:28 +01:00 committed by GitHub
commit a361a7ba91
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 94 additions and 0 deletions

View File

@ -255,6 +255,10 @@ func validateIngress(ingress []config.UnvalidatedIngressRule, defaults OriginReq
} else if prefix := "unix+tls:"; strings.HasPrefix(r.Service, prefix) {
path := strings.TrimPrefix(r.Service, prefix)
service = &unixSocketPath{path: path, scheme: "https"}
} else if prefix := "unix+tcp:"; strings.HasPrefix(r.Service, prefix) {
// Stream raw bytes (e.g. SSH, RDP protocol) directly into a unix socket without HTTP wrapping
path := strings.TrimPrefix(r.Service, prefix)
service = &unixSocketTCPService{path: path}
} else if prefix := "http_status:"; strings.HasPrefix(r.Service, prefix) {
statusCode, err := strconv.Atoi(strings.TrimPrefix(r.Service, prefix))
if err != nil {

View File

@ -43,6 +43,19 @@ ingress:
require.Equal(t, "https", s.scheme)
}
func TestParseUnixSocketTCP(t *testing.T) {
rawYAML := `
ingress:
- service: unix+tcp:/run/sshd.sock
`
ing, err := ParseIngress(MustReadIngress(rawYAML))
require.NoError(t, err)
s, ok := ing.Rules[0].Service.(*unixSocketTCPService)
require.True(t, ok)
require.Equal(t, "/run/sshd.sock", s.path)
require.Equal(t, "unix+tcp:/run/sshd.sock", s.String())
}
func TestParseIngressNilConfig(t *testing.T) {
_, err := ParseIngress(nil)
require.Error(t, err)
@ -322,6 +335,19 @@ ingress:
},
},
},
{
name: "Unix+TCP service",
args: args{rawYAML: `
ingress:
- service: unix+tcp:/run/sshd.sock
`},
want: []Rule{
{
Service: &unixSocketTCPService{path: "/run/sshd.sock"},
Config: defaultConfig,
},
},
},
{
name: "RDP services",
args: args{rawYAML: `

View File

@ -119,3 +119,14 @@ func (o *tcpOverWSService) EstablishConnection(ctx context.Context, dest string,
func (o *socksProxyOverWSService) EstablishConnection(_ context.Context, _ string, _ *zerolog.Logger) (OriginConnection, error) {
return o.conn, nil
}
func (o *unixSocketTCPService) EstablishConnection(ctx context.Context, _ string, _ *zerolog.Logger) (OriginConnection, error) {
conn, err := o.dialer.DialContext(ctx, "unix", o.path)
if err != nil {
return nil, err
}
return &tcpOverWSConnection{
conn: conn,
streamHandler: o.streamHandler,
}, nil
}

View File

@ -8,6 +8,7 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"os"
"testing"
"github.com/stretchr/testify/assert"
@ -186,6 +187,35 @@ func TestHTTPServiceUsesIngressRuleScheme(t *testing.T) {
}
}
func TestUnixSocketTCPServiceEstablishConnection(t *testing.T) {
dir, err := os.MkdirTemp("/tmp", "cf-test-")
require.NoError(t, err)
defer os.RemoveAll(dir)
socketPath := dir + "/sshd.sock"
originListener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
listenerClosed := make(chan struct{})
tcpListenRoutine(originListener, listenerClosed)
svc := &unixSocketTCPService{path: socketPath}
require.NoError(t, svc.start(TestLogger, make(chan struct{}), OriginRequestConfig{}))
// Successful connection to the unix socket
conn, err := svc.EstablishConnection(context.Background(), "", TestLogger)
require.NoError(t, err)
require.NotNil(t, conn)
conn.Close()
// Close the listener and verify that new connections fail
originListener.Close()
<-listenerClosed
_, err = svc.EstablishConnection(context.Background(), "", TestLogger)
require.Error(t, err)
}
func tcpListenRoutine(listener net.Listener, closeChan chan struct{}) {
go func() {
for {

View File

@ -46,6 +46,14 @@ type unixSocketPath struct {
transport *http.Transport
}
// unixSocketTCPService is an OriginService that streams raw bytes (e.g. SSH, RDP) directly into a
// unix socket, bypassing HTTP entirely. It is the unix-socket analogue of tcpOverWSService.
type unixSocketTCPService struct {
path string
streamHandler streamHandlerFunc
dialer net.Dialer
}
func (o *unixSocketPath) String() string {
scheme := ""
if o.scheme == "https" {
@ -67,6 +75,21 @@ func (o unixSocketPath) MarshalJSON() ([]byte, error) {
return json.Marshal(o.String())
}
func (o *unixSocketTCPService) String() string {
return "unix+tcp:" + o.path
}
func (o *unixSocketTCPService) start(_ *zerolog.Logger, _ <-chan struct{}, cfg OriginRequestConfig) error {
o.streamHandler = DefaultStreamHandler
o.dialer.Timeout = cfg.ConnectTimeout.Duration
o.dialer.KeepAlive = cfg.TCPKeepAlive.Duration
return nil
}
func (o unixSocketTCPService) MarshalJSON() ([]byte, error) {
return json.Marshal(o.String())
}
type httpService struct {
url *url.URL
hostHeader string