diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 049dc25a..6617456b 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -407,9 +407,9 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan defer wg.Done() if err = server.Start(); err != nil && err != ssh.ErrServerClosed { logger.WithError(err).Error("SSH server error") + // TODO: remove when declarative tunnels are implemented. + close(shutdownC) } - // TODO: remove when declarative tunnels are implemented. - close(shutdownC) }() c.Set("url", "ssh://"+localServerAddress) } diff --git a/sshserver/sshserver_unix.go b/sshserver/sshserver_unix.go index eeb05f23..e445527c 100644 --- a/sshserver/sshserver_unix.go +++ b/sshserver/sshserver_unix.go @@ -11,7 +11,7 @@ import ( "fmt" "io" "net" - "net/url" + "regexp" "runtime" "strings" "time" @@ -36,7 +36,7 @@ const ( sshContextEventLogger = "eventLogger" sshContextPreamble = "sshPreamble" sshContextSSHClient = "sshClient" - SSHPreambleLength = 4 + SSHPreambleLength = 2 ) type auditEvent struct { @@ -271,7 +271,7 @@ func (s *SSHProxy) readPreamble(conn net.Conn) (*SSHPreamble, error) { if _, err := io.ReadFull(conn, size); err != nil { return nil, err } - payloadLength := binary.BigEndian.Uint32(size) + payloadLength := binary.BigEndian.Uint16(size) payload := make([]byte, payloadLength) if _, err := io.ReadFull(conn, payload); err != nil { return nil, err @@ -283,13 +283,15 @@ func (s *SSHProxy) readPreamble(conn net.Conn) (*SSHPreamble, error) { return nil, err } - destUrl, err := url.Parse(preamble.Destination) + ok, err := regexp.Match(`^[^:]*:\d+$`, []byte(preamble.Destination)) if err != nil { - return nil, errors.Wrap(err, "failed to parse URL") + return nil, err } - if destUrl.Port() == "" { + + if !ok { preamble.Destination += ":22" } + return &preamble, nil } diff --git a/websocket/websocket.go b/websocket/websocket.go index 4ada6190..2eb66901 100644 --- a/websocket/websocket.go +++ b/websocket/websocket.go @@ -180,8 +180,12 @@ func sendSSHPreamble(stream net.Conn, destination, token string) error { return err } + if uint16(len(payload)) > ^uint16(0) { + return errors.New("ssh preamble payload too large") + } + sizeBytes := make([]byte, sshserver.SSHPreambleLength) - binary.BigEndian.PutUint32(sizeBytes, uint32(len(payload))) + binary.BigEndian.PutUint16(sizeBytes, uint16(len(payload))) if _, err := stream.Write(sizeBytes); err != nil { return err }