Merge 938e947be8 into d2a87e9b93
This commit is contained in:
commit
a361a7ba91
|
|
@ -255,6 +255,10 @@ func validateIngress(ingress []config.UnvalidatedIngressRule, defaults OriginReq
|
|||
} else if prefix := "unix+tls:"; strings.HasPrefix(r.Service, prefix) {
|
||||
path := strings.TrimPrefix(r.Service, prefix)
|
||||
service = &unixSocketPath{path: path, scheme: "https"}
|
||||
} else if prefix := "unix+tcp:"; strings.HasPrefix(r.Service, prefix) {
|
||||
// Stream raw bytes (e.g. SSH, RDP protocol) directly into a unix socket without HTTP wrapping
|
||||
path := strings.TrimPrefix(r.Service, prefix)
|
||||
service = &unixSocketTCPService{path: path}
|
||||
} else if prefix := "http_status:"; strings.HasPrefix(r.Service, prefix) {
|
||||
statusCode, err := strconv.Atoi(strings.TrimPrefix(r.Service, prefix))
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -43,6 +43,19 @@ ingress:
|
|||
require.Equal(t, "https", s.scheme)
|
||||
}
|
||||
|
||||
func TestParseUnixSocketTCP(t *testing.T) {
|
||||
rawYAML := `
|
||||
ingress:
|
||||
- service: unix+tcp:/run/sshd.sock
|
||||
`
|
||||
ing, err := ParseIngress(MustReadIngress(rawYAML))
|
||||
require.NoError(t, err)
|
||||
s, ok := ing.Rules[0].Service.(*unixSocketTCPService)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "/run/sshd.sock", s.path)
|
||||
require.Equal(t, "unix+tcp:/run/sshd.sock", s.String())
|
||||
}
|
||||
|
||||
func TestParseIngressNilConfig(t *testing.T) {
|
||||
_, err := ParseIngress(nil)
|
||||
require.Error(t, err)
|
||||
|
|
@ -322,6 +335,19 @@ ingress:
|
|||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Unix+TCP service",
|
||||
args: args{rawYAML: `
|
||||
ingress:
|
||||
- service: unix+tcp:/run/sshd.sock
|
||||
`},
|
||||
want: []Rule{
|
||||
{
|
||||
Service: &unixSocketTCPService{path: "/run/sshd.sock"},
|
||||
Config: defaultConfig,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "RDP services",
|
||||
args: args{rawYAML: `
|
||||
|
|
|
|||
|
|
@ -119,3 +119,14 @@ func (o *tcpOverWSService) EstablishConnection(ctx context.Context, dest string,
|
|||
func (o *socksProxyOverWSService) EstablishConnection(_ context.Context, _ string, _ *zerolog.Logger) (OriginConnection, error) {
|
||||
return o.conn, nil
|
||||
}
|
||||
|
||||
func (o *unixSocketTCPService) EstablishConnection(ctx context.Context, _ string, _ *zerolog.Logger) (OriginConnection, error) {
|
||||
conn, err := o.dialer.DialContext(ctx, "unix", o.path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &tcpOverWSConnection{
|
||||
conn: conn,
|
||||
streamHandler: o.streamHandler,
|
||||
}, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import (
|
|||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
|
@ -186,6 +187,35 @@ func TestHTTPServiceUsesIngressRuleScheme(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestUnixSocketTCPServiceEstablishConnection(t *testing.T) {
|
||||
dir, err := os.MkdirTemp("/tmp", "cf-test-")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
socketPath := dir + "/sshd.sock"
|
||||
originListener, err := net.Listen("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
listenerClosed := make(chan struct{})
|
||||
tcpListenRoutine(originListener, listenerClosed)
|
||||
|
||||
svc := &unixSocketTCPService{path: socketPath}
|
||||
require.NoError(t, svc.start(TestLogger, make(chan struct{}), OriginRequestConfig{}))
|
||||
|
||||
// Successful connection to the unix socket
|
||||
conn, err := svc.EstablishConnection(context.Background(), "", TestLogger)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn)
|
||||
conn.Close()
|
||||
|
||||
// Close the listener and verify that new connections fail
|
||||
originListener.Close()
|
||||
<-listenerClosed
|
||||
|
||||
_, err = svc.EstablishConnection(context.Background(), "", TestLogger)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func tcpListenRoutine(listener net.Listener, closeChan chan struct{}) {
|
||||
go func() {
|
||||
for {
|
||||
|
|
|
|||
|
|
@ -46,6 +46,14 @@ type unixSocketPath struct {
|
|||
transport *http.Transport
|
||||
}
|
||||
|
||||
// unixSocketTCPService is an OriginService that streams raw bytes (e.g. SSH, RDP) directly into a
|
||||
// unix socket, bypassing HTTP entirely. It is the unix-socket analogue of tcpOverWSService.
|
||||
type unixSocketTCPService struct {
|
||||
path string
|
||||
streamHandler streamHandlerFunc
|
||||
dialer net.Dialer
|
||||
}
|
||||
|
||||
func (o *unixSocketPath) String() string {
|
||||
scheme := ""
|
||||
if o.scheme == "https" {
|
||||
|
|
@ -67,6 +75,21 @@ func (o unixSocketPath) MarshalJSON() ([]byte, error) {
|
|||
return json.Marshal(o.String())
|
||||
}
|
||||
|
||||
func (o *unixSocketTCPService) String() string {
|
||||
return "unix+tcp:" + o.path
|
||||
}
|
||||
|
||||
func (o *unixSocketTCPService) start(_ *zerolog.Logger, _ <-chan struct{}, cfg OriginRequestConfig) error {
|
||||
o.streamHandler = DefaultStreamHandler
|
||||
o.dialer.Timeout = cfg.ConnectTimeout.Duration
|
||||
o.dialer.KeepAlive = cfg.TCPKeepAlive.Duration
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o unixSocketTCPService) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(o.String())
|
||||
}
|
||||
|
||||
type httpService struct {
|
||||
url *url.URL
|
||||
hostHeader string
|
||||
|
|
|
|||
Loading…
Reference in New Issue