This commit is contained in:
Stringer 2026-03-15 00:42:07 +01:00 committed by GitHub
commit df732ad0f7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 140 additions and 1 deletions

View File

@ -40,6 +40,7 @@ type baseEndpoints struct {
zoneLevel url.URL
accountRoutes url.URL
accountVnets url.URL
zones url.URL
}
var _ Client = (*RESTClient)(nil)
@ -60,7 +61,11 @@ func NewRESTClient(baseURL, accountTag, zoneTag, authToken, userAgent string, lo
}
zoneLevelEndpoint, err := url.Parse(fmt.Sprintf("%s/zones/%s/tunnels", baseURL, zoneTag))
if err != nil {
return nil, errors.Wrap(err, "failed to create account level endpoint")
return nil, errors.Wrap(err, "failed to create zone level endpoint")
}
zonesEndpoint, err := url.Parse(fmt.Sprintf("%s/zones", baseURL))
if err != nil {
return nil, errors.Wrap(err, "failed to create zones endpoint")
}
httpTransport := http.Transport{
TLSHandshakeTimeout: defaultTimeout,
@ -73,6 +78,7 @@ func NewRESTClient(baseURL, accountTag, zoneTag, authToken, userAgent string, lo
zoneLevel: *zoneLevelEndpoint,
accountRoutes: *accountRoutesEndpoint,
accountVnets: *accountVnetsEndpoint,
zones: *zonesEndpoint,
},
authToken: authToken,
userAgent: userAgent,
@ -241,3 +247,16 @@ func (r *RESTClient) statusCodeToError(op string, resp *http.Response) error {
return errors.Errorf("API call to %s failed with status %d: %s", op,
resp.StatusCode, http.StatusText(resp.StatusCode))
}
func (r *RESTClient) ListZones() ([]*Zone, error) {
endpoint := r.baseEndpoints.zones
return fetchExhaustively[Zone](func(page int) (*http.Response, error) {
reqURL := endpoint
query := reqURL.Query()
query.Set("page", fmt.Sprintf("%d", page))
query.Set("per_page", "50")
// Required to get basic zone info instead of just IDs
reqURL.RawQuery = query.Encode()
return r.sendRequest("GET", reqURL, nil)
})
}

View File

@ -17,6 +17,7 @@ type TunnelClient interface {
type HostnameClient interface {
RouteTunnel(tunnelID uuid.UUID, route HostnameRoute) (HostnameRouteResult, error)
ListZones() ([]*Zone, error)
}
type IPRouteClient interface {
@ -39,3 +40,8 @@ type Client interface {
IPRouteClient
VnetClient
}
type Zone struct {
ID string `json:"id"`
Name string `json:"name"`
}

View File

@ -5,7 +5,9 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"path"
"strings"
"github.com/google/uuid"
"github.com/pkg/errors"
@ -25,6 +27,7 @@ type HostnameRoute interface {
RecordType() string
UnmarshalResult(body io.Reader) (HostnameRouteResult, error)
String() string
Hostname() string
}
type HostnameRouteResult interface {
@ -78,6 +81,10 @@ func (dr *DNSRoute) String() string {
return fmt.Sprintf("%s %s", dr.RecordType(), dr.userHostname)
}
func (dr *DNSRoute) Hostname() string {
return dr.userHostname
}
func (res *DNSRouteResult) SuccessSummary() string {
var msgFmt string
switch res.CName {
@ -139,6 +146,10 @@ func (lb *LBRoute) String() string {
return fmt.Sprintf("%s %s %s", lb.RecordType(), lb.lbName, lb.lbPool)
}
func (lr *LBRoute) Hostname() string {
return lr.lbName
}
func (lr *LBRoute) UnmarshalResult(body io.Reader) (HostnameRouteResult, error) {
var result LBRouteResult
err := parseResponse(body, &result)
@ -176,7 +187,35 @@ func (res *LBRouteResult) SuccessSummary() string {
}
func (r *RESTClient) RouteTunnel(tunnelID uuid.UUID, route HostnameRoute) (HostnameRouteResult, error) {
// First, try to find the correct zone by fetching all zones and matching the hostname
zoneID := ""
zones, err := r.ListZones()
if err == nil {
longestMatch := ""
for _, zone := range zones {
// A hostname should end with the zone name EXACTLY or be a subdomain of it.
// e.g. "app.staging.example.com" ends with ".example.com" and "example.com" == "example.com"
if route.Hostname() == zone.Name || strings.HasSuffix(route.Hostname(), "."+zone.Name) {
// We want the most specific zone if there are multiple matches
// e.g. "staging.example.com" zone vs "example.com" zone
if len(zone.Name) > len(longestMatch) {
longestMatch = zone.Name
zoneID = zone.ID
}
}
}
}
endpoint := r.baseEndpoints.zoneLevel
if zoneID != "" {
// Construct dynamic endpoint using the correct zone ID instead of the default one
baseURL := strings.TrimSuffix(r.baseEndpoints.zones.String(), "/zones")
zoneEndpoint, err := url.Parse(fmt.Sprintf("%s/zones/%s/tunnels", baseURL, zoneID))
if err == nil {
endpoint = *zoneEndpoint
}
}
endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v/routes", tunnelID))
resp, err := r.sendRequest("PUT", endpoint, route)
if err != nil {

View File

@ -1,9 +1,14 @@
package cfapi
import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/google/uuid"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
)
@ -97,3 +102,73 @@ func TestLBRouteResultSuccessSummary(t *testing.T) {
assert.Equal(t, tt.expected, actual, "case %d", i+1)
}
}
func TestRouteTunnel_ZoneResolution(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/zones" {
// A sample JSON response matching the Cloudflare api format, ensuring we mimic what the real API sends.
io.WriteString(w, `{
"success": true,
"errors": [],
"messages": [],
"result": [
{
"id": "zone-1",
"name": "example.com",
"status": "active",
"paused": false
},
{
"id": "zone-2",
"name": "example.co.uk",
"status": "active",
"paused": false
}
],
"result_info": {
"page": 1,
"per_page": 50,
"total_pages": 1,
"count": 2,
"total_count": 2
}
}`)
return
}
if r.URL.Path == "/zones/zone-1/tunnels/11111111-2222-3333-4444-555555555555/routes" {
io.WriteString(w, `{"success":true,"result":{"cname":"new","name":"app.example.com"}}`)
return
}
// Fallback path when zone does NOT match. It uses "default-zone-from-login" as specified in NewRESTClient arguments.
if r.URL.Path == "/zones/default-zone-from-login/tunnels/11111111-2222-3333-4444-555555555555/routes" {
io.WriteString(w, `{"success":true,"result":{"cname":"new","name":"fallback.otherdomain.com"}}`)
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer ts.Close()
logger := zerolog.Nop()
client, err := NewRESTClient(ts.URL, "account", "default-zone-from-login", "token", "agent", &logger)
assert.NoError(t, err)
tunnelID, _ := uuid.Parse("11111111-2222-3333-4444-555555555555")
t.Run("Success match", func(t *testing.T) {
route := NewDNSRoute("app.example.com", false)
res, err := client.RouteTunnel(tunnelID, route)
assert.NoError(t, err)
assert.NotNil(t, res)
assert.Equal(t, "Added CNAME app.example.com which will route to this tunnel", res.SuccessSummary())
})
t.Run("Fallback to default zone when no match", func(t *testing.T) {
route := NewDNSRoute("fallback.otherdomain.com", false)
res, err := client.RouteTunnel(tunnelID, route)
assert.NoError(t, err)
assert.NotNil(t, res)
assert.Equal(t, "Added CNAME fallback.otherdomain.com which will route to this tunnel", res.SuccessSummary())
})
}