cloudflared-mirror/ingress/origin_service.go

412 lines
10 KiB
Go

package ingress
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strconv"
"time"
"github.com/pkg/errors"
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/hello"
"github.com/cloudflare/cloudflared/ipaccess"
"github.com/cloudflare/cloudflared/management"
"github.com/cloudflare/cloudflared/socks"
"github.com/cloudflare/cloudflared/tlsconfig"
)
const (
HelloWorldService = "hello_world"
HelloWorldFlag = "hello-world"
HttpStatusService = "http_status"
)
// OriginService is something a tunnel can proxy traffic to.
type OriginService interface {
String() string
// Start the origin service if it's managed by cloudflared, e.g. proxy servers or Hello World.
// If it's not managed by cloudflared, this is a no-op because the user is responsible for
// starting the origin service.
// Implementor of services managed by cloudflared should terminate the service if shutdownC is closed
start(log *zerolog.Logger, shutdownC <-chan struct{}, cfg OriginRequestConfig) error
MarshalJSON() ([]byte, error)
}
// unixSocketPath is an OriginService representing a unix socket (which accepts HTTP or HTTPS)
type unixSocketPath struct {
path string
scheme string
transport *http.Transport
}
func (o *unixSocketPath) String() string {
scheme := ""
if o.scheme == "https" {
scheme = "+tls"
}
return fmt.Sprintf("unix%s:%s", scheme, o.path)
}
func (o *unixSocketPath) start(log *zerolog.Logger, _ <-chan struct{}, cfg OriginRequestConfig) error {
transport, err := newHTTPTransport(o, cfg, log)
if err != nil {
return err
}
o.transport = transport
return nil
}
func (o unixSocketPath) MarshalJSON() ([]byte, error) {
return json.Marshal(o.String())
}
type httpService struct {
url *url.URL
hostHeader string
transport *http.Transport
matchSNIToHost bool
}
func (o *httpService) start(log *zerolog.Logger, _ <-chan struct{}, cfg OriginRequestConfig) error {
transport, err := newHTTPTransport(o, cfg, log)
if err != nil {
return err
}
o.hostHeader = cfg.HTTPHostHeader
o.transport = transport
o.matchSNIToHost = cfg.MatchSNIToHost
return nil
}
func (o *httpService) String() string {
return o.url.String()
}
func (o httpService) MarshalJSON() ([]byte, error) {
return json.Marshal(o.String())
}
// rawTCPService dials TCP to the destination specified by the client
// It's used by warp routing
type rawTCPService struct {
name string
dialer net.Dialer
writeTimeout time.Duration
logger *zerolog.Logger
}
func (o *rawTCPService) String() string {
return o.name
}
func (o *rawTCPService) start(_ *zerolog.Logger, _ <-chan struct{}, _ OriginRequestConfig) error {
return nil
}
func (o rawTCPService) MarshalJSON() ([]byte, error) {
return json.Marshal(o.String())
}
// tcpOverWSService models TCP origins serving eyeballs connecting over websocket, such as
// cloudflared access commands.
type tcpOverWSService struct {
scheme string
dest string
isBastion bool
streamHandler streamHandlerFunc
dialer net.Dialer
}
type socksProxyOverWSService struct {
conn *socksProxyOverWSConnection
}
func newTCPOverWSService(url *url.URL) *tcpOverWSService {
switch url.Scheme {
case "ssh":
addPortIfMissing(url, 22)
case "rdp":
addPortIfMissing(url, 3389)
case "smb":
addPortIfMissing(url, 445)
case "tcp":
addPortIfMissing(url, 7864) // just a random port since there isn't a default in this case
}
return &tcpOverWSService{
scheme: url.Scheme,
dest: url.Host,
}
}
func newBastionService() *tcpOverWSService {
return &tcpOverWSService{
isBastion: true,
}
}
func newSocksProxyOverWSService(accessPolicy *ipaccess.Policy) *socksProxyOverWSService {
proxy := socksProxyOverWSService{
conn: &socksProxyOverWSConnection{
accessPolicy: accessPolicy,
},
}
return &proxy
}
func addPortIfMissing(uri *url.URL, port int) {
hostname := uri.Hostname()
if uri.Port() == "" {
uri.Host = net.JoinHostPort(hostname, strconv.FormatInt(int64(port), 10))
}
}
func (o *tcpOverWSService) String() string {
if o.isBastion {
return ServiceBastion
}
if o.scheme != "" {
return fmt.Sprintf("%s://%s", o.scheme, o.dest)
} else {
return o.dest
}
}
func (o *tcpOverWSService) start(log *zerolog.Logger, _ <-chan struct{}, cfg OriginRequestConfig) error {
if cfg.ProxyType == socksProxy {
o.streamHandler = socks.StreamHandler
} else {
o.streamHandler = DefaultStreamHandler
}
o.dialer.Timeout = cfg.ConnectTimeout.Duration
o.dialer.KeepAlive = cfg.TCPKeepAlive.Duration
return nil
}
func (o tcpOverWSService) MarshalJSON() ([]byte, error) {
return json.Marshal(o.String())
}
func (o *socksProxyOverWSService) start(log *zerolog.Logger, _ <-chan struct{}, cfg OriginRequestConfig) error {
return nil
}
func (o *socksProxyOverWSService) String() string {
return ServiceSocksProxy
}
func (o socksProxyOverWSService) MarshalJSON() ([]byte, error) {
return json.Marshal(o.String())
}
// HelloWorld is an OriginService for the built-in Hello World server.
// Users only use this for testing and experimenting with cloudflared.
type helloWorld struct {
httpService
server net.Listener
}
func (o *helloWorld) String() string {
return HelloWorldService
}
// Start starts a HelloWorld server and stores its address in the Service receiver.
func (o *helloWorld) start(
log *zerolog.Logger,
shutdownC <-chan struct{},
cfg OriginRequestConfig,
) error {
if err := o.httpService.start(log, shutdownC, cfg); err != nil {
return err
}
helloListener, err := hello.CreateTLSListener("127.0.0.1:")
if err != nil {
return errors.Wrap(err, "Cannot start Hello World Server")
}
go hello.StartHelloWorldServer(log, helloListener, shutdownC)
o.server = helloListener
o.httpService.url = &url.URL{
Scheme: "https",
Host: o.server.Addr().String(),
}
return nil
}
func (o helloWorld) MarshalJSON() ([]byte, error) {
return json.Marshal(o.String())
}
// statusCode is an OriginService that just responds with a given HTTP status.
// Typical use-case is "user wants the catch-all rule to just respond 404".
type statusCode struct {
code int
// Set only when the user has not defined any ingress rules
defaultResp bool
log *zerolog.Logger
}
func newStatusCode(status int) statusCode {
return statusCode{code: status}
}
// default status code (503) that is returned for requests to cloudflared that don't have any ingress rules setup
func newDefaultStatusCode(log *zerolog.Logger) statusCode {
return statusCode{code: 503, defaultResp: true, log: log}
}
func (o *statusCode) String() string {
return fmt.Sprintf("http_status:%d", o.code)
}
func (o *statusCode) start(
log *zerolog.Logger,
_ <-chan struct{},
cfg OriginRequestConfig,
) error {
return nil
}
func (o statusCode) MarshalJSON() ([]byte, error) {
return json.Marshal(o.String())
}
// WarpRoutingService starts a tcp stream between the origin and requests from
// warp clients.
type WarpRoutingService struct {
Proxy StreamBasedOriginProxy
}
func NewWarpRoutingService(config WarpRoutingConfig, writeTimeout time.Duration) *WarpRoutingService {
svc := &rawTCPService{
name: ServiceWarpRouting,
dialer: net.Dialer{
Timeout: config.ConnectTimeout.Duration,
KeepAlive: config.TCPKeepAlive.Duration,
},
writeTimeout: writeTimeout,
}
return &WarpRoutingService{Proxy: svc}
}
// ManagementService starts a local HTTP server to handle incoming management requests.
type ManagementService struct {
HTTPLocalProxy
}
func newManagementService(managementProxy HTTPLocalProxy) *ManagementService {
return &ManagementService{
HTTPLocalProxy: managementProxy,
}
}
func (o *ManagementService) start(log *zerolog.Logger, _ <-chan struct{}, cfg OriginRequestConfig) error {
return nil
}
func (o *ManagementService) String() string {
return "management"
}
func (o ManagementService) MarshalJSON() ([]byte, error) {
return json.Marshal(o.String())
}
func NewManagementRule(management *management.ManagementService) Rule {
return Rule{
Hostname: management.Hostname,
Service: newManagementService(management),
}
}
type NopReadCloser struct{}
// Read always returns EOF to signal end of input
func (nrc *NopReadCloser) Read(buf []byte) (int, error) {
return 0, io.EOF
}
func (nrc *NopReadCloser) Close() error {
return nil
}
func newHTTPTransport(service OriginService, cfg OriginRequestConfig, log *zerolog.Logger) (*http.Transport, error) {
originCertPool, err := tlsconfig.LoadOriginCA(cfg.CAPool, log)
if err != nil {
return nil, errors.Wrap(err, "Error loading cert pool")
}
httpTransport := http.Transport{
Proxy: http.ProxyFromEnvironment,
MaxIdleConns: cfg.KeepAliveConnections,
MaxIdleConnsPerHost: cfg.KeepAliveConnections,
IdleConnTimeout: cfg.KeepAliveTimeout.Duration,
TLSHandshakeTimeout: cfg.TLSTimeout.Duration,
ExpectContinueTimeout: 1 * time.Second,
TLSClientConfig: &tls.Config{RootCAs: originCertPool, InsecureSkipVerify: cfg.NoTLSVerify},
ForceAttemptHTTP2: cfg.Http2Origin,
}
if _, isHelloWorld := service.(*helloWorld); !isHelloWorld && cfg.OriginServerName != "" {
httpTransport.TLSClientConfig.ServerName = cfg.OriginServerName
}
dialer := &net.Dialer{
Timeout: cfg.ConnectTimeout.Duration,
KeepAlive: cfg.TCPKeepAlive.Duration,
}
if cfg.NoHappyEyeballs {
dialer.FallbackDelay = -1 // As of Golang 1.12, a negative delay disables "happy eyeballs"
}
// DialContext depends on which kind of origin is being used.
dialContext := dialer.DialContext
switch service := service.(type) {
// If this origin is a unix socket, enforce network type "unix".
case *unixSocketPath:
httpTransport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
return dialContext(ctx, "unix", service.path)
}
// Otherwise, use the regular network config.
default:
httpTransport.DialContext = dialContext
}
return &httpTransport, nil
}
// MockOriginHTTPService should only be used by other packages to mock OriginService. Set Transport to configure desired RoundTripper behavior.
type MockOriginHTTPService struct {
Transport http.RoundTripper
}
func (mos MockOriginHTTPService) RoundTrip(req *http.Request) (*http.Response, error) {
return mos.Transport.RoundTrip(req)
}
func (mos MockOriginHTTPService) String() string {
return "MockOriginService"
}
func (mos MockOriginHTTPService) start(log *zerolog.Logger, _ <-chan struct{}, cfg OriginRequestConfig) error {
return nil
}
func (mos MockOriginHTTPService) MarshalJSON() ([]byte, error) {
return json.Marshal(mos.String())
}