439 lines
12 KiB
Go
439 lines
12 KiB
Go
package tunnelstore
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"path"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/pkg/errors"
|
|
|
|
"github.com/cloudflare/cloudflared/logger"
|
|
)
|
|
|
|
const (
|
|
defaultTimeout = 15 * time.Second
|
|
jsonContentType = "application/json"
|
|
)
|
|
|
|
var (
|
|
ErrTunnelNameConflict = errors.New("tunnel with name already exists")
|
|
ErrUnauthorized = errors.New("unauthorized")
|
|
ErrBadRequest = errors.New("incorrect request parameters")
|
|
ErrNotFound = errors.New("not found")
|
|
ErrAPINoSuccess = errors.New("API call failed")
|
|
)
|
|
|
|
type Tunnel struct {
|
|
ID uuid.UUID `json:"id"`
|
|
Name string `json:"name"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
DeletedAt time.Time `json:"deleted_at"`
|
|
Connections []Connection `json:"connections"`
|
|
}
|
|
|
|
type Connection struct {
|
|
ColoName string `json:"colo_name"`
|
|
ID uuid.UUID `json:"id"`
|
|
IsPendingReconnect bool `json:"is_pending_reconnect"`
|
|
}
|
|
|
|
type Change = string
|
|
|
|
const (
|
|
ChangeNew = "new"
|
|
ChangeUpdated = "updated"
|
|
ChangeUnchanged = "unchanged"
|
|
)
|
|
|
|
// Route represents a record type that can route to a tunnel
|
|
type Route interface {
|
|
json.Marshaler
|
|
RecordType() string
|
|
UnmarshalResult(body io.Reader) (RouteResult, error)
|
|
}
|
|
|
|
type RouteResult interface {
|
|
// SuccessSummary explains what will route to this tunnel when it's provisioned successfully
|
|
SuccessSummary() string
|
|
}
|
|
|
|
type DNSRoute struct {
|
|
userHostname string
|
|
}
|
|
|
|
type DNSRouteResult struct {
|
|
route *DNSRoute
|
|
CName Change `json:"cname"`
|
|
}
|
|
|
|
func NewDNSRoute(userHostname string) Route {
|
|
return &DNSRoute{
|
|
userHostname: userHostname,
|
|
}
|
|
}
|
|
|
|
func (dr *DNSRoute) MarshalJSON() ([]byte, error) {
|
|
s := struct {
|
|
Type string `json:"type"`
|
|
UserHostname string `json:"user_hostname"`
|
|
}{
|
|
Type: dr.RecordType(),
|
|
UserHostname: dr.userHostname,
|
|
}
|
|
return json.Marshal(&s)
|
|
}
|
|
|
|
func (dr *DNSRoute) UnmarshalResult(body io.Reader) (RouteResult, error) {
|
|
var result DNSRouteResult
|
|
err := parseResponse(body, &result)
|
|
result.route = dr
|
|
return &result, err
|
|
}
|
|
|
|
func (dr *DNSRoute) RecordType() string {
|
|
return "dns"
|
|
}
|
|
|
|
func (res *DNSRouteResult) SuccessSummary() string {
|
|
var msgFmt string
|
|
switch res.CName {
|
|
case ChangeNew:
|
|
msgFmt = "Added CNAME %s which will route to this tunnel"
|
|
case ChangeUpdated: // this is not currently returned by tunnelsore
|
|
msgFmt = "%s updated to route to your tunnel"
|
|
case ChangeUnchanged:
|
|
msgFmt = "%s is already configured to route to your tunnel"
|
|
}
|
|
return fmt.Sprintf(msgFmt, res.route.userHostname)
|
|
}
|
|
|
|
type LBRoute struct {
|
|
lbName string
|
|
lbPool string
|
|
}
|
|
|
|
type LBRouteResult struct {
|
|
route *LBRoute
|
|
LoadBalancer Change `json:"load_balancer"`
|
|
Pool Change `json:"pool"`
|
|
}
|
|
|
|
func NewLBRoute(lbName, lbPool string) Route {
|
|
return &LBRoute{
|
|
lbName: lbName,
|
|
lbPool: lbPool,
|
|
}
|
|
}
|
|
|
|
func (lr *LBRoute) MarshalJSON() ([]byte, error) {
|
|
s := struct {
|
|
Type string `json:"type"`
|
|
LBName string `json:"lb_name"`
|
|
LBPool string `json:"lb_pool"`
|
|
}{
|
|
Type: lr.RecordType(),
|
|
LBName: lr.lbName,
|
|
LBPool: lr.lbPool,
|
|
}
|
|
return json.Marshal(&s)
|
|
}
|
|
|
|
func (lr *LBRoute) RecordType() string {
|
|
return "lb"
|
|
}
|
|
|
|
func (lr *LBRoute) UnmarshalResult(body io.Reader) (RouteResult, error) {
|
|
var result LBRouteResult
|
|
err := parseResponse(body, &result)
|
|
result.route = lr
|
|
return &result, err
|
|
}
|
|
|
|
func (res *LBRouteResult) SuccessSummary() string {
|
|
var msg string
|
|
switch res.LoadBalancer + "," + res.Pool {
|
|
case "new,new":
|
|
msg = "Created load balancer %s and added a new pool %s with this tunnel as an origin"
|
|
case "new,updated":
|
|
msg = "Created load balancer %s with an existing pool %s which was updated to use this tunnel as an origin"
|
|
case "new,unchanged":
|
|
msg = "Created load balancer %s with an existing pool %s which already has this tunnel as an origin"
|
|
case "updated,new":
|
|
msg = "Added new pool %[2]s with this tunnel as an origin to load balancer %[1]s"
|
|
case "updated,updated":
|
|
msg = "Updated pool %[2]s to use this tunnel as an origin and added it to load balancer %[1]s"
|
|
case "updated,unchanged":
|
|
msg = "Added pool %[2]s, which already has this tunnel as an origin, to load balancer %[1]s"
|
|
case "unchanged,updated":
|
|
msg = "Added this tunnel as an origin in pool %[2]s which is already used by load balancer %[1]s"
|
|
case "unchanged,unchanged":
|
|
msg = "Load balancer %s already uses pool %s which has this tunnel as an origin"
|
|
case "unchanged,new":
|
|
// this state is not possible
|
|
fallthrough
|
|
default:
|
|
msg = "Something went wrong: failed to modify load balancer %s with pool %s; please check traffic manager configuration in the dashboard"
|
|
}
|
|
|
|
return fmt.Sprintf(msg, res.route.lbName, res.route.lbPool)
|
|
}
|
|
|
|
type Client interface {
|
|
CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, error)
|
|
GetTunnel(tunnelID uuid.UUID) (*Tunnel, error)
|
|
DeleteTunnel(tunnelID uuid.UUID) error
|
|
ListTunnels(filter *Filter) ([]*Tunnel, error)
|
|
CleanupConnections(tunnelID uuid.UUID) error
|
|
RouteTunnel(tunnelID uuid.UUID, route Route) (RouteResult, error)
|
|
}
|
|
|
|
type RESTClient struct {
|
|
baseEndpoints *baseEndpoints
|
|
authToken string
|
|
userAgent string
|
|
client http.Client
|
|
logger logger.Service
|
|
}
|
|
|
|
type baseEndpoints struct {
|
|
accountLevel url.URL
|
|
zoneLevel url.URL
|
|
}
|
|
|
|
var _ Client = (*RESTClient)(nil)
|
|
|
|
func NewRESTClient(baseURL, accountTag, zoneTag, authToken, userAgent string, logger logger.Service) (*RESTClient, error) {
|
|
if strings.HasSuffix(baseURL, "/") {
|
|
baseURL = baseURL[:len(baseURL)-1]
|
|
}
|
|
accountLevelEndpoint, err := url.Parse(fmt.Sprintf("%s/accounts/%s/tunnels", baseURL, accountTag))
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "failed to create account level endpoint")
|
|
}
|
|
zoneLevelEndpoint, err := url.Parse(fmt.Sprintf("%s/zones/%s/tunnels", baseURL, zoneTag))
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "failed to create account level endpoint")
|
|
}
|
|
return &RESTClient{
|
|
baseEndpoints: &baseEndpoints{
|
|
accountLevel: *accountLevelEndpoint,
|
|
zoneLevel: *zoneLevelEndpoint,
|
|
},
|
|
authToken: authToken,
|
|
userAgent: userAgent,
|
|
client: http.Client{
|
|
Transport: &http.Transport{
|
|
TLSHandshakeTimeout: defaultTimeout,
|
|
ResponseHeaderTimeout: defaultTimeout,
|
|
},
|
|
Timeout: defaultTimeout,
|
|
},
|
|
logger: logger,
|
|
}, nil
|
|
}
|
|
|
|
type newTunnel struct {
|
|
Name string `json:"name"`
|
|
TunnelSecret []byte `json:"tunnel_secret"`
|
|
}
|
|
|
|
func (r *RESTClient) CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, error) {
|
|
if name == "" {
|
|
return nil, errors.New("tunnel name required")
|
|
}
|
|
if _, err := uuid.Parse(name); err == nil {
|
|
return nil, errors.New("you cannot use UUIDs as tunnel names")
|
|
}
|
|
body := &newTunnel{
|
|
Name: name,
|
|
TunnelSecret: tunnelSecret,
|
|
}
|
|
|
|
resp, err := r.sendRequest("POST", r.baseEndpoints.accountLevel, body)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "REST request failed")
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
switch resp.StatusCode {
|
|
case http.StatusOK:
|
|
return unmarshalTunnel(resp.Body)
|
|
case http.StatusConflict:
|
|
return nil, ErrTunnelNameConflict
|
|
}
|
|
|
|
return nil, r.statusCodeToError("create tunnel", resp)
|
|
}
|
|
|
|
func (r *RESTClient) GetTunnel(tunnelID uuid.UUID) (*Tunnel, error) {
|
|
endpoint := r.baseEndpoints.accountLevel
|
|
endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v", tunnelID))
|
|
resp, err := r.sendRequest("GET", endpoint, nil)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "REST request failed")
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode == http.StatusOK {
|
|
return unmarshalTunnel(resp.Body)
|
|
}
|
|
|
|
return nil, r.statusCodeToError("get tunnel", resp)
|
|
}
|
|
|
|
func (r *RESTClient) DeleteTunnel(tunnelID uuid.UUID) error {
|
|
endpoint := r.baseEndpoints.accountLevel
|
|
endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v", tunnelID))
|
|
resp, err := r.sendRequest("DELETE", endpoint, nil)
|
|
if err != nil {
|
|
return errors.Wrap(err, "REST request failed")
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
return r.statusCodeToError("delete tunnel", resp)
|
|
}
|
|
|
|
func (r *RESTClient) ListTunnels(filter *Filter) ([]*Tunnel, error) {
|
|
endpoint := r.baseEndpoints.accountLevel
|
|
endpoint.RawQuery = filter.encode()
|
|
resp, err := r.sendRequest("GET", endpoint, nil)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "REST request failed")
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode == http.StatusOK {
|
|
return parseListTunnels(resp.Body)
|
|
}
|
|
|
|
return nil, r.statusCodeToError("list tunnels", resp)
|
|
}
|
|
|
|
func parseListTunnels(body io.ReadCloser) ([]*Tunnel, error) {
|
|
var tunnels []*Tunnel
|
|
err := parseResponse(body, &tunnels)
|
|
return tunnels, err
|
|
}
|
|
|
|
func (r *RESTClient) CleanupConnections(tunnelID uuid.UUID) error {
|
|
endpoint := r.baseEndpoints.accountLevel
|
|
endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v/connections", tunnelID))
|
|
resp, err := r.sendRequest("DELETE", endpoint, nil)
|
|
if err != nil {
|
|
return errors.Wrap(err, "REST request failed")
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
return r.statusCodeToError("cleanup connections", resp)
|
|
}
|
|
|
|
func (r *RESTClient) RouteTunnel(tunnelID uuid.UUID, route Route) (RouteResult, error) {
|
|
endpoint := r.baseEndpoints.zoneLevel
|
|
endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v/routes", tunnelID))
|
|
resp, err := r.sendRequest("PUT", endpoint, route)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "REST request failed")
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode == http.StatusOK {
|
|
return route.UnmarshalResult(resp.Body)
|
|
}
|
|
|
|
return nil, r.statusCodeToError("add route", resp)
|
|
}
|
|
|
|
func (r *RESTClient) sendRequest(method string, url url.URL, body interface{}) (*http.Response, error) {
|
|
var bodyReader io.Reader
|
|
if body != nil {
|
|
if bodyBytes, err := json.Marshal(body); err != nil {
|
|
return nil, errors.Wrap(err, "failed to serialize json body")
|
|
} else {
|
|
bodyReader = bytes.NewBuffer(bodyBytes)
|
|
}
|
|
}
|
|
|
|
req, err := http.NewRequest(method, url.String(), bodyReader)
|
|
if err != nil {
|
|
return nil, errors.Wrapf(err, "can't create %s request", method)
|
|
}
|
|
req.Header.Set("User-Agent", r.userAgent)
|
|
if bodyReader != nil {
|
|
req.Header.Set("Content-Type", jsonContentType)
|
|
}
|
|
req.Header.Add("X-Auth-User-Service-Key", r.authToken)
|
|
req.Header.Add("Accept", "application/json;version=1")
|
|
return r.client.Do(req)
|
|
}
|
|
|
|
func parseResponse(reader io.Reader, data interface{}) error {
|
|
// Schema for Tunnelstore responses in the v1 API.
|
|
// Roughly, it's a wrapper around a particular result that adds failures/errors/etc
|
|
var result struct {
|
|
Result json.RawMessage `json:"result"`
|
|
Success bool `json:"success"`
|
|
Errors []string `json:"errors"`
|
|
}
|
|
// First, parse the wrapper and check the API call succeeded
|
|
if err := json.NewDecoder(reader).Decode(&result); err != nil {
|
|
return errors.Wrap(err, "failed to decode response")
|
|
}
|
|
if err := checkErrors(result.Errors); err != nil {
|
|
return err
|
|
}
|
|
if !result.Success {
|
|
return ErrAPINoSuccess
|
|
}
|
|
// At this point we know the API call succeeded, so, parse out the inner
|
|
// result into the datatype provided as a parameter.
|
|
return json.Unmarshal(result.Result, &data)
|
|
}
|
|
|
|
func unmarshalTunnel(reader io.Reader) (*Tunnel, error) {
|
|
var tunnel Tunnel
|
|
err := parseResponse(reader, &tunnel)
|
|
return &tunnel, err
|
|
}
|
|
|
|
func checkErrors(errs []string) error {
|
|
if len(errs) == 0 {
|
|
return nil
|
|
}
|
|
if len(errs) == 1 {
|
|
return fmt.Errorf("API error: %s", errs[0])
|
|
}
|
|
allErrs := strings.Join(errs, "; ")
|
|
return fmt.Errorf("API errors: %s", allErrs)
|
|
}
|
|
|
|
func (r *RESTClient) statusCodeToError(op string, resp *http.Response) error {
|
|
if resp.Header.Get("Content-Type") == "application/json" {
|
|
var errorsResp struct {
|
|
Error string `json:"error"`
|
|
}
|
|
if json.NewDecoder(resp.Body).Decode(&errorsResp) == nil && errorsResp.Error != "" {
|
|
return errors.Errorf("Failed to %s: %s", op, errorsResp.Error)
|
|
}
|
|
}
|
|
|
|
switch resp.StatusCode {
|
|
case http.StatusOK:
|
|
return nil
|
|
case http.StatusBadRequest:
|
|
return ErrBadRequest
|
|
case http.StatusUnauthorized, http.StatusForbidden:
|
|
return ErrUnauthorized
|
|
case http.StatusNotFound:
|
|
return ErrNotFound
|
|
}
|
|
return errors.Errorf("API call to %s failed with status %d: %s", op,
|
|
resp.StatusCode, http.StatusText(resp.StatusCode))
|
|
}
|