Fix cloudflared tunnel route dns zone resolution bug
When users have multiple domains, tunnel route dns would incorrectly use the default zoneID from the login certificate, creating an invalid CNAME record (e.g. app.domain1.com.domain2.com). This fix introduces ListZones in cfapi to fetch all valid zones for the account and explicitly checks if the provided hostname exactly matches or is a subdomain of a discovered zone, preventing this behavior and dynamically adjusting the endpoint to the correct Zone ID.
This commit is contained in:
parent
d2a87e9b93
commit
ef095a3ac8
|
|
@ -40,6 +40,7 @@ type baseEndpoints struct {
|
|||
zoneLevel url.URL
|
||||
accountRoutes url.URL
|
||||
accountVnets url.URL
|
||||
zones url.URL
|
||||
}
|
||||
|
||||
var _ Client = (*RESTClient)(nil)
|
||||
|
|
@ -60,7 +61,11 @@ func NewRESTClient(baseURL, accountTag, zoneTag, authToken, userAgent string, lo
|
|||
}
|
||||
zoneLevelEndpoint, err := url.Parse(fmt.Sprintf("%s/zones/%s/tunnels", baseURL, zoneTag))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create account level endpoint")
|
||||
return nil, errors.Wrap(err, "failed to create zone level endpoint")
|
||||
}
|
||||
zonesEndpoint, err := url.Parse(fmt.Sprintf("%s/zones", baseURL))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create zones endpoint")
|
||||
}
|
||||
httpTransport := http.Transport{
|
||||
TLSHandshakeTimeout: defaultTimeout,
|
||||
|
|
@ -73,6 +78,7 @@ func NewRESTClient(baseURL, accountTag, zoneTag, authToken, userAgent string, lo
|
|||
zoneLevel: *zoneLevelEndpoint,
|
||||
accountRoutes: *accountRoutesEndpoint,
|
||||
accountVnets: *accountVnetsEndpoint,
|
||||
zones: *zonesEndpoint,
|
||||
},
|
||||
authToken: authToken,
|
||||
userAgent: userAgent,
|
||||
|
|
@ -241,3 +247,16 @@ func (r *RESTClient) statusCodeToError(op string, resp *http.Response) error {
|
|||
return errors.Errorf("API call to %s failed with status %d: %s", op,
|
||||
resp.StatusCode, http.StatusText(resp.StatusCode))
|
||||
}
|
||||
|
||||
func (r *RESTClient) ListZones() ([]*Zone, error) {
|
||||
endpoint := r.baseEndpoints.zones
|
||||
return fetchExhaustively[Zone](func(page int) (*http.Response, error) {
|
||||
reqURL := endpoint
|
||||
query := reqURL.Query()
|
||||
query.Set("page", fmt.Sprintf("%d", page))
|
||||
query.Set("per_page", "50")
|
||||
// Required to get basic zone info instead of just IDs
|
||||
reqURL.RawQuery = query.Encode()
|
||||
return r.sendRequest("GET", reqURL, nil)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ type TunnelClient interface {
|
|||
|
||||
type HostnameClient interface {
|
||||
RouteTunnel(tunnelID uuid.UUID, route HostnameRoute) (HostnameRouteResult, error)
|
||||
ListZones() ([]*Zone, error)
|
||||
}
|
||||
|
||||
type IPRouteClient interface {
|
||||
|
|
@ -39,3 +40,8 @@ type Client interface {
|
|||
IPRouteClient
|
||||
VnetClient
|
||||
}
|
||||
|
||||
type Zone struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,7 +5,9 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
|
|
@ -25,6 +27,7 @@ type HostnameRoute interface {
|
|||
RecordType() string
|
||||
UnmarshalResult(body io.Reader) (HostnameRouteResult, error)
|
||||
String() string
|
||||
Hostname() string
|
||||
}
|
||||
|
||||
type HostnameRouteResult interface {
|
||||
|
|
@ -78,6 +81,10 @@ func (dr *DNSRoute) String() string {
|
|||
return fmt.Sprintf("%s %s", dr.RecordType(), dr.userHostname)
|
||||
}
|
||||
|
||||
func (dr *DNSRoute) Hostname() string {
|
||||
return dr.userHostname
|
||||
}
|
||||
|
||||
func (res *DNSRouteResult) SuccessSummary() string {
|
||||
var msgFmt string
|
||||
switch res.CName {
|
||||
|
|
@ -139,6 +146,10 @@ func (lb *LBRoute) String() string {
|
|||
return fmt.Sprintf("%s %s %s", lb.RecordType(), lb.lbName, lb.lbPool)
|
||||
}
|
||||
|
||||
func (lr *LBRoute) Hostname() string {
|
||||
return lr.lbName
|
||||
}
|
||||
|
||||
func (lr *LBRoute) UnmarshalResult(body io.Reader) (HostnameRouteResult, error) {
|
||||
var result LBRouteResult
|
||||
err := parseResponse(body, &result)
|
||||
|
|
@ -176,7 +187,35 @@ func (res *LBRouteResult) SuccessSummary() string {
|
|||
}
|
||||
|
||||
func (r *RESTClient) RouteTunnel(tunnelID uuid.UUID, route HostnameRoute) (HostnameRouteResult, error) {
|
||||
// First, try to find the correct zone by fetching all zones and matching the hostname
|
||||
zoneID := ""
|
||||
zones, err := r.ListZones()
|
||||
if err == nil {
|
||||
longestMatch := ""
|
||||
for _, zone := range zones {
|
||||
// A hostname should end with the zone name EXACTLY or be a subdomain of it.
|
||||
// e.g. "app.staging.example.com" ends with ".example.com" and "example.com" == "example.com"
|
||||
if route.Hostname() == zone.Name || strings.HasSuffix(route.Hostname(), "."+zone.Name) {
|
||||
// We want the most specific zone if there are multiple matches
|
||||
// e.g. "staging.example.com" zone vs "example.com" zone
|
||||
if len(zone.Name) > len(longestMatch) {
|
||||
longestMatch = zone.Name
|
||||
zoneID = zone.ID
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
endpoint := r.baseEndpoints.zoneLevel
|
||||
if zoneID != "" {
|
||||
// Construct dynamic endpoint using the correct zone ID instead of the default one
|
||||
baseURL := strings.TrimSuffix(r.baseEndpoints.zones.String(), "/zones")
|
||||
zoneEndpoint, err := url.Parse(fmt.Sprintf("%s/zones/%s/tunnels", baseURL, zoneID))
|
||||
if err == nil {
|
||||
endpoint = *zoneEndpoint
|
||||
}
|
||||
}
|
||||
|
||||
endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v/routes", tunnelID))
|
||||
resp, err := r.sendRequest("PUT", endpoint, route)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -1,9 +1,14 @@
|
|||
package cfapi
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
|
|
@ -97,3 +102,73 @@ func TestLBRouteResultSuccessSummary(t *testing.T) {
|
|||
assert.Equal(t, tt.expected, actual, "case %d", i+1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteTunnel_ZoneResolution(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/zones" {
|
||||
// A sample JSON response matching the Cloudflare api format, ensuring we mimic what the real API sends.
|
||||
io.WriteString(w, `{
|
||||
"success": true,
|
||||
"errors": [],
|
||||
"messages": [],
|
||||
"result": [
|
||||
{
|
||||
"id": "zone-1",
|
||||
"name": "example.com",
|
||||
"status": "active",
|
||||
"paused": false
|
||||
},
|
||||
{
|
||||
"id": "zone-2",
|
||||
"name": "example.co.uk",
|
||||
"status": "active",
|
||||
"paused": false
|
||||
}
|
||||
],
|
||||
"result_info": {
|
||||
"page": 1,
|
||||
"per_page": 50,
|
||||
"total_pages": 1,
|
||||
"count": 2,
|
||||
"total_count": 2
|
||||
}
|
||||
}`)
|
||||
return
|
||||
}
|
||||
if r.URL.Path == "/zones/zone-1/tunnels/11111111-2222-3333-4444-555555555555/routes" {
|
||||
io.WriteString(w, `{"success":true,"result":{"cname":"new","name":"app.example.com"}}`)
|
||||
return
|
||||
}
|
||||
|
||||
// Fallback path when zone does NOT match. It uses "default-zone-from-login" as specified in NewRESTClient arguments.
|
||||
if r.URL.Path == "/zones/default-zone-from-login/tunnels/11111111-2222-3333-4444-555555555555/routes" {
|
||||
io.WriteString(w, `{"success":true,"result":{"cname":"new","name":"fallback.otherdomain.com"}}`)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
logger := zerolog.Nop()
|
||||
client, err := NewRESTClient(ts.URL, "account", "default-zone-from-login", "token", "agent", &logger)
|
||||
assert.NoError(t, err)
|
||||
|
||||
tunnelID, _ := uuid.Parse("11111111-2222-3333-4444-555555555555")
|
||||
|
||||
t.Run("Success match", func(t *testing.T) {
|
||||
route := NewDNSRoute("app.example.com", false)
|
||||
res, err := client.RouteTunnel(tunnelID, route)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, res)
|
||||
assert.Equal(t, "Added CNAME app.example.com which will route to this tunnel", res.SuccessSummary())
|
||||
})
|
||||
|
||||
t.Run("Fallback to default zone when no match", func(t *testing.T) {
|
||||
route := NewDNSRoute("fallback.otherdomain.com", false)
|
||||
res, err := client.RouteTunnel(tunnelID, route)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, res)
|
||||
assert.Equal(t, "Added CNAME fallback.otherdomain.com which will route to this tunnel", res.SuccessSummary())
|
||||
})
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue