diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index dd784fc8..d6fd239f 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -73,6 +73,9 @@ const ( // s3URLFlag is the S3 URL of SSH log uploader (e.g. don't use AWS s3 and use google storage bucket instead) s3URLFlag = "s3-url-host" + + // disablePortForwarding disables both remote and local ssh port forwarding + enablePortForwardingFlag = "enable-port-forwarding" ) var ( @@ -387,7 +390,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan } sshServerAddress := "127.0.0.1:" + c.String(sshPortFlag) - server, err := sshserver.New(logManager, logger, version, sshServerAddress, shutdownC, c.Duration(sshIdleTimeoutFlag), c.Duration(sshMaxTimeoutFlag)) + server, err := sshserver.New(logManager, logger, version, sshServerAddress, shutdownC, c.Duration(sshIdleTimeoutFlag), c.Duration(sshMaxTimeoutFlag), c.Bool(enablePortForwardingFlag)) if err != nil { logger.WithError(err).Error("Cannot create new SSH Server") return errors.Wrap(err, "Cannot create new SSH Server") @@ -1020,5 +1023,11 @@ func tunnelFlags(shouldHide bool) []cli.Flag { EnvVars: []string{"S3_URL"}, Hidden: true, }), + altsrc.NewBoolFlag(&cli.BoolFlag{ + Name: enablePortForwardingFlag, + Usage: "Enables remote and local SSH port forwarding", + EnvVars: []string{"ENABLE_PORT_FORWARDING"}, + Hidden: true, + }), } } diff --git a/sshserver/sshserver_unix.go b/sshserver/sshserver_unix.go index 82781d77..5e968664 100644 --- a/sshserver/sshserver_unix.go +++ b/sshserver/sshserver_unix.go @@ -55,7 +55,7 @@ type SSHServer struct { } // New creates a new SSHServer and configures its host keys and authenication by the data provided -func New(logManager sshlog.Manager, logger *logrus.Logger, version, address string, shutdownC chan struct{}, idleTimeout, maxTimeout time.Duration) (*SSHServer, error) { +func New(logManager sshlog.Manager, logger *logrus.Logger, version, address string, shutdownC chan struct{}, idleTimeout, maxTimeout time.Duration, enablePortForwarding bool) (*SSHServer, error) { currentUser, err := user.Current() if err != nil { return nil, err @@ -64,23 +64,40 @@ func New(logManager sshlog.Manager, logger *logrus.Logger, version, address stri return nil, errors.New("cloudflared SSH server needs to run as root") } + forwardHandler := &ssh.ForwardedTCPHandler{} sshServer := SSHServer{ Server: ssh.Server{ Addr: address, MaxTimeout: maxTimeout, IdleTimeout: idleTimeout, Version: fmt.Sprintf("SSH-2.0-Cloudflare-Access_%s_%s", version, runtime.GOOS), + // Register SSH global Request handlers to respond to tcpip forwarding + RequestHandlers: map[string]ssh.RequestHandler{ + "tcpip-forward": forwardHandler.HandleSSHRequest, + "cancel-tcpip-forward": forwardHandler.HandleSSHRequest, + }, + // Register SSH channel types + ChannelHandlers: map[string]ssh.ChannelHandler{ + "session": ssh.DefaultSessionHandler, + "direct-tcpip": ssh.DirectTCPIPHandler, + }, }, logger: logger, shutdownC: shutdownC, logManager: logManager, } + if enablePortForwarding { + sshServer.LocalPortForwardingCallback = allowForward + sshServer.ReversePortForwardingCallback = allowForward + } + if err := sshServer.configureHostKeys(); err != nil { return nil, err } sshServer.configureAuthentication() + return &sshServer, nil } @@ -285,19 +302,6 @@ func (s *SSHServer) startPtySession(cmd *exec.Cmd, winCh <-chan ssh.Window, logC return tty, tty, nil } -// Sets PTY window size for terminal -func setWinsize(f *os.File, w, h int) syscall.Errno { - _, _, errNo := syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ), - uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0}))) - return errNo -} - -func stringToUint32(str string) (uint32, error) { - uid, err := strconv.ParseUint(str, 10, 32) - return uint32(uid), err - -} - func (s *SSHServer) logAuditEvent(writer io.WriteCloser, session ssh.Session, sessionID string, eventType string) { username := "unknown" sshUser, ok := session.Context().Value("sshUser").(*User) @@ -322,3 +326,19 @@ func (s *SSHServer) logAuditEvent(writer io.WriteCloser, session ssh.Session, se line := string(data) + "\n" writer.Write([]byte(line)) } + +// Sets PTY window size for terminal +func setWinsize(f *os.File, w, h int) syscall.Errno { + _, _, errNo := syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ), + uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0}))) + return errNo +} + +func stringToUint32(str string) (uint32, error) { + uid, err := strconv.ParseUint(str, 10, 32) + return uint32(uid), err +} + +func allowForward(_ ssh.Context, _ string, _ uint32) bool { + return true +} diff --git a/sshserver/sshserver_windows.go b/sshserver/sshserver_windows.go index d5f89744..19601d38 100644 --- a/sshserver/sshserver_windows.go +++ b/sshserver/sshserver_windows.go @@ -13,7 +13,7 @@ import ( type SSHServer struct{} -func New(_ sshlog.Manager, _ *logrus.Logger, _, _ string, _ chan struct{}, _, _ time.Duration) (*SSHServer, error) { +func New(_ sshlog.Manager, _ *logrus.Logger, _, _ string, _ chan struct{}, _, _ time.Duration, _ bool) (*SSHServer, error) { return nil, errors.New("cloudflared ssh server is not supported on windows") }