Fix cloudflared tunnel route dns zone resolution bug

When users have multiple domains, tunnel route dns would incorrectly use the default zoneID from the login certificate, creating an invalid CNAME record (e.g. app.domain1.com.domain2.com).

This fix introduces ListZones in cfapi to fetch all valid zones for the account and explicitly checks if the provided hostname exactly matches or is a subdomain of a discovered zone, preventing this behavior and dynamically adjusting the endpoint to the correct Zone ID.
This commit is contained in:
@StringerBell69 2026-03-14 23:25:43 +01:00
parent d2a87e9b93
commit ef095a3ac8
4 changed files with 140 additions and 1 deletions

View File

@ -40,6 +40,7 @@ type baseEndpoints struct {
zoneLevel url.URL
accountRoutes url.URL
accountVnets url.URL
zones url.URL
}
var _ Client = (*RESTClient)(nil)
@ -60,7 +61,11 @@ func NewRESTClient(baseURL, accountTag, zoneTag, authToken, userAgent string, lo
}
zoneLevelEndpoint, err := url.Parse(fmt.Sprintf("%s/zones/%s/tunnels", baseURL, zoneTag))
if err != nil {
return nil, errors.Wrap(err, "failed to create account level endpoint")
return nil, errors.Wrap(err, "failed to create zone level endpoint")
}
zonesEndpoint, err := url.Parse(fmt.Sprintf("%s/zones", baseURL))
if err != nil {
return nil, errors.Wrap(err, "failed to create zones endpoint")
}
httpTransport := http.Transport{
TLSHandshakeTimeout: defaultTimeout,
@ -73,6 +78,7 @@ func NewRESTClient(baseURL, accountTag, zoneTag, authToken, userAgent string, lo
zoneLevel: *zoneLevelEndpoint,
accountRoutes: *accountRoutesEndpoint,
accountVnets: *accountVnetsEndpoint,
zones: *zonesEndpoint,
},
authToken: authToken,
userAgent: userAgent,
@ -241,3 +247,16 @@ func (r *RESTClient) statusCodeToError(op string, resp *http.Response) error {
return errors.Errorf("API call to %s failed with status %d: %s", op,
resp.StatusCode, http.StatusText(resp.StatusCode))
}
func (r *RESTClient) ListZones() ([]*Zone, error) {
endpoint := r.baseEndpoints.zones
return fetchExhaustively[Zone](func(page int) (*http.Response, error) {
reqURL := endpoint
query := reqURL.Query()
query.Set("page", fmt.Sprintf("%d", page))
query.Set("per_page", "50")
// Required to get basic zone info instead of just IDs
reqURL.RawQuery = query.Encode()
return r.sendRequest("GET", reqURL, nil)
})
}

View File

@ -17,6 +17,7 @@ type TunnelClient interface {
type HostnameClient interface {
RouteTunnel(tunnelID uuid.UUID, route HostnameRoute) (HostnameRouteResult, error)
ListZones() ([]*Zone, error)
}
type IPRouteClient interface {
@ -39,3 +40,8 @@ type Client interface {
IPRouteClient
VnetClient
}
type Zone struct {
ID string `json:"id"`
Name string `json:"name"`
}

View File

@ -5,7 +5,9 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"path"
"strings"
"github.com/google/uuid"
"github.com/pkg/errors"
@ -25,6 +27,7 @@ type HostnameRoute interface {
RecordType() string
UnmarshalResult(body io.Reader) (HostnameRouteResult, error)
String() string
Hostname() string
}
type HostnameRouteResult interface {
@ -78,6 +81,10 @@ func (dr *DNSRoute) String() string {
return fmt.Sprintf("%s %s", dr.RecordType(), dr.userHostname)
}
func (dr *DNSRoute) Hostname() string {
return dr.userHostname
}
func (res *DNSRouteResult) SuccessSummary() string {
var msgFmt string
switch res.CName {
@ -139,6 +146,10 @@ func (lb *LBRoute) String() string {
return fmt.Sprintf("%s %s %s", lb.RecordType(), lb.lbName, lb.lbPool)
}
func (lr *LBRoute) Hostname() string {
return lr.lbName
}
func (lr *LBRoute) UnmarshalResult(body io.Reader) (HostnameRouteResult, error) {
var result LBRouteResult
err := parseResponse(body, &result)
@ -176,7 +187,35 @@ func (res *LBRouteResult) SuccessSummary() string {
}
func (r *RESTClient) RouteTunnel(tunnelID uuid.UUID, route HostnameRoute) (HostnameRouteResult, error) {
// First, try to find the correct zone by fetching all zones and matching the hostname
zoneID := ""
zones, err := r.ListZones()
if err == nil {
longestMatch := ""
for _, zone := range zones {
// A hostname should end with the zone name EXACTLY or be a subdomain of it.
// e.g. "app.staging.example.com" ends with ".example.com" and "example.com" == "example.com"
if route.Hostname() == zone.Name || strings.HasSuffix(route.Hostname(), "."+zone.Name) {
// We want the most specific zone if there are multiple matches
// e.g. "staging.example.com" zone vs "example.com" zone
if len(zone.Name) > len(longestMatch) {
longestMatch = zone.Name
zoneID = zone.ID
}
}
}
}
endpoint := r.baseEndpoints.zoneLevel
if zoneID != "" {
// Construct dynamic endpoint using the correct zone ID instead of the default one
baseURL := strings.TrimSuffix(r.baseEndpoints.zones.String(), "/zones")
zoneEndpoint, err := url.Parse(fmt.Sprintf("%s/zones/%s/tunnels", baseURL, zoneID))
if err == nil {
endpoint = *zoneEndpoint
}
}
endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v/routes", tunnelID))
resp, err := r.sendRequest("PUT", endpoint, route)
if err != nil {

View File

@ -1,9 +1,14 @@
package cfapi
import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/google/uuid"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
)
@ -97,3 +102,73 @@ func TestLBRouteResultSuccessSummary(t *testing.T) {
assert.Equal(t, tt.expected, actual, "case %d", i+1)
}
}
func TestRouteTunnel_ZoneResolution(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/zones" {
// A sample JSON response matching the Cloudflare api format, ensuring we mimic what the real API sends.
io.WriteString(w, `{
"success": true,
"errors": [],
"messages": [],
"result": [
{
"id": "zone-1",
"name": "example.com",
"status": "active",
"paused": false
},
{
"id": "zone-2",
"name": "example.co.uk",
"status": "active",
"paused": false
}
],
"result_info": {
"page": 1,
"per_page": 50,
"total_pages": 1,
"count": 2,
"total_count": 2
}
}`)
return
}
if r.URL.Path == "/zones/zone-1/tunnels/11111111-2222-3333-4444-555555555555/routes" {
io.WriteString(w, `{"success":true,"result":{"cname":"new","name":"app.example.com"}}`)
return
}
// Fallback path when zone does NOT match. It uses "default-zone-from-login" as specified in NewRESTClient arguments.
if r.URL.Path == "/zones/default-zone-from-login/tunnels/11111111-2222-3333-4444-555555555555/routes" {
io.WriteString(w, `{"success":true,"result":{"cname":"new","name":"fallback.otherdomain.com"}}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer ts.Close()
logger := zerolog.Nop()
client, err := NewRESTClient(ts.URL, "account", "default-zone-from-login", "token", "agent", &logger)
assert.NoError(t, err)
tunnelID, _ := uuid.Parse("11111111-2222-3333-4444-555555555555")
t.Run("Success match", func(t *testing.T) {
route := NewDNSRoute("app.example.com", false)
res, err := client.RouteTunnel(tunnelID, route)
assert.NoError(t, err)
assert.NotNil(t, res)
assert.Equal(t, "Added CNAME app.example.com which will route to this tunnel", res.SuccessSummary())
})
t.Run("Fallback to default zone when no match", func(t *testing.T) {
route := NewDNSRoute("fallback.otherdomain.com", false)
res, err := client.RouteTunnel(tunnelID, route)
assert.NoError(t, err)
assert.NotNil(t, res)
assert.Equal(t, "Added CNAME fallback.otherdomain.com which will route to this tunnel", res.SuccessSummary())
})
}