package validation import ( "context" "fmt" "net" "net/http" "net/url" "strings" "time" "github.com/coreos/go-oidc/v3/oidc" "github.com/pkg/errors" "golang.org/x/net/idna" ) const ( defaultScheme = "http" accessDomain = "cloudflareaccess.com" accessCertPath = "/cdn-cgi/access/certs" accessJwtHeader = "Cf-access-jwt-assertion" ) var ( supportedProtocols = []string{"http", "https", "rdp", "ssh", "smb", "tcp"} validationTimeout = time.Duration(30 * time.Second) ) func ValidateHostname(hostname string) (string, error) { if hostname == "" { return "", nil } // users gives url(contains schema) not just hostname if strings.Contains(hostname, ":") || strings.Contains(hostname, "%3A") { unescapeHostname, err := url.PathUnescape(hostname) if err != nil { return "", fmt.Errorf("Hostname(actually a URL) %s has invalid escape characters %s", hostname, unescapeHostname) } hostnameToURL, err := url.Parse(unescapeHostname) if err != nil { return "", fmt.Errorf("Hostname(actually a URL) %s has invalid format %s", hostname, hostnameToURL) } asciiHostname, err := idna.ToASCII(hostnameToURL.Hostname()) if err != nil { return "", fmt.Errorf("Hostname(actually a URL) %s has invalid ASCII encdoing %s", hostname, asciiHostname) } return asciiHostname, nil } asciiHostname, err := idna.ToASCII(hostname) if err != nil { return "", fmt.Errorf("Hostname %s has invalid ASCII encdoing %s", hostname, asciiHostname) } hostnameToURL, err := url.Parse(asciiHostname) if err != nil { return "", fmt.Errorf("Hostname %s is not valid", hostnameToURL) } return hostnameToURL.RequestURI(), nil } // ValidateUrl returns a validated version of `originUrl` with a scheme prepended (by default http://). // Note: when originUrl contains a scheme, the path is removed: // // ValidateUrl("https://localhost:8080/api/") => "https://localhost:8080" // // but when it does not, the path is preserved: // // ValidateUrl("localhost:8080/api/") => "http://localhost:8080/api/" // // This is arguably a bug, but changing it might break some cloudflared users. func ValidateUrl(originUrl string) (*url.URL, error) { urlStr, err := validateUrlString(originUrl) if err != nil { return nil, err } return url.Parse(urlStr) } func validateUrlString(originUrl string) (string, error) { if originUrl == "" { return "", fmt.Errorf("URL should not be empty") } if net.ParseIP(originUrl) != nil { return validateIP("", originUrl, "") } else if strings.HasPrefix(originUrl, "[") && strings.HasSuffix(originUrl, "]") { // ParseIP doesn't recoginze [::1] return validateIP("", originUrl[1:len(originUrl)-1], "") } host, port, err := net.SplitHostPort(originUrl) // user might pass in an ip address like 127.0.0.1 if err == nil && net.ParseIP(host) != nil { return validateIP("", host, port) } unescapedUrl, err := url.PathUnescape(originUrl) if err != nil { return "", fmt.Errorf("URL %s has invalid escape characters %s", originUrl, unescapedUrl) } parsedUrl, err := url.Parse(unescapedUrl) if err != nil { return "", fmt.Errorf("URL %s has invalid format", originUrl) } // if the url is in the form of host:port, IsAbs() will think host is the schema var hostname string hasScheme := parsedUrl.IsAbs() && parsedUrl.Host != "" if hasScheme { err := validateScheme(parsedUrl.Scheme) if err != nil { return "", err } // The earlier check for ip address will miss the case http://[::1] // and http://[::1]:8080 if net.ParseIP(parsedUrl.Hostname()) != nil { return validateIP(parsedUrl.Scheme, parsedUrl.Hostname(), parsedUrl.Port()) } hostname, err = ValidateHostname(parsedUrl.Hostname()) if err != nil { return "", fmt.Errorf("URL %s has invalid format", originUrl) } if parsedUrl.Port() != "" { return fmt.Sprintf("%s://%s", parsedUrl.Scheme, net.JoinHostPort(hostname, parsedUrl.Port())), nil } return fmt.Sprintf("%s://%s", parsedUrl.Scheme, hostname), nil } else { if host == "" { hostname, err = ValidateHostname(originUrl) if err != nil { return "", fmt.Errorf("URL no %s has invalid format", originUrl) } return fmt.Sprintf("%s://%s", defaultScheme, hostname), nil } else { hostname, err = ValidateHostname(host) if err != nil { return "", fmt.Errorf("URL %s has invalid format", originUrl) } // This is why the path is preserved when `originUrl` doesn't have a schema. // Using `parsedUrl.Port()` here, instead of `port`, would remove the path return fmt.Sprintf("%s://%s", defaultScheme, net.JoinHostPort(hostname, port)), nil } } } func validateScheme(scheme string) error { for _, protocol := range supportedProtocols { if scheme == protocol { return nil } } return fmt.Errorf("Currently Cloudflare Tunnel does not support %s protocol.", scheme) } func validateIP(scheme, host, port string) (string, error) { if scheme == "" { scheme = defaultScheme } if port != "" { return fmt.Sprintf("%s://%s", scheme, net.JoinHostPort(host, port)), nil } else if strings.Contains(host, ":") { // IPv6 return fmt.Sprintf("%s://[%s]", scheme, host), nil } return fmt.Sprintf("%s://%s", scheme, host), nil } // Access checks if a JWT from Cloudflare Access is valid. type Access struct { verifier *oidc.IDTokenVerifier } func NewAccessValidator(ctx context.Context, domain, issuer, applicationAUD string) (*Access, error) { domainURL, err := validateUrlString(domain) if err != nil { return nil, err } issuerURL, err := validateUrlString(issuer) if err != nil { return nil, err } // An issuerURL from Cloudflare Access will always use HTTPS. issuerURL = strings.Replace(issuerURL, "http:", "https:", 1) keySet := oidc.NewRemoteKeySet(ctx, domainURL+accessCertPath) return &Access{oidc.NewVerifier(issuerURL, keySet, &oidc.Config{ClientID: applicationAUD})}, nil } func (a *Access) Validate(ctx context.Context, jwt string) error { token, err := a.verifier.Verify(ctx, jwt) if err != nil { return errors.Wrapf(err, "token is invalid: %s", jwt) } // Perform extra sanity checks, just to be safe. if token == nil { return fmt.Errorf("token is nil: %s", jwt) } if !strings.HasSuffix(token.Issuer, accessDomain) { return fmt.Errorf("token has non-cloudflare issuer of %s: %s", token.Issuer, jwt) } return nil } func (a *Access) ValidateRequest(ctx context.Context, r *http.Request) error { return a.Validate(ctx, r.Header.Get(accessJwtHeader)) }