diff --git a/config/configuration.go b/config/configuration.go index a65dd6cd..1b455d3a 100644 --- a/config/configuration.go +++ b/config/configuration.go @@ -221,6 +221,14 @@ type OriginRequestConfig struct { ProxyPort *uint `yaml:"proxyPort"` // Valid options are 'socks' or empty. ProxyType *string `yaml:"proxyType"` + // IP rules for the proxy service + IPRules []IngressIPRule `yaml:"ipRules"` +} + +type IngressIPRule struct { + Prefix *string `yaml:"prefix"` + Ports []int `yaml:"ports"` + Allow bool `yaml:"allow"` } type Configuration struct { diff --git a/ingress/ingress.go b/ingress/ingress.go index 10aeabfa..9985595c 100644 --- a/ingress/ingress.go +++ b/ingress/ingress.go @@ -10,6 +10,7 @@ import ( "sync" "github.com/cloudflare/cloudflared/config" + "github.com/cloudflare/cloudflared/ipaccess" "github.com/pkg/errors" "github.com/rs/zerolog" @@ -26,6 +27,7 @@ var ( const ( ServiceBastion = "bastion" + ServiceSocksProxy = "socks-proxy" ServiceWarpRouting = "warp-routing" ) @@ -175,6 +177,23 @@ func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestCon service = &srv } else if r.Service == "hello_world" || r.Service == "hello-world" || r.Service == "helloworld" { service = new(helloWorld) + } else if r.Service == ServiceSocksProxy { + rules := make([]ipaccess.Rule, len(r.OriginRequest.IPRules)) + + for i, ipRule := range r.OriginRequest.IPRules { + rule, err := ipaccess.NewRuleByCIDR(ipRule.Prefix, ipRule.Ports, ipRule.Allow) + if err != nil { + return Ingress{}, fmt.Errorf("unable to create ip rule for %s: %s", r.Service, err) + } + rules[i] = rule + } + + accessPolicy, err := ipaccess.NewPolicy(false, rules) + if err != nil { + return Ingress{}, fmt.Errorf("unable to create ip access policy for %s: %s", r.Service, err) + } + + service = newSocksProxyOverWSService(accessPolicy) } else if r.Service == ServiceBastion || cfg.BastionMode { // Bastion mode will always start a Websocket proxy server, which will // overwrite the localService.URL field when `start` is called. So, diff --git a/ingress/ingress_test.go b/ingress/ingress_test.go index 7b87a35b..ecbb54ba 100644 --- a/ingress/ingress_test.go +++ b/ingress/ingress_test.go @@ -15,6 +15,7 @@ import ( "gopkg.in/yaml.v2" "github.com/cloudflare/cloudflared/config" + "github.com/cloudflare/cloudflared/ipaccess" "github.com/cloudflare/cloudflared/tlsconfig" ) @@ -304,6 +305,33 @@ ingress: }, }, }, + { + name: "SOCKS services", + args: args{rawYAML: ` +ingress: +- hostname: socks.foo.com + service: socks-proxy + originRequest: + ipRules: + - prefix: 1.1.1.0/24 + ports: [80, 443] + allow: true + - prefix: 0.0.0.0/0 + allow: false +- service: http_status:404 +`}, + want: []Rule{ + { + Hostname: "socks.foo.com", + Service: newSocksProxyOverWSService(accessPolicy()), + Config: defaultConfig, + }, + { + Service: &fourOhFour, + Config: defaultConfig, + }, + }, + }, { name: "URL isn't necessary if using bastion", args: args{rawYAML: ` @@ -548,6 +576,16 @@ func MustParseURL(t *testing.T, rawURL string) *url.URL { return u } +func accessPolicy() *ipaccess.Policy { + cidr1 := "1.1.1.0/24" + cidr2 := "0.0.0.0/0" + rule1, _ := ipaccess.NewRuleByCIDR(&cidr1, []int{80, 443}, true) + rule2, _ := ipaccess.NewRuleByCIDR(&cidr2, nil, false) + rules := []ipaccess.Rule{rule1, rule2} + accessPolicy, _ := ipaccess.NewPolicy(false, rules) + return accessPolicy +} + func BenchmarkFindMatch(b *testing.B) { rulesYAML := ` ingress: diff --git a/ingress/origin_connection.go b/ingress/origin_connection.go index 44cfdc4d..8dc651ff 100644 --- a/ingress/origin_connection.go +++ b/ingress/origin_connection.go @@ -7,6 +7,8 @@ import ( "net" "net/http" + "github.com/cloudflare/cloudflared/ipaccess" + "github.com/cloudflare/cloudflared/socks" "github.com/cloudflare/cloudflared/websocket" gws "github.com/gorilla/websocket" "github.com/rs/zerolog" @@ -107,3 +109,17 @@ func newWSConnection(clientTLSConfig *tls.Config, r *http.Request) (OriginConnec resp, }, resp, nil } + +// socksProxyOverWSConnection is an OriginConnection that streams SOCKS connections over WS. +// The connection to the origin happens inside the SOCKS code as the client specifies the origin +// details in the packet. +type socksProxyOverWSConnection struct { + accessPolicy *ipaccess.Policy +} + +func (sp *socksProxyOverWSConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) { + socks.StreamNetHandler(websocket.NewConn(ctx, tunnelConn, log), sp.accessPolicy, log) +} + +func (sp *socksProxyOverWSConnection) Close() { +} diff --git a/ingress/origin_proxy.go b/ingress/origin_proxy.go index 98f144e4..14162023 100644 --- a/ingress/origin_proxy.go +++ b/ingress/origin_proxy.go @@ -145,3 +145,14 @@ func (o *tcpOverWSService) bastionDest(r *http.Request) (string, error) { func removePath(dest string) string { return strings.SplitN(dest, "/", 2)[0] } + +func (o *socksProxyOverWSService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) { + originConn := o.conn + resp := &http.Response{ + Status: switchingProtocolText, + StatusCode: http.StatusSwitchingProtocols, + Header: websocket.NewResponseHeader(r), + ContentLength: -1, + } + return originConn, resp, nil +} diff --git a/ingress/origin_request_config.go b/ingress/origin_request_config.go index 575e2170..02b15c4b 100644 --- a/ingress/origin_request_config.go +++ b/ingress/origin_request_config.go @@ -3,6 +3,7 @@ package ingress import ( "time" + "github.com/cloudflare/cloudflared/ipaccess" "github.com/urfave/cli/v2" "github.com/cloudflare/cloudflared/config" @@ -213,6 +214,8 @@ type OriginRequestConfig struct { ProxyPort uint `yaml:"proxyPort"` // What sort of proxy should be started ProxyType string `yaml:"proxyType"` + // IP rules for the proxy service + IPRules []ipaccess.Rule `yaml:"ipRules"` } func (defaults *OriginRequestConfig) setConnectTimeout(overrides config.OriginRequestConfig) { diff --git a/ingress/origin_service.go b/ingress/origin_service.go index 8d55e7eb..e621cf73 100644 --- a/ingress/origin_service.go +++ b/ingress/origin_service.go @@ -12,6 +12,7 @@ import ( "time" "github.com/cloudflare/cloudflared/hello" + "github.com/cloudflare/cloudflared/ipaccess" "github.com/cloudflare/cloudflared/socks" "github.com/cloudflare/cloudflared/tlsconfig" "github.com/cloudflare/cloudflared/websocket" @@ -100,6 +101,10 @@ type tcpOverWSService struct { streamHandler streamHandlerFunc } +type socksProxyOverWSService struct { + conn *socksProxyOverWSConnection +} + func newTCPOverWSService(url *url.URL) *tcpOverWSService { switch url.Scheme { case "ssh": @@ -122,6 +127,16 @@ func newBastionService() *tcpOverWSService { } } +func newSocksProxyOverWSService(accessPolicy *ipaccess.Policy) *socksProxyOverWSService { + proxy := socksProxyOverWSService{ + conn: &socksProxyOverWSConnection{ + accessPolicy: accessPolicy, + }, + } + + return &proxy +} + func addPortIfMissing(uri *url.URL, port int) { if uri.Port() == "" { uri.Host = fmt.Sprintf("%s:%d", uri.Hostname(), port) @@ -144,6 +159,14 @@ func (o *tcpOverWSService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdo return nil } +func (o *socksProxyOverWSService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { + return nil +} + +func (o *socksProxyOverWSService) String() string { + return ServiceSocksProxy +} + // HelloWorld is an OriginService for the built-in Hello World server. // Users only use this for testing and experimenting with cloudflared. type helloWorld struct { diff --git a/ipaccess/access.go b/ipaccess/access.go new file mode 100644 index 00000000..3136f2b5 --- /dev/null +++ b/ipaccess/access.go @@ -0,0 +1,101 @@ +package ipaccess + +import ( + "fmt" + "net" + "sort" +) + +type Policy struct { + defaultAllow bool + rules []Rule +} + +type Rule struct { + ipNet *net.IPNet + ports []int + allow bool +} + +func NewPolicy(defaultAllow bool, rules []Rule) (*Policy, error) { + for _, rule := range rules { + if err := rule.Validate(); err != nil { + return nil, err + } + } + + policy := Policy{ + defaultAllow: defaultAllow, + rules: rules, + } + + return &policy, nil +} + +func NewRuleByCIDR(prefix *string, ports []int, allow bool) (Rule, error) { + if prefix == nil || len(*prefix) == 0 { + return Rule{}, fmt.Errorf("no prefix provided") + } + + _, ipnet, err := net.ParseCIDR(*prefix) + if err != nil { + return Rule{}, fmt.Errorf("unable to parse cidr: %s", *prefix) + } + + return NewRule(ipnet, ports, allow) +} + +func NewRule(ipnet *net.IPNet, ports []int, allow bool) (Rule, error) { + rule := Rule{ + ipNet: ipnet, + ports: ports, + allow: allow, + } + return rule, rule.Validate() +} + +func (r *Rule) Validate() error { + if r.ipNet == nil { + return fmt.Errorf("no ipnet set on the rule") + } + + if len(r.ports) > 0 { + sort.Ints(r.ports) + for _, port := range r.ports { + if port < 1 || port > 65535 { + return fmt.Errorf("invalid port %d, needs to be between 1 and 65535", port) + } + } + } + + return nil +} + +func (h *Policy) Allowed(ip net.IP, port int) (bool, *Rule) { + if len(h.rules) == 0 { + return h.defaultAllow, nil + } + + for _, rule := range h.rules { + if rule.ipNet.Contains(ip) { + if len(rule.ports) == 0 { + return rule.allow, &rule + } else if pos := sort.SearchInts(rule.ports, port); pos < len(rule.ports) && rule.ports[pos] == port { + return rule.allow, &rule + } + } + } + + return h.defaultAllow, nil +} + +func (ipr *Rule) String() string { + return fmt.Sprintf("prefix:%s/port:%s/allow:%t", ipr.ipNet, ipr.PortsString(), ipr.allow) +} + +func (ipr *Rule) PortsString() string { + if len(ipr.ports) > 0 { + return fmt.Sprint(ipr.ports) + } + return "all" +} diff --git a/ipaccess/access_test.go b/ipaccess/access_test.go new file mode 100644 index 00000000..118ba97b --- /dev/null +++ b/ipaccess/access_test.go @@ -0,0 +1,107 @@ +package ipaccess + +import ( + "bytes" + "net" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRuleCreation(t *testing.T) { + _, ipnet, _ := net.ParseCIDR("1.1.1.1/24") + + _, err := NewRule(nil, []int{80}, false) + assert.Error(t, err, "expected error as no ipnet provided") + + _, err = NewRule(ipnet, []int{65536, 80}, false) + assert.Error(t, err, "expected error as port higher than 65535") + + _, err = NewRule(ipnet, []int{80, -1}, false) + assert.Error(t, err, "expected error as port less than 0") + + rule, err := NewRule(ipnet, []int{443, 80}, false) + assert.NoError(t, err) + assert.True(t, ipnet.IP.Equal(rule.ipNet.IP) && bytes.Compare(ipnet.Mask, rule.ipNet.Mask) == 0, "ipnet expected to be %+v, got: %+v", ipnet, rule.ipNet) + assert.True(t, len(rule.ports) == 2 && rule.ports[0] == 80 && rule.ports[1] == 443, "expected ports to be sorted") +} + +func TestRuleCreationByCIDR(t *testing.T) { + var cidr *string + _, err := NewRuleByCIDR(cidr, []int{80}, false) + assert.Error(t, err, "expected error as cidr is nil") + + badCidr := "1.1.1.1" + cidr = &badCidr + _, err = NewRuleByCIDR(cidr, []int{80}, false) + assert.Error(t, err, "expected error as the cidr is bad") + + goodCidr := "1.1.1.1/24" + _, ipnet, _ := net.ParseCIDR("1.1.1.0/24") + cidr = &goodCidr + rule, err := NewRuleByCIDR(cidr, []int{80}, false) + assert.NoError(t, err) + assert.True(t, ipnet.IP.Equal(rule.ipNet.IP) && bytes.Compare(ipnet.Mask, rule.ipNet.Mask) == 0, "ipnet expected to be %+v, got: %+v", ipnet, rule.ipNet) +} + +func TestRulesNoRules(t *testing.T) { + ip, _, _ := net.ParseCIDR("1.2.3.4/24") + + policy, _ := NewPolicy(true, []Rule{}) + + allowed, rule := policy.Allowed(ip, 80) + assert.True(t, allowed, "expected to be allowed as no rules and default allow") + assert.Nil(t, rule, "expected to be nil as no rules") + + policy, _ = NewPolicy(false, []Rule{}) + + allowed, rule = policy.Allowed(ip, 80) + assert.False(t, allowed, "expected to be denied as no rules and default deny") + assert.Nil(t, rule, "expected to be nil as no rules") +} + +func TestRulesMatchIPAndPort(t *testing.T) { + ip1, ipnet1, _ := net.ParseCIDR("1.2.3.4/24") + ip2, _, _ := net.ParseCIDR("2.3.4.5/24") + + rule1, _ := NewRule(ipnet1, []int{80, 443}, true) + rules := []Rule{ + rule1, + } + + policy, _ := NewPolicy(false, rules) + + allowed, rule := policy.Allowed(ip1, 80) + assert.True(t, allowed, "expected to be allowed as matching rule") + assert.True(t, rule.ipNet == ipnet1, "expected to match ipnet1") + + allowed, rule = policy.Allowed(ip2, 80) + assert.False(t, allowed, "expected to be denied as no matching rule") + assert.Nil(t, rule, "expected to be nil") +} + +func TestRulesMatchIPAndPort2(t *testing.T) { + ip1, ipnet1, _ := net.ParseCIDR("1.2.3.4/24") + ip2, ipnet2, _ := net.ParseCIDR("2.3.4.5/24") + + rule1, _ := NewRule(ipnet1, []int{53, 80}, false) + rule2, _ := NewRule(ipnet2, []int{53, 80}, true) + rules := []Rule{ + rule1, + rule2, + } + + policy, _ := NewPolicy(false, rules) + + allowed, rule := policy.Allowed(ip1, 80) + assert.False(t, allowed, "expected to be denied as matching rule") + assert.True(t, rule.ipNet == ipnet1, "expected to match ipnet1") + + allowed, rule = policy.Allowed(ip2, 80) + assert.True(t, allowed, "expected to be allowed as matching rule") + assert.True(t, rule.ipNet == ipnet2, "expected to match ipnet1") + + allowed, rule = policy.Allowed(ip2, 81) + assert.False(t, allowed, "expected to be denied as no matching rule") + assert.Nil(t, rule, "expected to be nil") +} diff --git a/socks/connection_handler_test.go b/socks/connection_handler_test.go index 9370de93..424d8185 100644 --- a/socks/connection_handler_test.go +++ b/socks/connection_handler_test.go @@ -40,7 +40,7 @@ func sendSocksRequest(t *testing.T) []byte { func startTestServer(t *testing.T, httpHandler func(w http.ResponseWriter, r *http.Request)) { // create a socks server - requestHandler := NewRequestHandler(NewNetDialer()) + requestHandler := NewRequestHandler(NewNetDialer(), nil) socksServer := NewConnectionHandler(requestHandler) listener, err := net.Listen("tcp", "localhost:8086") assert.NoError(t, err) diff --git a/socks/request_handler.go b/socks/request_handler.go index 904751c9..d9266275 100644 --- a/socks/request_handler.go +++ b/socks/request_handler.go @@ -6,6 +6,7 @@ import ( "net" "strings" + "github.com/cloudflare/cloudflared/ipaccess" "github.com/rs/zerolog" ) @@ -16,14 +17,16 @@ type RequestHandler interface { // StandardRequestHandler implements the base socks5 command processing type StandardRequestHandler struct { - dialer Dialer + dialer Dialer + accessPolicy *ipaccess.Policy } // NewRequestHandler creates a standard SOCKS5 request handler // This handles the SOCKS5 commands and proxies them to their destination -func NewRequestHandler(dialer Dialer) RequestHandler { +func NewRequestHandler(dialer Dialer, accessPolicy *ipaccess.Policy) RequestHandler { return &StandardRequestHandler{ - dialer: dialer, + dialer: dialer, + accessPolicy: accessPolicy, } } @@ -46,6 +49,25 @@ func (h *StandardRequestHandler) Handle(req *Request, conn io.ReadWriter) error // handleConnect is used to handle a connect command func (h *StandardRequestHandler) handleConnect(conn io.ReadWriter, req *Request) error { + if h.accessPolicy != nil { + if req.DestAddr.IP == nil { + addr, err := net.ResolveIPAddr("ip", req.DestAddr.FQDN) + if err != nil { + _ = sendReply(conn, ruleFailure, req.DestAddr) + return fmt.Errorf("unable to resolve host to confirm acceess") + } + + req.DestAddr.IP = addr.IP + } + if allowed, rule := h.accessPolicy.Allowed(req.DestAddr.IP, req.DestAddr.Port); !allowed { + _ = sendReply(conn, ruleFailure, req.DestAddr) + if rule != nil { + return fmt.Errorf("Connect to %v denied due to iprule: %s", req.DestAddr, rule.String()) + } + return fmt.Errorf("Connect to %v denied", req.DestAddr) + } + } + target, localAddr, err := h.dialer.Dial(req.DestAddr.Address()) if err != nil { msg := err.Error() @@ -110,7 +132,17 @@ func (h *StandardRequestHandler) handleAssociate(conn io.ReadWriter, req *Reques func StreamHandler(tunnelConn io.ReadWriter, originConn net.Conn, log *zerolog.Logger) { dialer := NewConnDialer(originConn) - requestHandler := NewRequestHandler(dialer) + requestHandler := NewRequestHandler(dialer, nil) + socksServer := NewConnectionHandler(requestHandler) + + if err := socksServer.Serve(tunnelConn); err != nil { + log.Debug().Err(err).Msg("Socks stream handler error") + } +} + +func StreamNetHandler(tunnelConn io.ReadWriter, accessPolicy *ipaccess.Policy, log *zerolog.Logger) { + dialer := NewNetDialer() + requestHandler := NewRequestHandler(dialer, accessPolicy) socksServer := NewConnectionHandler(requestHandler) if err := socksServer.Serve(tunnelConn); err != nil { diff --git a/socks/request_handler_test.go b/socks/request_handler_test.go index 8a6d51c7..e45b5dbf 100644 --- a/socks/request_handler_test.go +++ b/socks/request_handler_test.go @@ -4,6 +4,7 @@ import ( "bytes" "testing" + "github.com/cloudflare/cloudflared/ipaccess" "github.com/stretchr/testify/assert" ) @@ -11,7 +12,7 @@ func TestUnsupportedBind(t *testing.T) { req := createRequest(t, socks5Version, bindCommand, "2001:db8::68", 1337, false) var b bytes.Buffer - requestHandler := NewRequestHandler(NewNetDialer()) + requestHandler := NewRequestHandler(NewNetDialer(), nil) err := requestHandler.Handle(req, &b) assert.NoError(t, err) assert.True(t, b.Bytes()[1] == commandNotSupported, "expected a response") @@ -21,8 +22,61 @@ func TestUnsupportedAssociate(t *testing.T) { req := createRequest(t, socks5Version, associateCommand, "127.0.0.1", 1337, false) var b bytes.Buffer - requestHandler := NewRequestHandler(NewNetDialer()) + requestHandler := NewRequestHandler(NewNetDialer(), nil) err := requestHandler.Handle(req, &b) assert.NoError(t, err) assert.True(t, b.Bytes()[1] == commandNotSupported, "expected a response") } + +func TestHandleConnect(t *testing.T) { + req := createRequest(t, socks5Version, connectCommand, "127.0.0.1", 1337, false) + var b bytes.Buffer + + requestHandler := NewRequestHandler(NewNetDialer(), nil) + err := requestHandler.Handle(req, &b) + assert.Error(t, err) + assert.True(t, b.Bytes()[1] == connectionRefused, "expected a response") +} + +func TestHandleConnectIPAccess(t *testing.T) { + prefix := "127.0.0.0/24" + rule1, _ := ipaccess.NewRuleByCIDR(&prefix, []int{1337}, true) + rule2, _ := ipaccess.NewRuleByCIDR(&prefix, []int{1338}, false) + rules := []ipaccess.Rule{rule1, rule2} + var b bytes.Buffer + + accessPolicy, _ := ipaccess.NewPolicy(false, nil) + requestHandler := NewRequestHandler(NewNetDialer(), accessPolicy) + req := createRequest(t, socks5Version, connectCommand, "127.0.0.1", 1337, false) + err := requestHandler.Handle(req, &b) + assert.Error(t, err) + assert.True(t, b.Bytes()[1] == ruleFailure, "expected to be denied as no rules and defaultAllow=false") + + b.Reset() + accessPolicy, _ = ipaccess.NewPolicy(true, nil) + requestHandler = NewRequestHandler(NewNetDialer(), accessPolicy) + req = createRequest(t, socks5Version, connectCommand, "127.0.0.1", 1337, false) + err = requestHandler.Handle(req, &b) + assert.Error(t, err) + assert.True(t, b.Bytes()[1] == connectionRefused, "expected to be allowed as no rules and defaultAllow=true") + + b.Reset() + accessPolicy, _ = ipaccess.NewPolicy(false, rules) + requestHandler = NewRequestHandler(NewNetDialer(), accessPolicy) + req = createRequest(t, socks5Version, connectCommand, "127.0.0.1", 1337, false) + err = requestHandler.Handle(req, &b) + assert.Error(t, err) + assert.True(t, b.Bytes()[1] == connectionRefused, "expected to be allowed as matching rule") + + b.Reset() + req = createRequest(t, socks5Version, connectCommand, "127.0.0.1", 1338, false) + err = requestHandler.Handle(req, &b) + assert.Error(t, err) + assert.True(t, b.Bytes()[1] == ruleFailure, "expected to be denied as matching rule") + + b.Reset() + req = createRequest(t, socks5Version, connectCommand, "127.0.0.1", 1339, false) + err = requestHandler.Handle(req, &b) + assert.Error(t, err) + assert.True(t, b.Bytes()[1] == ruleFailure, "expect to be denied as no matching rule and defaultAllow=false") +}