diff --git a/cfapi/base_client.go b/cfapi/base_client.go index 92544071..aa139779 100644 --- a/cfapi/base_client.go +++ b/cfapi/base_client.go @@ -109,20 +109,34 @@ func (r *RESTClient) sendRequest(method string, url url.URL, body interface{}) ( return r.client.Do(req) } -func parseResponse(reader io.Reader, data interface{}) error { +func parseResponseEnvelope(reader io.Reader) (*response, error) { // Schema for Tunnelstore responses in the v1 API. // Roughly, it's a wrapper around a particular result that adds failures/errors/etc var result response // First, parse the wrapper and check the API call succeeded if err := json.NewDecoder(reader).Decode(&result); err != nil { - return errors.Wrap(err, "failed to decode response") + return nil, errors.Wrap(err, "failed to decode response") } if err := result.checkErrors(); err != nil { - return err + return nil, err } if !result.Success { - return ErrAPINoSuccess + return nil, ErrAPINoSuccess } + + return &result, nil +} + +func parseResponse(reader io.Reader, data interface{}) error { + result, err := parseResponseEnvelope(reader) + if err != nil { + return err + } + + return parseResponseBody(result, data) +} + +func parseResponseBody(result *response, data interface{}) error { // At this point we know the API call succeeded, so, parse out the inner // result into the datatype provided as a parameter. if err := json.Unmarshal(result.Result, &data); err != nil { @@ -131,11 +145,58 @@ func parseResponse(reader io.Reader, data interface{}) error { return nil } +func fetchExhaustively[T any](requestFn func(int) (*http.Response, error)) ([]*T, error) { + page := 0 + var fullResponse []*T + + for { + page += 1 + envelope, parsedBody, err := fetchPage[T](requestFn, page) + + if err != nil { + return nil, errors.Wrap(err, fmt.Sprintf("Error Parsing page %d", page)) + } + + fullResponse = append(fullResponse, parsedBody...) + if envelope.Pagination.Count < envelope.Pagination.PerPage || len(fullResponse) >= envelope.Pagination.TotalCount { + break + } + + } + return fullResponse, nil +} + +func fetchPage[T any](requestFn func(int) (*http.Response, error), page int) (*response, []*T, error) { + pageResp, err := requestFn(page) + if err != nil { + return nil, nil, errors.Wrap(err, "REST request failed") + } + defer pageResp.Body.Close() + if pageResp.StatusCode == http.StatusOK { + envelope, err := parseResponseEnvelope(pageResp.Body) + if err != nil { + return nil, nil, err + } + var parsedRspBody []*T + return envelope, parsedRspBody, parseResponseBody(envelope, &parsedRspBody) + + } + return nil, nil, errors.New(fmt.Sprintf("Failed to fetch page. Server returned: %d", pageResp.StatusCode)) +} + type response struct { - Success bool `json:"success,omitempty"` - Errors []apiErr `json:"errors,omitempty"` - Messages []string `json:"messages,omitempty"` - Result json.RawMessage `json:"result,omitempty"` + Success bool `json:"success,omitempty"` + Errors []apiErr `json:"errors,omitempty"` + Messages []string `json:"messages,omitempty"` + Result json.RawMessage `json:"result,omitempty"` + Pagination Pagination `json:"result_info,omitempty"` +} + +type Pagination struct { + Count int `json:"count,omitempty"` + Page int `json:"page,omitempty"` + PerPage int `json:"per_page,omitempty"` + TotalCount int `json:"total_count,omitempty"` } func (r *response) checkErrors() error { diff --git a/cfapi/ip_route.go b/cfapi/ip_route.go index a45c00bf..f451d996 100644 --- a/cfapi/ip_route.go +++ b/cfapi/ip_route.go @@ -137,20 +137,24 @@ type GetRouteByIpParams struct { } // ListRoutes calls the Tunnelstore GET endpoint for all routes under an account. +// Due to pagination on the server side it will call the endpoint multiple times if needed. func (r *RESTClient) ListRoutes(filter *IpRouteFilter) ([]*DetailedRoute, error) { - endpoint := r.baseEndpoints.accountRoutes - endpoint.RawQuery = filter.Encode() - resp, err := r.sendRequest("GET", endpoint, nil) - if err != nil { - return nil, errors.Wrap(err, "REST request failed") - } - defer resp.Body.Close() + fetchFn := func(page int) (*http.Response, error) { + endpoint := r.baseEndpoints.accountRoutes + filter.Page(page) + endpoint.RawQuery = filter.Encode() + rsp, err := r.sendRequest("GET", endpoint, nil) - if resp.StatusCode == http.StatusOK { - return parseListDetailedRoutes(resp.Body) + if err != nil { + return nil, errors.Wrap(err, "REST request failed") + } + if rsp.StatusCode != http.StatusOK { + rsp.Body.Close() + return nil, r.statusCodeToError("list routes", rsp) + } + return rsp, nil } - - return nil, r.statusCodeToError("list routes", resp) + return fetchExhaustively[DetailedRoute](fetchFn) } // AddRoute calls the Tunnelstore POST endpoint for a given route. @@ -208,12 +212,6 @@ func (r *RESTClient) GetByIP(params GetRouteByIpParams) (DetailedRoute, error) { return DetailedRoute{}, r.statusCodeToError("get route by IP", resp) } -func parseListDetailedRoutes(body io.ReadCloser) ([]*DetailedRoute, error) { - var routes []*DetailedRoute - err := parseResponse(body, &routes) - return routes, err -} - func parseRoute(body io.ReadCloser) (Route, error) { var route Route err := parseResponse(body, &route) diff --git a/cfapi/ip_route_filter.go b/cfapi/ip_route_filter.go index 455a434e..eda9805a 100644 --- a/cfapi/ip_route_filter.go +++ b/cfapi/ip_route_filter.go @@ -167,6 +167,10 @@ func (f *IpRouteFilter) MaxFetchSize(max uint) { f.queryParams.Set("per_page", strconv.Itoa(int(max))) } +func (f *IpRouteFilter) Page(page int) { + f.queryParams.Set("page", strconv.Itoa(page)) +} + func (f IpRouteFilter) Encode() string { return f.queryParams.Encode() } diff --git a/cfapi/tunnel.go b/cfapi/tunnel.go index 0d34d222..dc80c6a1 100644 --- a/cfapi/tunnel.go +++ b/cfapi/tunnel.go @@ -177,25 +177,22 @@ func (r *RESTClient) DeleteTunnel(tunnelID uuid.UUID, cascade bool) error { } func (r *RESTClient) ListTunnels(filter *TunnelFilter) ([]*Tunnel, error) { - endpoint := r.baseEndpoints.accountLevel - endpoint.RawQuery = filter.encode() - resp, err := r.sendRequest("GET", endpoint, nil) - if err != nil { - return nil, errors.Wrap(err, "REST request failed") - } - defer resp.Body.Close() - - if resp.StatusCode == http.StatusOK { - return parseListTunnels(resp.Body) + fetchFn := func(page int) (*http.Response, error) { + endpoint := r.baseEndpoints.accountLevel + filter.Page(page) + endpoint.RawQuery = filter.encode() + rsp, err := r.sendRequest("GET", endpoint, nil) + if err != nil { + return nil, errors.Wrap(err, "REST request failed") + } + if rsp.StatusCode != http.StatusOK { + rsp.Body.Close() + return nil, r.statusCodeToError("list tunnels", rsp) + } + return rsp, nil } - return nil, r.statusCodeToError("list tunnels", resp) -} - -func parseListTunnels(body io.ReadCloser) ([]*Tunnel, error) { - var tunnels []*Tunnel - err := parseResponse(body, &tunnels) - return tunnels, err + return fetchExhaustively[Tunnel](fetchFn) } func (r *RESTClient) ListActiveClients(tunnelID uuid.UUID) ([]*ActiveClient, error) { diff --git a/cfapi/tunnel_filter.go b/cfapi/tunnel_filter.go index df8932bc..736b752e 100644 --- a/cfapi/tunnel_filter.go +++ b/cfapi/tunnel_filter.go @@ -50,6 +50,10 @@ func (f *TunnelFilter) MaxFetchSize(max uint) { f.queryParams.Set("per_page", strconv.Itoa(int(max))) } +func (f *TunnelFilter) Page(page int) { + f.queryParams.Set("page", strconv.Itoa(page)) +} + func (f TunnelFilter) encode() string { return f.queryParams.Encode() } diff --git a/cfapi/tunnel_test.go b/cfapi/tunnel_test.go index c61bdc45..2c012825 100644 --- a/cfapi/tunnel_test.go +++ b/cfapi/tunnel_test.go @@ -3,7 +3,6 @@ package cfapi import ( "bytes" "fmt" - "io" "net" "reflect" "strings" @@ -16,52 +15,6 @@ import ( var loc, _ = time.LoadLocation("UTC") -func Test_parseListTunnels(t *testing.T) { - type args struct { - body string - } - tests := []struct { - name string - args args - want []*Tunnel - wantErr bool - }{ - { - name: "empty list", - args: args{body: `{"success": true, "result": []}`}, - want: []*Tunnel{}, - }, - { - name: "success is false", - args: args{body: `{"success": false, "result": []}`}, - wantErr: true, - }, - { - name: "errors are present", - args: args{body: `{"errors": [{"code": 1003, "message":"An A, AAAA or CNAME record already exists with that host"}], "result": []}`}, - wantErr: true, - }, - { - name: "invalid response", - args: args{body: `abc`}, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - body := io.NopCloser(bytes.NewReader([]byte(tt.args.body))) - got, err := parseListTunnels(body) - if (err != nil) != tt.wantErr { - t.Errorf("parseListTunnels() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("parseListTunnels() = %v, want %v", got, tt.want) - } - }) - } -} - func Test_unmarshalTunnel(t *testing.T) { type args struct { body string