AUTH-2114: Uses short lived cert auth for outgoing client connection

This commit is contained in:
Michael Borkenstein 2019-10-09 16:56:47 -05:00
parent 4d2583edf5
commit 95704b11fb
5 changed files with 221 additions and 79 deletions

View File

@ -395,8 +395,8 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
uploadManager.Start() uploadManager.Start()
} }
sshServerAddress := "127.0.0.1:" + c.String(sshPortFlag) localServerAddress := "127.0.0.1:" + c.String(sshPortFlag)
server, err := sshserver.New(logManager, logger, version, sshServerAddress, shutdownC, c.Duration(sshIdleTimeoutFlag), c.Duration(sshMaxTimeoutFlag)) server, err := sshserver.New(logManager, logger, version, localServerAddress, c.String("hostname"), shutdownC, c.Duration(sshIdleTimeoutFlag), c.Duration(sshMaxTimeoutFlag))
if err != nil { if err != nil {
msg := "Cannot create new SSH Server" msg := "Cannot create new SSH Server"
logger.WithError(err).Error(msg) logger.WithError(err).Error(msg)
@ -411,7 +411,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
// TODO: remove when declarative tunnels are implemented. // TODO: remove when declarative tunnels are implemented.
close(shutdownC) close(shutdownC)
}() }()
c.Set("url", "ssh://"+sshServerAddress) c.Set("url", "ssh://"+localServerAddress)
} }
if host := hostnameFromURI(c.String("url")); host != "" { if host := hostnameFromURI(c.String("url")); host != "" {

View File

@ -8,7 +8,6 @@ import (
"crypto/x509" "crypto/x509"
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"errors"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
@ -20,6 +19,7 @@ import (
cfpath "github.com/cloudflare/cloudflared/cmd/cloudflared/path" cfpath "github.com/cloudflare/cloudflared/cmd/cloudflared/path"
"github.com/coreos/go-oidc/jose" "github.com/coreos/go-oidc/jose"
homedir "github.com/mitchellh/go-homedir" homedir "github.com/mitchellh/go-homedir"
"github.com/pkg/errors"
gossh "golang.org/x/crypto/ssh" gossh "golang.org/x/crypto/ssh"
) )
@ -73,48 +73,54 @@ func GenerateShortLivedCertificate(appURL *url.URL, token string) error {
// handleCertificateGeneration takes a JWT and uses it build a signPayload // handleCertificateGeneration takes a JWT and uses it build a signPayload
// to send to the Sign endpoint with the public key from the keypair it generated // to send to the Sign endpoint with the public key from the keypair it generated
func handleCertificateGeneration(token, fullName string) (string, error) { func handleCertificateGeneration(token, fullName string) (string, error) {
pub, err := generateKeyPair(fullName)
if err != nil {
return "", err
}
return SignCert(token, string(pub))
}
func SignCert(token, pubKey string) (string, error) {
if token == "" { if token == "" {
return "", errors.New("invalid token") return "", errors.New("invalid token")
} }
jwt, err := jose.ParseJWT(token) jwt, err := jose.ParseJWT(token)
if err != nil { if err != nil {
return "", err return "", errors.Wrap(err, "failed to parse JWT")
} }
claims, err := jwt.Claims() claims, err := jwt.Claims()
if err != nil { if err != nil {
return "", err return "", errors.Wrap(err, "failed to retrieve JWT claims")
} }
issuer, _, err := claims.StringClaim("iss") issuer, _, err := claims.StringClaim("iss")
if err != nil { if err != nil {
return "", err return "", errors.Wrap(err, "failed to retrieve JWT iss")
}
pub, err := generateKeyPair(fullName)
if err != nil {
return "", err
} }
buf, err := json.Marshal(&signPayload{ buf, err := json.Marshal(&signPayload{
PublicKey: string(pub), PublicKey: pubKey,
JWT: token, JWT: token,
Issuer: issuer, Issuer: issuer,
}) })
if err != nil { if err != nil {
return "", err return "", errors.Wrap(err, "failed to marshal signPayload")
} }
var res *http.Response var res *http.Response
if mockRequest != nil { if mockRequest != nil {
res, err = mockRequest(issuer+signEndpoint, "application/json", bytes.NewBuffer(buf)) res, err = mockRequest(issuer+signEndpoint, "application/json", bytes.NewBuffer(buf))
} else { } else {
res, err = http.Post(issuer+signEndpoint, "application/json", bytes.NewBuffer(buf)) client := http.Client{
Timeout: 10 * time.Second,
}
res, err = client.Post(issuer+signEndpoint, "application/json", bytes.NewBuffer(buf))
} }
if err != nil { if err != nil {
return "", err return "", errors.Wrap(err, "failed to send request")
} }
defer res.Body.Close() defer res.Body.Close()
@ -130,9 +136,9 @@ func handleCertificateGeneration(token, fullName string) (string, error) {
var signRes signResponse var signRes signResponse
if err := decoder.Decode(&signRes); err != nil { if err := decoder.Decode(&signRes); err != nil {
return "", err return "", errors.Wrap(err, "failed to decode HTTP response")
} }
return signRes.Certificate, err return signRes.Certificate, nil
} }
// generateKeyPair creates a EC keypair (P256) and stores them in the homedir. // generateKeyPair creates a EC keypair (P256) and stores them in the homedir.

View File

@ -3,6 +3,9 @@
package sshserver package sshserver
import ( import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"encoding/binary" "encoding/binary"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -13,6 +16,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/cloudflare/cloudflared/sshgen"
"github.com/cloudflare/cloudflared/sshlog" "github.com/cloudflare/cloudflared/sshlog"
"github.com/gliderlabs/ssh" "github.com/gliderlabs/ssh"
"github.com/google/uuid" "github.com/google/uuid"
@ -30,8 +34,9 @@ const (
auditEventShell = "shell" auditEventShell = "shell"
sshContextSessionID = "sessionID" sshContextSessionID = "sessionID"
sshContextEventLogger = "eventLogger" sshContextEventLogger = "eventLogger"
sshContextDestination = "sshDest" sshContextPreamble = "sshPreamble"
sshPreambleLength = 4 sshContextSSHClient = "sshClient"
SSHPreambleLength = 4
) )
type auditEvent struct { type auditEvent struct {
@ -41,31 +46,53 @@ type auditEvent struct {
User string `json:"user,omitempty"` User string `json:"user,omitempty"`
Login string `json:"login,omitempty"` Login string `json:"login,omitempty"`
Datetime string `json:"datetime,omitempty"` Datetime string `json:"datetime,omitempty"`
Hostname string `json:"hostname,omitempty"`
Destination string `json:"destination,omitempty"` Destination string `json:"destination,omitempty"`
} }
// sshConn wraps the incoming net.Conn and a cleanup function
// This is done to allow the outgoing SSH client to be retrieved and closed when the conn itself is closed.
type sshConn struct {
net.Conn
cleanupFunc func()
}
// close calls the cleanupFunc before closing the conn
func (c sshConn) Close() error {
c.cleanupFunc()
return c.Conn.Close()
}
type SSHProxy struct { type SSHProxy struct {
ssh.Server ssh.Server
hostname string
logger *logrus.Logger logger *logrus.Logger
shutdownC chan struct{} shutdownC chan struct{}
caCert ssh.PublicKey caCert ssh.PublicKey
logManager sshlog.Manager logManager sshlog.Manager
} }
type SSHPreamble struct {
Destination string
JWT string
}
// New creates a new SSHProxy and configures its host keys and authentication by the data provided // New creates a new SSHProxy and configures its host keys and authentication by the data provided
func New(logManager sshlog.Manager, logger *logrus.Logger, version, address string, shutdownC chan struct{}, idleTimeout, maxTimeout time.Duration) (*SSHProxy, error) { func New(logManager sshlog.Manager, logger *logrus.Logger, version, localAddress, hostname string, shutdownC chan struct{}, idleTimeout, maxTimeout time.Duration) (*SSHProxy, error) {
sshProxy := SSHProxy{ sshProxy := SSHProxy{
hostname: hostname,
logger: logger, logger: logger,
shutdownC: shutdownC, shutdownC: shutdownC,
logManager: logManager, logManager: logManager,
} }
sshProxy.Server = ssh.Server{ sshProxy.Server = ssh.Server{
Addr: address, Addr: localAddress,
MaxTimeout: maxTimeout, MaxTimeout: maxTimeout,
IdleTimeout: idleTimeout, IdleTimeout: idleTimeout,
Version: fmt.Sprintf("SSH-2.0-Cloudflare-Access_%s_%s", version, runtime.GOOS), Version: fmt.Sprintf("SSH-2.0-Cloudflare-Access_%s_%s", version, runtime.GOOS),
ConnCallback: sshProxy.connCallback, PublicKeyHandler: sshProxy.proxyAuthCallback,
ConnCallback: sshProxy.connCallback,
ChannelHandlers: map[string]ssh.ChannelHandler{ ChannelHandlers: map[string]ssh.ChannelHandler{
"default": sshProxy.channelHandler, "default": sshProxy.channelHandler,
}, },
@ -92,23 +119,54 @@ func (s *SSHProxy) Start() error {
return s.ListenAndServe() return s.ListenAndServe()
} }
// proxyAuthCallback attempts to connect to ultimate SSH destination. If successful, it allows the incoming connection
// to connect to the proxy and saves the outgoing SSH client to the context. Otherwise, no connection to the
// the proxy is allowed.
func (s *SSHProxy) proxyAuthCallback(ctx ssh.Context, key ssh.PublicKey) bool {
client, err := s.dialDestination(ctx)
if err != nil {
return false
}
ctx.SetValue(sshContextSSHClient, client)
return true
}
// connCallback reads the preamble sent from the proxy server and saves an audit event logger to the context.
// If any errors occur, the connection is terminated by returning nil from the callback.
func (s *SSHProxy) connCallback(ctx ssh.Context, conn net.Conn) net.Conn { func (s *SSHProxy) connCallback(ctx ssh.Context, conn net.Conn) net.Conn {
// AUTH-2050: This is a temporary workaround of a timing issue in the tunnel muxer to allow further testing. // AUTH-2050: This is a temporary workaround of a timing issue in the tunnel muxer to allow further testing.
// TODO: Remove this // TODO: Remove this
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
if err := s.configureSSHDestination(conn, ctx); err != nil { preamble, err := s.readPreamble(conn)
if err != io.EOF { if err != nil {
s.logger.WithError(err).Error("failed to read SSH destination") if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
s.logger.Warn("Could not establish session. Client likely does not have --destination set and is using old-style ssh config")
} else if err != io.EOF {
s.logger.WithError(err).Error("failed to read SSH preamble")
} }
return nil return nil
} }
ctx.SetValue(sshContextPreamble, preamble)
if err := s.configureLogger(ctx); err != nil { logger, sessionID, err := s.auditLogger()
if err != nil {
s.logger.WithError(err).Error("failed to configure logger") s.logger.WithError(err).Error("failed to configure logger")
return nil return nil
} }
return conn ctx.SetValue(sshContextEventLogger, logger)
ctx.SetValue(sshContextSessionID, sessionID)
// attempts to retrieve and close the outgoing ssh client when the incoming conn is closed.
// If no client exists, the conn is being closed before the PublicKeyCallback was called (where the client is created).
cleanupFunc := func() {
client, ok := ctx.Value(sshContextSSHClient).(*gossh.Client)
if ok && client != nil {
client.Close()
}
}
return sshConn{conn, cleanupFunc}
} }
// channelHandler proxies incoming and outgoing SSH traffic back and forth over an SSH Channel // channelHandler proxies incoming and outgoing SSH traffic back and forth over an SSH Channel
@ -129,13 +187,12 @@ func (s *SSHProxy) channelHandler(srv *ssh.Server, conn *gossh.ServerConn, newCh
} }
defer localChan.Close() defer localChan.Close()
// AUTH-2136 TODO: multiplex ssh client between channels // client will be closed when the sshConn is closed
client, err := s.createSSHClient(ctx) client, ok := ctx.Value(sshContextSSHClient).(*gossh.Client)
if err != nil { if !ok {
s.logger.WithError(err).Error("Failed to dial remote server") s.logger.Error("Could not retrieve client from context")
return return
} }
defer client.Close()
remoteChan, remoteChanReqs, err := client.OpenChannel(newChan.ChannelType(), newChan.ExtraData()) remoteChan, remoteChanReqs, err := client.OpenChannel(newChan.ChannelType(), newChan.ExtraData())
if err != nil { if err != nil {
@ -196,54 +253,116 @@ func (s *SSHProxy) proxyChannel(localChan, remoteChan gossh.Channel, localChanRe
} }
} }
// configureSSHDestination reads a preamble from the SSH connection before any SSH traffic is sent. // readPreamble reads a preamble from the SSH connection before any SSH traffic is sent.
// This preamble contains the ultimate SSH destination the proxy will connect too. // This preamble is a JSON encoded struct containing the users JWT and ultimate destination.
// The first 4 bytes contain the length of the destination which follows immediately. // The first 4 bytes contain the length of the preamble which follows immediately.
func (s *SSHProxy) configureSSHDestination(conn net.Conn, ctx ssh.Context) error { func (s *SSHProxy) readPreamble(conn net.Conn) (*SSHPreamble, error) {
size := make([]byte, sshPreambleLength) // Set conn read deadline while reading preamble to prevent hangs if preamble wasnt sent.
if err := conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond)); err != nil {
return nil, errors.Wrap(err, "failed to set conn deadline")
}
defer func() {
if err := conn.SetReadDeadline(time.Time{}); err != nil {
s.logger.WithError(err).Error("Failed to unset conn read deadline")
}
}()
size := make([]byte, SSHPreambleLength)
if _, err := io.ReadFull(conn, size); err != nil { if _, err := io.ReadFull(conn, size); err != nil {
return err return nil, err
} }
payloadLength := binary.BigEndian.Uint32(size) payloadLength := binary.BigEndian.Uint32(size)
data := make([]byte, payloadLength) payload := make([]byte, payloadLength)
if _, err := io.ReadFull(conn, data); err != nil { if _, err := io.ReadFull(conn, payload); err != nil {
return err return nil, err
} }
destAddr := string(data) var preamble SSHPreamble
destUrl, err := url.Parse(destAddr) err := json.Unmarshal(payload, &preamble)
if err != nil { if err != nil {
return errors.Wrap(err, "failed to parse URL") return nil, err
}
destUrl, err := url.Parse(preamble.Destination)
if err != nil {
return nil, errors.Wrap(err, "failed to parse URL")
} }
if destUrl.Port() == "" { if destUrl.Port() == "" {
destAddr += ":22" preamble.Destination += ":22"
} }
ctx.SetValue(sshContextDestination, destAddr) return &preamble, nil
return nil
} }
// createSSHClient creates a new SSH client and dials the destination server // dialDestination creates a new SSH client and dials the destination server
func (s *SSHProxy) createSSHClient(ctx ssh.Context) (*gossh.Client, error) { func (s *SSHProxy) dialDestination(ctx ssh.Context) (*gossh.Client, error) {
preamble, ok := ctx.Value(sshContextPreamble).(*SSHPreamble)
if !ok {
msg := "failed to retrieve SSH preamble from context"
s.logger.Error(msg)
return nil, errors.New(msg)
}
signer, err := s.genSSHSigner(preamble.JWT)
if err != nil {
s.logger.WithError(err).Error("Failed to generate signed short lived cert")
return nil, err
}
clientConfig := &gossh.ClientConfig{ clientConfig := &gossh.ClientConfig{
User: ctx.User(), User: ctx.User(),
// AUTH-2103 TODO: proper host key check // AUTH-2103 TODO: proper host key check
HostKeyCallback: gossh.InsecureIgnoreHostKey(), HostKeyCallback: gossh.InsecureIgnoreHostKey(),
// AUTH-2114 TODO: replace with short lived cert auth Auth: []gossh.AuthMethod{gossh.PublicKeys(signer)},
Auth: []gossh.AuthMethod{gossh.Password("test")}, ClientVersion: ctx.ServerVersion(),
ClientVersion: ctx.ServerVersion(),
} }
address, ok := ctx.Value(sshContextDestination).(string) client, err := gossh.Dial("tcp", preamble.Destination, clientConfig)
if !ok {
return nil, errors.New("failed to retrieve SSH destination from context")
}
client, err := gossh.Dial("tcp", address, clientConfig)
if err != nil { if err != nil {
s.logger.WithError(err).Info("Failed to connect to destination SSH server")
return nil, err return nil, err
} }
return client, nil return client, nil
} }
// Generates a key pair and sends public key to get signed by CA
func (s *SSHProxy) genSSHSigner(jwt string) (gossh.Signer, error) {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, errors.Wrap(err, "failed to generate ecdsa key pair")
}
pub, err := gossh.NewPublicKey(&key.PublicKey)
if err != nil {
return nil, errors.Wrap(err, "failed to convert ecdsa public key to SSH public key")
}
pubBytes := gossh.MarshalAuthorizedKey(pub)
signedCertBytes, err := sshgen.SignCert(jwt, string(pubBytes))
if err != nil {
return nil, errors.Wrap(err, "failed to retrieve cert from SSHCAAPI")
}
signedPub, _, _, _, err := gossh.ParseAuthorizedKey([]byte(signedCertBytes))
if err != nil {
return nil, errors.Wrap(err, "failed to parse SSH public key")
}
cert, ok := signedPub.(*gossh.Certificate)
if !ok {
return nil, errors.Wrap(err, "failed to assert public key as certificate")
}
signer, err := gossh.NewSignerFromKey(key)
if err != nil {
return nil, errors.Wrap(err, "failed to create signer")
}
certSigner, err := gossh.NewCertSigner(cert, signer)
if err != nil {
return nil, errors.Wrap(err, "failed to create cert signer")
}
return certSigner, nil
}
// forwardChannelRequest sends request req to SSH channel sshChan, waits for reply, and sends the reply back. // forwardChannelRequest sends request req to SSH channel sshChan, waits for reply, and sends the reply back.
func (s *SSHProxy) forwardChannelRequest(sshChan gossh.Channel, req *gossh.Request) error { func (s *SSHProxy) forwardChannelRequest(sshChan gossh.Channel, req *gossh.Request) error {
reply, err := sshChan.SendRequest(req.Type, req.WantReply, req.Payload) reply, err := sshChan.SendRequest(req.Type, req.WantReply, req.Payload)
@ -282,20 +401,18 @@ func (s *SSHProxy) logChannelRequest(req *gossh.Request, conn *gossh.ServerConn,
s.logAuditEvent(conn, event, eventType, ctx) s.logAuditEvent(conn, event, eventType, ctx)
} }
func (s *SSHProxy) configureLogger(ctx ssh.Context) error { func (s *SSHProxy) auditLogger() (io.WriteCloser, string, error) {
sessionUUID, err := uuid.NewRandom() sessionUUID, err := uuid.NewRandom()
if err != nil { if err != nil {
return errors.Wrap(err, "failed to create sessionID") return nil, "", errors.Wrap(err, "failed to create sessionID")
} }
sessionID := sessionUUID.String() sessionID := sessionUUID.String()
writer, err := s.logManager.NewLogger(fmt.Sprintf("%s-event.log", sessionID), s.logger) writer, err := s.logManager.NewLogger(fmt.Sprintf("%s-event.log", sessionID), s.logger)
if err != nil { if err != nil {
return errors.Wrap(err, "failed to create logger") return nil, "", errors.Wrap(err, "failed to create logger")
} }
ctx.SetValue(sshContextEventLogger, writer) return writer, sessionID, nil
ctx.SetValue(sshContextSessionID, sessionID)
return nil
} }
func (s *SSHProxy) logAuditEvent(conn *gossh.ServerConn, event, eventType string, ctx ssh.Context) { func (s *SSHProxy) logAuditEvent(conn *gossh.ServerConn, event, eventType string, ctx ssh.Context) {
@ -306,9 +423,12 @@ func (s *SSHProxy) logAuditEvent(conn *gossh.ServerConn, event, eventType string
return return
} }
destination, destOk := ctx.Value(sshContextDestination).(string) var destination string
if !destOk { preamble, ok := ctx.Value(sshContextPreamble).(*SSHPreamble)
s.logger.Error("Failed to retrieve SSH destination from context") if ok {
destination = preamble.Destination
} else {
s.logger.Error("Failed to retrieve SSH preamble from context")
} }
ae := auditEvent{ ae := auditEvent{
@ -318,6 +438,7 @@ func (s *SSHProxy) logAuditEvent(conn *gossh.ServerConn, event, eventType string
User: conn.User(), User: conn.User(),
Login: conn.User(), Login: conn.User(),
Datetime: time.Now().UTC().Format(time.RFC3339), Datetime: time.Now().UTC().Format(time.RFC3339),
Hostname: s.hostname,
Destination: destination, Destination: destination,
} }
data, err := json.Marshal(&ae) data, err := json.Marshal(&ae)

View File

@ -13,7 +13,12 @@ import (
type SSHServer struct{} type SSHServer struct{}
func New(_ sshlog.Manager, _ *logrus.Logger, _, _ string, _ chan struct{}, _, _ time.Duration) (*SSHServer, error) { type SSHPreamble struct {
Destination string
JWT string
}
func New(_ sshlog.Manager, _ *logrus.Logger, _, _, _ string, _ chan struct{}, _, _ time.Duration) (*SSHServer, error) {
return nil, errors.New("cloudflared ssh server is not supported on windows") return nil, errors.New("cloudflared ssh server is not supported on windows")
} }

View File

@ -6,12 +6,14 @@ import (
"crypto/tls" "crypto/tls"
"encoding/base64" "encoding/base64"
"encoding/binary" "encoding/binary"
"encoding/json"
"errors" "errors"
"io" "io"
"net" "net"
"net/http" "net/http"
"time" "time"
"github.com/cloudflare/cloudflared/sshserver"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -155,9 +157,11 @@ func StartProxyServer(logger *logrus.Logger, listener net.Listener, remote strin
conn.Close() conn.Close()
}() }()
token := r.Header.Get("cf-access-token")
if destination := r.Header.Get("CF-Access-SSH-Destination"); destination != "" { if destination := r.Header.Get("CF-Access-SSH-Destination"); destination != "" {
if err := sendSSHDestination(stream, destination); err != nil { if err := sendSSHPreamble(stream, destination, token); err != nil {
logger.WithError(err).Error("Failed to send SSH destination") logger.WithError(err).Error("Failed to send SSH preamble")
return
} }
} }
@ -167,16 +171,22 @@ func StartProxyServer(logger *logrus.Logger, listener net.Listener, remote strin
return httpServer.Serve(listener) return httpServer.Serve(listener)
} }
// sendSSHDestination sends the final SSH destination address to the cloudflared SSH proxy // sendSSHPreamble sends the final SSH destination address to the cloudflared SSH proxy
// The destination is preceded by its length // The destination is preceded by its length
func sendSSHDestination(stream net.Conn, destination string) error { func sendSSHPreamble(stream net.Conn, destination, token string) error {
sizeBytes := make([]byte, 4) preamble := &sshserver.SSHPreamble{Destination: destination, JWT: token}
binary.BigEndian.PutUint32(sizeBytes, uint32(len(destination))) payload, err := json.Marshal(preamble)
if err != nil {
return err
}
sizeBytes := make([]byte, sshserver.SSHPreambleLength)
binary.BigEndian.PutUint32(sizeBytes, uint32(len(payload)))
if _, err := stream.Write(sizeBytes); err != nil { if _, err := stream.Write(sizeBytes); err != nil {
return err return err
} }
if _, err := stream.Write([]byte(destination)); err != nil { if _, err := stream.Write(payload); err != nil {
return err return err
} }
return nil return nil