diff --git a/ingress/ingress_test.go b/ingress/ingress_test.go index 109cb353..8dfd5416 100644 --- a/ingress/ingress_test.go +++ b/ingress/ingress_test.go @@ -43,6 +43,19 @@ ingress: require.Equal(t, "https", s.scheme) } +func TestParseUnixSocketTCP(t *testing.T) { + rawYAML := ` +ingress: +- service: unix+tcp:/run/sshd.sock +` + ing, err := ParseIngress(MustReadIngress(rawYAML)) + require.NoError(t, err) + s, ok := ing.Rules[0].Service.(*unixSocketTCPService) + require.True(t, ok) + require.Equal(t, "/run/sshd.sock", s.path) + require.Equal(t, "unix+tcp:/run/sshd.sock", s.String()) +} + func TestParseIngressNilConfig(t *testing.T) { _, err := ParseIngress(nil) require.Error(t, err) @@ -322,6 +335,19 @@ ingress: }, }, }, + { + name: "Unix+TCP service", + args: args{rawYAML: ` +ingress: +- service: unix+tcp:/run/sshd.sock +`}, + want: []Rule{ + { + Service: &unixSocketTCPService{path: "/run/sshd.sock"}, + Config: defaultConfig, + }, + }, + }, { name: "RDP services", args: args{rawYAML: ` diff --git a/ingress/origin_proxy_test.go b/ingress/origin_proxy_test.go index 7a6170a2..aab4697e 100644 --- a/ingress/origin_proxy_test.go +++ b/ingress/origin_proxy_test.go @@ -8,6 +8,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "os" "testing" "github.com/stretchr/testify/assert" @@ -186,6 +187,35 @@ func TestHTTPServiceUsesIngressRuleScheme(t *testing.T) { } } +func TestUnixSocketTCPServiceEstablishConnection(t *testing.T) { + dir, err := os.MkdirTemp("/tmp", "cf-test-") + require.NoError(t, err) + defer os.RemoveAll(dir) + + socketPath := dir + "/sshd.sock" + originListener, err := net.Listen("unix", socketPath) + require.NoError(t, err) + + listenerClosed := make(chan struct{}) + tcpListenRoutine(originListener, listenerClosed) + + svc := &unixSocketTCPService{path: socketPath} + require.NoError(t, svc.start(TestLogger, make(chan struct{}), OriginRequestConfig{})) + + // Successful connection to the unix socket + conn, err := svc.EstablishConnection(context.Background(), "", TestLogger) + require.NoError(t, err) + require.NotNil(t, conn) + conn.Close() + + // Close the listener and verify that new connections fail + originListener.Close() + <-listenerClosed + + _, err = svc.EstablishConnection(context.Background(), "", TestLogger) + require.Error(t, err) +} + func tcpListenRoutine(listener net.Listener, closeChan chan struct{}) { go func() { for {