Merge ef095a3ac8 into d2a87e9b93
This commit is contained in:
commit
df732ad0f7
|
|
@ -40,6 +40,7 @@ type baseEndpoints struct {
|
|||
zoneLevel url.URL
|
||||
accountRoutes url.URL
|
||||
accountVnets url.URL
|
||||
zones url.URL
|
||||
}
|
||||
|
||||
var _ Client = (*RESTClient)(nil)
|
||||
|
|
@ -60,7 +61,11 @@ func NewRESTClient(baseURL, accountTag, zoneTag, authToken, userAgent string, lo
|
|||
}
|
||||
zoneLevelEndpoint, err := url.Parse(fmt.Sprintf("%s/zones/%s/tunnels", baseURL, zoneTag))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create account level endpoint")
|
||||
return nil, errors.Wrap(err, "failed to create zone level endpoint")
|
||||
}
|
||||
zonesEndpoint, err := url.Parse(fmt.Sprintf("%s/zones", baseURL))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create zones endpoint")
|
||||
}
|
||||
httpTransport := http.Transport{
|
||||
TLSHandshakeTimeout: defaultTimeout,
|
||||
|
|
@ -73,6 +78,7 @@ func NewRESTClient(baseURL, accountTag, zoneTag, authToken, userAgent string, lo
|
|||
zoneLevel: *zoneLevelEndpoint,
|
||||
accountRoutes: *accountRoutesEndpoint,
|
||||
accountVnets: *accountVnetsEndpoint,
|
||||
zones: *zonesEndpoint,
|
||||
},
|
||||
authToken: authToken,
|
||||
userAgent: userAgent,
|
||||
|
|
@ -241,3 +247,16 @@ func (r *RESTClient) statusCodeToError(op string, resp *http.Response) error {
|
|||
return errors.Errorf("API call to %s failed with status %d: %s", op,
|
||||
resp.StatusCode, http.StatusText(resp.StatusCode))
|
||||
}
|
||||
|
||||
func (r *RESTClient) ListZones() ([]*Zone, error) {
|
||||
endpoint := r.baseEndpoints.zones
|
||||
return fetchExhaustively[Zone](func(page int) (*http.Response, error) {
|
||||
reqURL := endpoint
|
||||
query := reqURL.Query()
|
||||
query.Set("page", fmt.Sprintf("%d", page))
|
||||
query.Set("per_page", "50")
|
||||
// Required to get basic zone info instead of just IDs
|
||||
reqURL.RawQuery = query.Encode()
|
||||
return r.sendRequest("GET", reqURL, nil)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ type TunnelClient interface {
|
|||
|
||||
type HostnameClient interface {
|
||||
RouteTunnel(tunnelID uuid.UUID, route HostnameRoute) (HostnameRouteResult, error)
|
||||
ListZones() ([]*Zone, error)
|
||||
}
|
||||
|
||||
type IPRouteClient interface {
|
||||
|
|
@ -39,3 +40,8 @@ type Client interface {
|
|||
IPRouteClient
|
||||
VnetClient
|
||||
}
|
||||
|
||||
type Zone struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,7 +5,9 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
|
|
@ -25,6 +27,7 @@ type HostnameRoute interface {
|
|||
RecordType() string
|
||||
UnmarshalResult(body io.Reader) (HostnameRouteResult, error)
|
||||
String() string
|
||||
Hostname() string
|
||||
}
|
||||
|
||||
type HostnameRouteResult interface {
|
||||
|
|
@ -78,6 +81,10 @@ func (dr *DNSRoute) String() string {
|
|||
return fmt.Sprintf("%s %s", dr.RecordType(), dr.userHostname)
|
||||
}
|
||||
|
||||
func (dr *DNSRoute) Hostname() string {
|
||||
return dr.userHostname
|
||||
}
|
||||
|
||||
func (res *DNSRouteResult) SuccessSummary() string {
|
||||
var msgFmt string
|
||||
switch res.CName {
|
||||
|
|
@ -139,6 +146,10 @@ func (lb *LBRoute) String() string {
|
|||
return fmt.Sprintf("%s %s %s", lb.RecordType(), lb.lbName, lb.lbPool)
|
||||
}
|
||||
|
||||
func (lr *LBRoute) Hostname() string {
|
||||
return lr.lbName
|
||||
}
|
||||
|
||||
func (lr *LBRoute) UnmarshalResult(body io.Reader) (HostnameRouteResult, error) {
|
||||
var result LBRouteResult
|
||||
err := parseResponse(body, &result)
|
||||
|
|
@ -176,7 +187,35 @@ func (res *LBRouteResult) SuccessSummary() string {
|
|||
}
|
||||
|
||||
func (r *RESTClient) RouteTunnel(tunnelID uuid.UUID, route HostnameRoute) (HostnameRouteResult, error) {
|
||||
// First, try to find the correct zone by fetching all zones and matching the hostname
|
||||
zoneID := ""
|
||||
zones, err := r.ListZones()
|
||||
if err == nil {
|
||||
longestMatch := ""
|
||||
for _, zone := range zones {
|
||||
// A hostname should end with the zone name EXACTLY or be a subdomain of it.
|
||||
// e.g. "app.staging.example.com" ends with ".example.com" and "example.com" == "example.com"
|
||||
if route.Hostname() == zone.Name || strings.HasSuffix(route.Hostname(), "."+zone.Name) {
|
||||
// We want the most specific zone if there are multiple matches
|
||||
// e.g. "staging.example.com" zone vs "example.com" zone
|
||||
if len(zone.Name) > len(longestMatch) {
|
||||
longestMatch = zone.Name
|
||||
zoneID = zone.ID
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
endpoint := r.baseEndpoints.zoneLevel
|
||||
if zoneID != "" {
|
||||
// Construct dynamic endpoint using the correct zone ID instead of the default one
|
||||
baseURL := strings.TrimSuffix(r.baseEndpoints.zones.String(), "/zones")
|
||||
zoneEndpoint, err := url.Parse(fmt.Sprintf("%s/zones/%s/tunnels", baseURL, zoneID))
|
||||
if err == nil {
|
||||
endpoint = *zoneEndpoint
|
||||
}
|
||||
}
|
||||
|
||||
endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v/routes", tunnelID))
|
||||
resp, err := r.sendRequest("PUT", endpoint, route)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -1,9 +1,14 @@
|
|||
package cfapi
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
|
|
@ -97,3 +102,73 @@ func TestLBRouteResultSuccessSummary(t *testing.T) {
|
|||
assert.Equal(t, tt.expected, actual, "case %d", i+1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteTunnel_ZoneResolution(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/zones" {
|
||||
// A sample JSON response matching the Cloudflare api format, ensuring we mimic what the real API sends.
|
||||
io.WriteString(w, `{
|
||||
"success": true,
|
||||
"errors": [],
|
||||
"messages": [],
|
||||
"result": [
|
||||
{
|
||||
"id": "zone-1",
|
||||
"name": "example.com",
|
||||
"status": "active",
|
||||
"paused": false
|
||||
},
|
||||
{
|
||||
"id": "zone-2",
|
||||
"name": "example.co.uk",
|
||||
"status": "active",
|
||||
"paused": false
|
||||
}
|
||||
],
|
||||
"result_info": {
|
||||
"page": 1,
|
||||
"per_page": 50,
|
||||
"total_pages": 1,
|
||||
"count": 2,
|
||||
"total_count": 2
|
||||
}
|
||||
}`)
|
||||
return
|
||||
}
|
||||
if r.URL.Path == "/zones/zone-1/tunnels/11111111-2222-3333-4444-555555555555/routes" {
|
||||
io.WriteString(w, `{"success":true,"result":{"cname":"new","name":"app.example.com"}}`)
|
||||
return
|
||||
}
|
||||
|
||||
// Fallback path when zone does NOT match. It uses "default-zone-from-login" as specified in NewRESTClient arguments.
|
||||
if r.URL.Path == "/zones/default-zone-from-login/tunnels/11111111-2222-3333-4444-555555555555/routes" {
|
||||
io.WriteString(w, `{"success":true,"result":{"cname":"new","name":"fallback.otherdomain.com"}}`)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
logger := zerolog.Nop()
|
||||
client, err := NewRESTClient(ts.URL, "account", "default-zone-from-login", "token", "agent", &logger)
|
||||
assert.NoError(t, err)
|
||||
|
||||
tunnelID, _ := uuid.Parse("11111111-2222-3333-4444-555555555555")
|
||||
|
||||
t.Run("Success match", func(t *testing.T) {
|
||||
route := NewDNSRoute("app.example.com", false)
|
||||
res, err := client.RouteTunnel(tunnelID, route)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, res)
|
||||
assert.Equal(t, "Added CNAME app.example.com which will route to this tunnel", res.SuccessSummary())
|
||||
})
|
||||
|
||||
t.Run("Fallback to default zone when no match", func(t *testing.T) {
|
||||
route := NewDNSRoute("fallback.otherdomain.com", false)
|
||||
res, err := client.RouteTunnel(tunnelID, route)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, res)
|
||||
assert.Equal(t, "Added CNAME fallback.otherdomain.com which will route to this tunnel", res.SuccessSummary())
|
||||
})
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue