diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 331b93f9..6bf8a3b3 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -244,10 +244,10 @@ func adhocNamedTunnel(c *cli.Context, name string) error { } if r, ok := routeFromFlag(c); ok { - if err := sc.route(tunnel.ID, r); err != nil { + if res, err := sc.route(tunnel.ID, r); err != nil { sc.logger.Errorf("failed to create route, please create it manually. err: %v.", err) } else { - sc.logger.Infof(r.SuccessSummary()) + sc.logger.Infof(res.SuccessSummary()) } } diff --git a/cmd/cloudflared/tunnel/subcommand_context.go b/cmd/cloudflared/tunnel/subcommand_context.go index b868eb49..ae49c2fe 100644 --- a/cmd/cloudflared/tunnel/subcommand_context.go +++ b/cmd/cloudflared/tunnel/subcommand_context.go @@ -270,17 +270,13 @@ func (sc *subcommandContext) cleanupConnections(tunnelIDs []uuid.UUID) error { return nil } -func (sc *subcommandContext) route(tunnelID uuid.UUID, r tunnelstore.Route) error { +func (sc *subcommandContext) route(tunnelID uuid.UUID, r tunnelstore.Route) (tunnelstore.RouteResult, error) { client, err := sc.client() if err != nil { - return err + return nil, err } - if err := client.RouteTunnel(tunnelID, r); err != nil { - return err - } - - return nil + return client.RouteTunnel(tunnelID, r) } func (sc *subcommandContext) tunnelActive(name string) (*tunnelstore.Tunnel, bool, error) { diff --git a/cmd/cloudflared/tunnel/subcommands.go b/cmd/cloudflared/tunnel/subcommands.go index 57b8c2f6..b2262f19 100644 --- a/cmd/cloudflared/tunnel/subcommands.go +++ b/cmd/cloudflared/tunnel/subcommands.go @@ -438,7 +438,7 @@ func routeCommand(c *cli.Context) error { const tunnelIDIndex = 1 routeType := c.Args().First() - var r tunnelstore.Route + var route tunnelstore.Route var tunnelID uuid.UUID switch routeType { case "dns": @@ -446,7 +446,7 @@ func routeCommand(c *cli.Context) error { if err != nil { return err } - r, err = dnsRouteFromArg(c) + route, err = dnsRouteFromArg(c) if err != nil { return err } @@ -455,7 +455,7 @@ func routeCommand(c *cli.Context) error { if err != nil { return err } - r, err = lbRouteFromArg(c) + route, err = lbRouteFromArg(c) if err != nil { return err } @@ -463,10 +463,12 @@ func routeCommand(c *cli.Context) error { return cliutil.UsageError("%s is not a recognized route type. Supported route types are dns and lb", routeType) } - if err := sc.route(tunnelID, r); err != nil { + res, err := sc.route(tunnelID, route) + if err != nil { return err } - sc.logger.Infof(r.SuccessSummary()) + + sc.logger.Infof(res.SuccessSummary()) return nil } diff --git a/tunnelstore/client.go b/tunnelstore/client.go index 6544b7bf..545d1752 100644 --- a/tunnelstore/client.go +++ b/tunnelstore/client.go @@ -43,10 +43,22 @@ type Connection struct { IsPendingReconnect bool `json:"is_pending_reconnect"` } +type Change = string + +const ( + ChangeNew = "new" + ChangeUpdated = "updated" + ChangeUnchanged = "unchanged" +) + // Route represents a record type that can route to a tunnel type Route interface { json.Marshaler RecordType() string + UnmarshalResult(body io.Reader) (RouteResult, error) +} + +type RouteResult interface { // SuccessSummary explains what will route to this tunnel when it's provisioned successfully SuccessSummary() string } @@ -55,6 +67,11 @@ type DNSRoute struct { userHostname string } +type DNSRouteResult struct { + route *DNSRoute + CName Change `json:"cname"` +} + func NewDNSRoute(userHostname string) Route { return &DNSRoute{ userHostname: userHostname, @@ -72,12 +89,30 @@ func (dr *DNSRoute) MarshalJSON() ([]byte, error) { return json.Marshal(&s) } +func (dr *DNSRoute) UnmarshalResult(body io.Reader) (RouteResult, error) { + var result DNSRouteResult + if err := json.NewDecoder(body).Decode(&result); err != nil { + return nil, err + } + result.route = dr + return &result, nil +} + func (dr *DNSRoute) RecordType() string { return "dns" } -func (dr *DNSRoute) SuccessSummary() string { - return fmt.Sprintf("%s will route to your tunnel", dr.userHostname) +func (res *DNSRouteResult) SuccessSummary() string { + var msgFmt string + switch res.CName { + case ChangeNew: + msgFmt = "Added CNAME %s which will route to this tunnel" + case ChangeUpdated: // this is not currently returned by tunnelsore + msgFmt = "%s updated to route to your tunnel" + case ChangeUnchanged: + msgFmt = "%s is already configured to route to your tunnel" + } + return fmt.Sprintf(msgFmt, res.route.userHostname) } type LBRoute struct { @@ -85,6 +120,12 @@ type LBRoute struct { lbPool string } +type LBRouteResult struct { + route *LBRoute + LoadBalancer Change `json:"load_balancer"` + Pool Change `json:"pool"` +} + func NewLBRoute(lbName, lbPool string) Route { return &LBRoute{ lbName: lbName, @@ -109,8 +150,42 @@ func (lr *LBRoute) RecordType() string { return "lb" } -func (lr *LBRoute) SuccessSummary() string { - return fmt.Sprintf("Load balancer %s will route to this tunnel through pool %s", lr.lbName, lr.lbPool) +func (lr *LBRoute) UnmarshalResult(body io.Reader) (RouteResult, error) { + var result LBRouteResult + if err := json.NewDecoder(body).Decode(&result); err != nil { + return nil, err + } + result.route = lr + return &result, nil +} + +func (res *LBRouteResult) SuccessSummary() string { + var msg string + switch res.LoadBalancer + "," + res.Pool { + case "new,new": + msg = "Created load balancer %s and added a new pool %s with this tunnel as an origin" + case "new,updated": + msg = "Created load balancer %s with an existing pool %s which was updated to use this tunnel as an origin" + case "new,unchanged": + msg = "Created load balancer %s with an existing pool %s which already has this tunnel as an origin" + case "updated,new": + msg = "Added new pool %[2]s with this tunnel as an origin to load balancer %[1]s" + case "updated,updated": + msg = "Updated pool %[2]s to use this tunnel as an origin and added it to load balancer %[1]s" + case "updated,unchanged": + msg = "Added pool %[2]s, which already has this tunnel as an origin, to load balancer %[1]s" + case "unchanged,updated": + msg = "Added this tunnel as an origin in pool %[2]s which is already used by load balancer %[1]s" + case "unchanged,unchanged": + msg = "Load balancer %s already uses pool %s which has this tunnel as an origin" + case "unchanged,new": + // this state is not possible + fallthrough + default: + msg = "Something went wrong: failed to modify load balancer %s with pool %s; please check traffic manager configuration in the dashboard" + } + + return fmt.Sprintf(msg, res.route.lbName, res.route.lbPool) } type Client interface { @@ -119,7 +194,7 @@ type Client interface { DeleteTunnel(tunnelID uuid.UUID) error ListTunnels(filter *Filter) ([]*Tunnel, error) CleanupConnections(tunnelID uuid.UUID) error - RouteTunnel(tunnelID uuid.UUID, route Route) error + RouteTunnel(tunnelID uuid.UUID, route Route) (RouteResult, error) } type RESTClient struct { @@ -260,16 +335,20 @@ func (r *RESTClient) CleanupConnections(tunnelID uuid.UUID) error { return r.statusCodeToError("cleanup connections", resp) } -func (r *RESTClient) RouteTunnel(tunnelID uuid.UUID, route Route) error { +func (r *RESTClient) RouteTunnel(tunnelID uuid.UUID, route Route) (RouteResult, error) { endpoint := r.baseEndpoints.zoneLevel endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v/routes", tunnelID)) resp, err := r.sendRequest("PUT", endpoint, route) if err != nil { - return errors.Wrap(err, "REST request failed") + return nil, errors.Wrap(err, "REST request failed") } defer resp.Body.Close() - return r.statusCodeToError("add route", resp) + if resp.StatusCode == http.StatusOK { + return route.UnmarshalResult(resp.Body) + } + + return nil, r.statusCodeToError("add route", resp) } func (r *RESTClient) sendRequest(method string, url url.URL, body interface{}) (*http.Response, error) { @@ -304,10 +383,10 @@ func unmarshalTunnel(reader io.Reader) (*Tunnel, error) { func (r *RESTClient) statusCodeToError(op string, resp *http.Response) error { if resp.Header.Get("Content-Type") == "application/json" { - var errorsResp struct{ + var errorsResp struct { Error string `json:"error"` } - if json.NewDecoder(resp.Body).Decode(&errorsResp) == nil && errorsResp.Error != ""{ + if json.NewDecoder(resp.Body).Decode(&errorsResp) == nil && errorsResp.Error != "" { return errors.Errorf("Failed to %s: %s", op, errorsResp.Error) } } diff --git a/tunnelstore/client_test.go b/tunnelstore/client_test.go new file mode 100644 index 00000000..61aa5cab --- /dev/null +++ b/tunnelstore/client_test.go @@ -0,0 +1,78 @@ +package tunnelstore + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDNSRouteUnmarshalResult(t *testing.T) { + route := &DNSRoute{ + userHostname: "example.com", + } + + result, err := route.UnmarshalResult(strings.NewReader(`{"cname": "new"}`)) + + assert.NoError(t, err) + assert.Equal(t, &DNSRouteResult{ + route: route, + CName: ChangeNew, + }, result) + + _, err = route.UnmarshalResult(strings.NewReader(`abc`)) + assert.NotNil(t, err) +} + +func TestLBRouteUnmarshalResult(t *testing.T) { + route := &LBRoute{ + lbName: "lb.example.com", + lbPool: "pool", + } + + result, err := route.UnmarshalResult(strings.NewReader(`{"pool": "unchanged", "load_balancer": "updated"}`)) + + assert.NoError(t, err) + assert.Equal(t, &LBRouteResult{ + route: route, + LoadBalancer: ChangeUpdated, + Pool: ChangeUnchanged, + }, result) + + _, err = route.UnmarshalResult(strings.NewReader(`abc`)) + assert.NotNil(t, err) +} + +func TestLBRouteResultSuccessSummary(t *testing.T) { + route := &LBRoute{ + lbName: "lb.example.com", + lbPool: "POOL", + } + + tests := []struct { + lb Change + pool Change + expected string + }{ + {ChangeNew, ChangeNew, "Created load balancer lb.example.com and added a new pool POOL with this tunnel as an origin" }, + {ChangeNew, ChangeUpdated, "Created load balancer lb.example.com with an existing pool POOL which was updated to use this tunnel as an origin" }, + {ChangeNew, ChangeUnchanged, "Created load balancer lb.example.com with an existing pool POOL which already has this tunnel as an origin" }, + {ChangeUpdated, ChangeNew, "Added new pool POOL with this tunnel as an origin to load balancer lb.example.com" }, + {ChangeUpdated, ChangeUpdated, "Updated pool POOL to use this tunnel as an origin and added it to load balancer lb.example.com" }, + {ChangeUpdated, ChangeUnchanged, "Added pool POOL, which already has this tunnel as an origin, to load balancer lb.example.com" }, + {ChangeUnchanged, ChangeNew, "Something went wrong: failed to modify load balancer lb.example.com with pool POOL; please check traffic manager configuration in the dashboard" }, + {ChangeUnchanged, ChangeUpdated, "Added this tunnel as an origin in pool POOL which is already used by load balancer lb.example.com" }, + {ChangeUnchanged, ChangeUnchanged, "Load balancer lb.example.com already uses pool POOL which has this tunnel as an origin" }, + {"", "", "Something went wrong: failed to modify load balancer lb.example.com with pool POOL; please check traffic manager configuration in the dashboard" }, + {"a", "b", "Something went wrong: failed to modify load balancer lb.example.com with pool POOL; please check traffic manager configuration in the dashboard" }, + } + for i, tt := range tests { + res := &LBRouteResult{ + route: route, + LoadBalancer: tt.lb, + Pool: tt.pool, + } + actual := res.SuccessSummary() + assert.Equal(t, tt.expected, actual, "case %d", i+1) + } +}