cloudflared-mirror/validation/validation.go

254 lines
6.9 KiB
Go

package validation
import (
"context"
"fmt"
"net"
"net/url"
"strings"
"time"
"net/http"
"github.com/coreos/go-oidc"
"github.com/pkg/errors"
"golang.org/x/net/idna"
)
const (
defaultScheme = "http"
accessDomain = "cloudflareaccess.com"
accessCertPath = "/cdn-cgi/access/certs"
accessJwtHeader = "Cf-access-jwt-assertion"
)
var (
supportedProtocols = []string{"http", "https", "rdp"}
validationTimeout = time.Duration(30 * time.Second)
)
func ValidateHostname(hostname string) (string, error) {
if hostname == "" {
return "", nil
}
// users gives url(contains schema) not just hostname
if strings.Contains(hostname, ":") || strings.Contains(hostname, "%3A") {
unescapeHostname, err := url.PathUnescape(hostname)
if err != nil {
return "", fmt.Errorf("Hostname(actually a URL) %s has invalid escape characters %s", hostname, unescapeHostname)
}
hostnameToURL, err := url.Parse(unescapeHostname)
if err != nil {
return "", fmt.Errorf("Hostname(actually a URL) %s has invalid format %s", hostname, hostnameToURL)
}
asciiHostname, err := idna.ToASCII(hostnameToURL.Hostname())
if err != nil {
return "", fmt.Errorf("Hostname(actually a URL) %s has invalid ASCII encdoing %s", hostname, asciiHostname)
}
return asciiHostname, nil
}
asciiHostname, err := idna.ToASCII(hostname)
if err != nil {
return "", fmt.Errorf("Hostname %s has invalid ASCII encdoing %s", hostname, asciiHostname)
}
hostnameToURL, err := url.Parse(asciiHostname)
if err != nil {
return "", fmt.Errorf("Hostname %s is not valid", hostnameToURL)
}
return hostnameToURL.RequestURI(), nil
}
func ValidateUrl(originUrl string) (string, error) {
if originUrl == "" {
return "", fmt.Errorf("URL should not be empty")
}
if net.ParseIP(originUrl) != nil {
return validateIP("", originUrl, "")
} else if strings.HasPrefix(originUrl, "[") && strings.HasSuffix(originUrl, "]") {
// ParseIP doesn't recoginze [::1]
return validateIP("", originUrl[1:len(originUrl)-1], "")
}
host, port, err := net.SplitHostPort(originUrl)
// user might pass in an ip address like 127.0.0.1
if err == nil && net.ParseIP(host) != nil {
return validateIP("", host, port)
}
unescapedUrl, err := url.PathUnescape(originUrl)
if err != nil {
return "", fmt.Errorf("URL %s has invalid escape characters %s", originUrl, unescapedUrl)
}
parsedUrl, err := url.Parse(unescapedUrl)
if err != nil {
return "", fmt.Errorf("URL %s has invalid format", originUrl)
}
// if the url is in the form of host:port, IsAbs() will think host is the schema
var hostname string
hasScheme := parsedUrl.IsAbs() && parsedUrl.Host != ""
if hasScheme {
err := validateScheme(parsedUrl.Scheme)
if err != nil {
return "", err
}
// The earlier check for ip address will miss the case http://[::1]
// and http://[::1]:8080
if net.ParseIP(parsedUrl.Hostname()) != nil {
return validateIP(parsedUrl.Scheme, parsedUrl.Hostname(), parsedUrl.Port())
}
hostname, err = ValidateHostname(parsedUrl.Hostname())
if err != nil {
return "", fmt.Errorf("URL %s has invalid format", originUrl)
}
if parsedUrl.Port() != "" {
return fmt.Sprintf("%s://%s", parsedUrl.Scheme, net.JoinHostPort(hostname, parsedUrl.Port())), nil
}
return fmt.Sprintf("%s://%s", parsedUrl.Scheme, hostname), nil
} else {
if host == "" {
hostname, err = ValidateHostname(originUrl)
if err != nil {
return "", fmt.Errorf("URL no %s has invalid format", originUrl)
}
return fmt.Sprintf("%s://%s", defaultScheme, hostname), nil
} else {
hostname, err = ValidateHostname(host)
if err != nil {
return "", fmt.Errorf("URL %s has invalid format", originUrl)
}
return fmt.Sprintf("%s://%s", defaultScheme, net.JoinHostPort(hostname, port)), nil
}
}
}
func validateScheme(scheme string) error {
for _, protocol := range supportedProtocols {
if scheme == protocol {
return nil
}
}
return fmt.Errorf("Currently Argo Tunnel does not support %s protocol.", scheme)
}
func validateIP(scheme, host, port string) (string, error) {
if scheme == "" {
scheme = defaultScheme
}
if port != "" {
return fmt.Sprintf("%s://%s", scheme, net.JoinHostPort(host, port)), nil
} else if strings.Contains(host, ":") {
// IPv6
return fmt.Sprintf("%s://[%s]", scheme, host), nil
}
return fmt.Sprintf("%s://%s", scheme, host), nil
}
func ValidateHTTPService(originURL string, hostname string, transport http.RoundTripper) error {
parsedURL, err := url.Parse(originURL)
if err != nil {
return err
}
client := &http.Client{
Transport: transport,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
Timeout: validationTimeout,
}
initialRequest, err := http.NewRequest("GET", parsedURL.String(), nil)
if err != nil {
return err
}
initialRequest.Host = hostname
_, initialErr := client.Do(initialRequest)
if initialErr != nil {
// Attempt the same endpoint via the other protocol (http/https); maybe we have better luck?
oldScheme := parsedURL.Scheme
parsedURL.Scheme = toggleProtocol(parsedURL.Scheme)
secondRequest, err := http.NewRequest("GET", parsedURL.String(), nil)
if err != nil {
return err
}
secondRequest.Host = hostname
_, secondErr := client.Do(secondRequest)
if secondErr == nil { // Worked this time--advise the user to switch protocols
return errors.Errorf(
"%s doesn't seem to work over %s, but does seem to work over %s. Consider changing the origin URL to %s",
parsedURL.Host,
oldScheme,
parsedURL.Scheme,
parsedURL,
)
}
}
return initialErr
}
func toggleProtocol(httpProtocol string) string {
switch httpProtocol {
case "http":
return "https"
case "https":
return "http"
default:
return httpProtocol
}
}
// Access checks if a JWT from Cloudflare Access is valid.
type Access struct {
verifier *oidc.IDTokenVerifier
}
func NewAccessValidator(ctx context.Context, domain, issuer, applicationAUD string) (*Access, error) {
domainURL, err := ValidateUrl(domain)
if err != nil {
return nil, err
}
issuerURL, err := ValidateUrl(issuer)
if err != nil {
return nil, err
}
// An issuerURL from Cloudflare Access will always use HTTPS.
issuerURL = strings.Replace(issuerURL, "http:", "https:", 1)
keySet := oidc.NewRemoteKeySet(ctx, domainURL+accessCertPath)
return &Access{oidc.NewVerifier(issuerURL, keySet, &oidc.Config{ClientID: applicationAUD})}, nil
}
func (a *Access) Validate(ctx context.Context, jwt string) error {
token, err := a.verifier.Verify(ctx, jwt)
if err != nil {
return errors.Wrapf(err, "token is invalid: %s", jwt)
}
// Perform extra sanity checks, just to be safe.
if token == nil {
return fmt.Errorf("token is nil: %s", jwt)
}
if !strings.HasSuffix(token.Issuer, accessDomain) {
return fmt.Errorf("token has non-cloudflare issuer of %s: %s", token.Issuer, jwt)
}
return nil
}
func (a *Access) ValidateRequest(ctx context.Context, r *http.Request) error {
return a.Validate(ctx, r.Header.Get(accessJwtHeader))
}