package socks import ( "fmt" "io" "net" "strings" "github.com/rs/zerolog" "github.com/cloudflare/cloudflared/ipaccess" ) // RequestHandler is the functions needed to handle a SOCKS5 command type RequestHandler interface { Handle(*Request, io.ReadWriter) error } // StandardRequestHandler implements the base socks5 command processing type StandardRequestHandler struct { dialer Dialer accessPolicy *ipaccess.Policy } // NewRequestHandler creates a standard SOCKS5 request handler // This handles the SOCKS5 commands and proxies them to their destination func NewRequestHandler(dialer Dialer, accessPolicy *ipaccess.Policy) RequestHandler { return &StandardRequestHandler{ dialer: dialer, accessPolicy: accessPolicy, } } // Handle processes and responds to socks5 commands func (h *StandardRequestHandler) Handle(req *Request, conn io.ReadWriter) error { switch req.Command { case connectCommand: return h.handleConnect(conn, req) case bindCommand: return h.handleBind(conn, req) case associateCommand: return h.handleAssociate(conn, req) default: if err := sendReply(conn, commandNotSupported, nil); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } return fmt.Errorf("Unsupported command: %v", req.Command) } } // handleConnect is used to handle a connect command func (h *StandardRequestHandler) handleConnect(conn io.ReadWriter, req *Request) error { if h.accessPolicy != nil { if req.DestAddr.IP == nil { addr, err := net.ResolveIPAddr("ip", req.DestAddr.FQDN) if err != nil { _ = sendReply(conn, ruleFailure, req.DestAddr) return fmt.Errorf("unable to resolve host to confirm access") } req.DestAddr.IP = addr.IP } if allowed, rule := h.accessPolicy.Allowed(req.DestAddr.IP, req.DestAddr.Port); !allowed { _ = sendReply(conn, ruleFailure, req.DestAddr) if rule != nil { return fmt.Errorf("Connect to %v denied due to iprule: %s", req.DestAddr, rule.String()) } return fmt.Errorf("Connect to %v denied", req.DestAddr) } } target, localAddr, err := h.dialer.Dial(req.DestAddr.Address()) if err != nil { msg := err.Error() resp := hostUnreachable if strings.Contains(msg, "refused") { resp = connectionRefused } else if strings.Contains(msg, "network is unreachable") { resp = networkUnreachable } if err := sendReply(conn, resp, nil); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } return fmt.Errorf("Connect to %v failed: %v", req.DestAddr, err) } defer target.Close() // Send success if err := sendReply(conn, successReply, localAddr); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } // Start proxying proxyDone := make(chan error, 2) go func() { _, e := io.Copy(target, req.bufConn) proxyDone <- e }() go func() { _, e := io.Copy(conn, target) proxyDone <- e }() // Wait for both for i := 0; i < 2; i++ { e := <-proxyDone if e != nil { return e } } return nil } // handleBind is used to handle a bind command // TODO: Support bind command func (h *StandardRequestHandler) handleBind(conn io.ReadWriter, req *Request) error { if err := sendReply(conn, commandNotSupported, nil); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } return nil } // handleAssociate is used to handle a connect command // TODO: Support associate command func (h *StandardRequestHandler) handleAssociate(conn io.ReadWriter, req *Request) error { if err := sendReply(conn, commandNotSupported, nil); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } return nil } func StreamHandler(tunnelConn io.ReadWriter, originConn net.Conn, log *zerolog.Logger) { dialer := NewConnDialer(originConn) requestHandler := NewRequestHandler(dialer, nil) socksServer := NewConnectionHandler(requestHandler) if err := socksServer.Serve(tunnelConn); err != nil { log.Debug().Err(err).Msg("Socks stream handler error") } } func StreamNetHandler(tunnelConn io.ReadWriter, accessPolicy *ipaccess.Policy, log *zerolog.Logger) { dialer := NewNetDialer() requestHandler := NewRequestHandler(dialer, accessPolicy) socksServer := NewConnectionHandler(requestHandler) if err := socksServer.Serve(tunnelConn); err != nil { log.Debug().Err(err).Msg("Socks stream handler error") } }