diff --git a/.github/ISSUE_TEMPLATE/bug-report---.md b/.github/ISSUE_TEMPLATE/bug-report---.md new file mode 100644 index 00000000..bb54d9f8 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug-report---.md @@ -0,0 +1,31 @@ +--- +name: "Bug report \U0001F41B" +about: Create a report to help us improve cloudflared +title: '' +labels: awaiting reply, bug +assignees: '' + +--- + +**Describe the bug** +A clear and concise description of what the bug is. + +**To Reproduce** +Steps to reproduce the behavior: +1. Configure '...' +2. Run '....' +3. See error + +**Expected behavior** +A clear and concise description of what you expected to happen. + +**Environment and versions** + - OS: [e.g. MacOS] + - Architecture: [e.g. AMD, ARM] + - Version: [e.g. 2022.02.0] + +**Logs and errors** +If applicable, add logs or errors to help explain your problem. + +**Additional context** +Add any other context about the problem here. diff --git a/.github/ISSUE_TEMPLATE/feature-request---.md b/.github/ISSUE_TEMPLATE/feature-request---.md new file mode 100644 index 00000000..62fdbe20 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature-request---.md @@ -0,0 +1,17 @@ +--- +name: "Feature request \U0001F4A1" +about: Suggest a feature or enhancement for cloudflared +title: '' +labels: awaiting reply, feature-request +assignees: '' + +--- + +**Describe the feature you'd like** +A clear and concise description of the feature. What problem does it solve for you? + +**Describe alternatives you've considered** +Are there any alternatives to solving this problem? If so, what was your experience with them? + +**Additional context** +Add any other context or screenshots about the feature request here. diff --git a/LICENSE b/LICENSE index 436a62ab..7642b8f5 100644 --- a/LICENSE +++ b/LICENSE @@ -1,155 +1,211 @@ -SERVICES AGREEMENT +Apache License +Version 2.0, January 2004 +http://www.apache.org/licenses/ -Your installation of this software is symbol of your signature indicating that -you accept the terms of this Services Agreement (this "Agreement"). This -Agreement is a legal agreement between you (either an individual or a single -entity) and CloudFlare, Inc. for the services being provided to you by -CloudFlare or its authorized representative (the "Services"), including any -computer software and any associated media, printed materials, and "online" or -electronic documentation provided in connection with the Services (the -"Software" and together with the Services are hereinafter collectively referred -to as the "Solution"). If the user is not an individual, then "you" means your -company, its officers, members, employees, agents, representatives, successors -and assigns. BY USING THE SOLUTION, YOU ARE INDICATING THAT YOU HAVE READ, AND -AGREE TO BE BOUND BY, THE POLICIES, TERMS, AND CONDITIONS SET FORTH BELOW IN -THEIR ENTIRETY WITHOUT LIMITATION OR QUALIFICATION, AS WELL AS BY ALL APPLICABLE -LAWS AND REGULATIONS, AS IF YOU HAD HANDWRITTEN YOUR NAME ON A CONTRACT. IF YOU -DO NOT AGREE TO THESE TERMS AND CONDITIONS, YOU MAY NOT USE THE SOLUTION. +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION -1. GRANT OF RIGHTS +1. Definitions. -1.1 Grant of License. The Solution is licensed by CloudFlare and its -licensors, not sold. Subject to the terms and conditions of this Agreement, -CloudFlare hereby grants you a nonexclusive, nonsublicensable, nontransferable -license to use the Solution. You may examine source code, if provided to you, -solely for the limited purpose of evaluating the Software for security flaws. -You may also use the Service to create derivative works which are exclusively -compatible with any CloudFlare product serviceand no other product or service. -This license applies to the parts of the Solution developed by CloudFlare. The -Solution may also incorporate externally maintained libraries and other open software. -These resources may be governed by other licenses. +"License" shall mean the terms and conditions for use, reproduction, +and distribution as defined by Sections 1 through 9 of this document. -1.2 Restrictions. The license granted herein is granted solely to you and -not, by implication or otherwise, to any of your parents, subsidiaries or -affiliates. No right is granted hereunder to use the Solution to perform -services for third parties. All rights not expressly granted hereunder are -reserved to CloudFlare. You may not use the Solution except as explicitly -permitted under this Agreement. You are expressly prohibited from modifying, -adapting, translating, preparing derivative works from, decompiling, reverse -engineering, disassembling or otherwise attempting to derive source code from -the Software used to provide the Services or any internal data files generated -by the Solution. You are also prohibited from removing, obscuring or altering -any copyright notice, trademarks, or other proprietary rights notices affixed to -or associated with the Solution. +"Licensor" shall mean the copyright owner or entity authorized by +the copyright owner that is granting the License. -1.3 Ownership. As between the parties, CloudFlare and/or its licensors own -and shall retain all right, title, and interest in and to the Solution, -including any and all technology embodied therein, including all copyrights, -patents, trade secrets, trade dress and other proprietary rights associated -therewith, and any derivative works created there from. +"Legal Entity" shall mean the union of the acting entity and all +other entities that control, are controlled by, or are under common +control with that entity. For the purposes of this definition, +"control" means (i) the power, direct or indirect, to cause the +direction or management of such entity, whether by contract or +otherwise, or (ii) ownership of fifty percent (50%) or more of the +outstanding shares, or (iii) beneficial ownership of such entity. -2. LIMITATION OF LIABILITY +"You" (or "Your") shall mean an individual or Legal Entity +exercising permissions granted by this License. -YOU EXPRESSLY ACKNOWLEDGE AND AGREE THAT DOWNLOADING THE SOFTWARE IS AT YOUR -SOLE RISK. THE SOFTWARE IS PROVIDED "AS IS" AND WITHOUT WARRANTY OF ANY KIND -AND CLOUDFLARE, ITS LICENSORS AND ITS AUTHORIZED REPRESENTATIVES (TOGETHER FOR -PURPOSES HEREOF, "CLOUDFLARE") EXPRESSLY DISCLAIM ALL WARRANTIES, EXPRESS OR -IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF -MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. CLOUDFLARE DOES NOT -WARRANT THAT THE FUNCTIONS CONTAINED IN THE SOFTWARE WILL MEET YOUR -REQUIREMENTS, OR THAT THE OPERATION OF THE SOFTWARE WILL BE UNINTERRUPTED OR -ERROR-FREE, OR THAT DEFECTS IN THE SOFTWARE WILL BE CORRECTED. FURTHERMORE, -CLOUDFLARE DOES NOT WARRANT OR MAKE ANY REPRESENTATIONS REGARDING THE SOFTWARE -OR RELATED DOCUMENTATION IN TERMS OF THEIR CORRECTNESS, ACCURACY, RELIABILITY, -OR OTHERWISE. NO ORAL OR WRITTEN INFORMATION OR ADVICE GIVEN BY CLOUDFLARE SHALL -CREATE A WARRANTY OR IN ANY WAY INCREASE THE SCOPE OF THIS WARRANTY. +"Source" form shall mean the preferred form for making modifications, +including but not limited to software source code, documentation +source, and configuration files. -3. CONFIDENTIALITY +"Object" form shall mean any form resulting from mechanical +transformation or translation of a Source form, including but +not limited to compiled object code, generated documentation, +and conversions to other media types. -It may be necessary during the set up and performance of the Solution for the -parties to exchange Confidential Information. "Confidential Information" means -any information whether oral, or written, of a private, secret, proprietary or -confidential nature, concerning either party or its business operations, -including without limitation: (a) your data and (b) CloudFlare's access control -systems, specialized network equipment and techniques related to the Solution, -use policies, which include trade secrets of CloudFlare and its licensors. Each -party agrees to use the same degree of care to protect the confidentiality of -the Confidential Information of the other party and to prevent its unauthorized -use or dissemination as it uses to protect its own Confidential Information of a -similar nature, but in no event shall exercise less than due diligence and -reasonable care. Each party agrees to use the Confidential Information of the -other party only for purposes related to the performance of this Agreement. All -Confidential Information remains the property of the party disclosing the -information and no license or other rights to Confidential Information is -granted or implied hereby. +"Work" shall mean the work of authorship, whether in Source or +Object form, made available under the License, as indicated by a +copyright notice that is included in or attached to the work +(an example is provided in the Appendix below). -4. TERM AND TERMINATION +"Derivative Works" shall mean any work, whether in Source or Object +form, that is based on (or derived from) the Work and for which the +editorial revisions, annotations, elaborations, or other modifications +represent, as a whole, an original work of authorship. For the purposes +of this License, Derivative Works shall not include works that remain +separable from, or merely link (or bind by name) to the interfaces of, +the Work and Derivative Works thereof. -4.1 Term. This Agreement shall be effective upon download or install of the -Software. +"Contribution" shall mean any work of authorship, including +the original version of the Work and any modifications or additions +to that Work or Derivative Works thereof, that is intentionally +submitted to Licensor for inclusion in the Work by the copyright owner +or by an individual or Legal Entity authorized to submit on behalf of +the copyright owner. For the purposes of this definition, "submitted" +means any form of electronic, verbal, or written communication sent +to the Licensor or its representatives, including but not limited to +communication on electronic mailing lists, source code control systems, +and issue tracking systems that are managed by, or on behalf of, the +Licensor for the purpose of discussing and improving the Work, but +excluding communication that is conspicuously marked or otherwise +designated in writing by the copyright owner as "Not a Contribution." -4.2 Termination. This Agreement may be terminated by CloudFlare or its -authorized representative by written notice to you if any of the following -events occur: (i) you fail to pay any amounts due for the Services and the -Solution when due and after written notice of such nonpayment has been given to -you; (ii) you are in material breach of any term, condition, or provision of -this Agreement or any other agreement executed by you with CloudFlare or its -authorized representative in connection with the provision of the Solution and -Services (a "Related Agreement"); or (iii) you terminate or suspend your -business, becomes subject to any bankruptcy or insolvency proceeding under -federal or state statutes, or become insolvent or subject to direct control by a -trustee, receiver or similar authority. +"Contributor" shall mean Licensor and any individual or Legal Entity +on behalf of whom a Contribution has been received by Licensor and +subsequently incorporated within the Work. -4.3 Effect of Termination. Upon the termination of this Agreement for any -reason: (1) all license rights granted hereunder shall terminate and (2) all -Confidential Information shall be returned to the disclosing party or destroyed. +2. Grant of Copyright License. Subject to the terms and conditions of +this License, each Contributor hereby grants to You a perpetual, +worldwide, non-exclusive, no-charge, royalty-free, irrevocable +copyright license to reproduce, prepare Derivative Works of, +publicly display, publicly perform, sublicense, and distribute the +Work and such Derivative Works in Source or Object form. -5. MISCELLANEOUS +3. Grant of Patent License. Subject to the terms and conditions of +this License, each Contributor hereby grants to You a perpetual, +worldwide, non-exclusive, no-charge, royalty-free, irrevocable +(except as stated in this section) patent license to make, have made, +use, offer to sell, sell, import, and otherwise transfer the Work, +where such license applies only to those patent claims licensable +by such Contributor that are necessarily infringed by their +Contribution(s) alone or by combination of their Contribution(s) +with the Work to which such Contribution(s) was submitted. If You +institute patent litigation against any entity (including a +cross-claim or counterclaim in a lawsuit) alleging that the Work +or a Contribution incorporated within the Work constitutes direct +or contributory patent infringement, then any patent licenses +granted to You under this License for that Work shall terminate +as of the date such litigation is filed. -5.1 Assignment. You may not assign any of your rights or delegate any of -your obligations under this Agreement, whether by operation of law or otherwise, -without the prior express written consent of CloudFlare or its authorized -representative. Any such assignment without the prior express written consent -of CloudFlare or its authorized representative shall be void. Subject to the -foregoing, this Agreement will bind and inure to the benefit of the parties, -their respective successors and permitted assigns. +4. Redistribution. You may reproduce and distribute copies of the +Work or Derivative Works thereof in any medium, with or without +modifications, and in Source or Object form, provided that You +meet the following conditions: -5.2 Waiver and Amendment. No modification, amendment or waiver of any -provision of this Agreement shall be effective unless in writing and signed by -the party to be charged. No failure or delay by either party in exercising any -right, power, or remedy under this Agreement, except as specifically provided -herein, shall operate as a waiver of any such right, power or remedy. Without -limiting the foregoing, terms and conditions on any purchase orders or similar -materials submitted by you to CloudFlare or its authorized representative shall -be of no force or effect. +(a) You must give any other recipients of the Work or +Derivative Works a copy of this License; and -5.3 Governing Law. This Agreement shall be governed by the laws of the State -of California, USA, excluding conflict of laws and provisions, and excluding the -United Nations Convention on Contracts for the International Sale of Goods. +(b) You must cause any modified files to carry prominent notices +stating that You changed the files; and -5.4 Notices. All notices, demands or consents required or permitted under -this Agreement shall be in writing. Notice shall be sent to you at the e-mail -address provided by you to CloudFlare or its authorized representative in -connection with the Solution. +(c) You must retain, in the Source form of any Derivative Works +that You distribute, all copyright, patent, trademark, and +attribution notices from the Source form of the Work, +excluding those notices that do not pertain to any part of +the Derivative Works; and -5.5 Independent Contractors. The parties are independent contractors. -Neither party shall be deemed to be an employee, agent, partner or legal -representative of the other for any purpose and neither shall have any right, -power or authority to create any obligation or responsibility on behalf of the -other. +(d) If the Work includes a "NOTICE" text file as part of its +distribution, then any Derivative Works that You distribute must +include a readable copy of the attribution notices contained +within such NOTICE file, excluding those notices that do not +pertain to any part of the Derivative Works, in at least one +of the following places: within a NOTICE text file distributed +as part of the Derivative Works; within the Source form or +documentation, if provided along with the Derivative Works; or, +within a display generated by the Derivative Works, if and +wherever such third-party notices normally appear. The contents +of the NOTICE file are for informational purposes only and +do not modify the License. You may add Your own attribution +notices within Derivative Works that You distribute, alongside +or as an addendum to the NOTICE text from the Work, provided +that such additional attribution notices cannot be construed +as modifying the License. -5.6 Severability. If any provision of this Agreement is held by a court of -competent jurisdiction to be contrary to law, such provision shall be changed -and interpreted so as to best accomplish the objectives of the original -provision to the fullest extent allowed by law and the remaining provisions of -this Agreement shall remain in full force and effect. +You may add Your own copyright statement to Your modifications and +may provide additional or different license terms and conditions +for use, reproduction, or distribution of Your modifications, or +for any such Derivative Works as a whole, provided Your use, +reproduction, and distribution of the Work otherwise complies with +the conditions stated in this License. -5.7 Force Majeure. CloudFlare shall not be liable to the other party for any -failure or delay in performance caused by reasons beyond its reasonable control. +5. Submission of Contributions. Unless You explicitly state otherwise, +any Contribution intentionally submitted for inclusion in the Work +by You to the Licensor shall be under the terms and conditions of +this License, without any additional terms or conditions. +Notwithstanding the above, nothing herein shall supersede or modify +the terms of any separate license agreement you may have executed +with Licensor regarding such Contributions. -5.8 Complete Understanding. This Agreement and the Related Agreement -constitute the final, complete and exclusive agreement between the parties with -respect to the subject matter hereof, and supersedes all previous written and -oral agreements and communications related to the subject matter of this -Agreement. To the extent this Agreement and the Related Agreement conflict, -this Agreement shall control. +6. Trademarks. This License does not grant permission to use the trade +names, trademarks, service marks, or product names of the Licensor, +except as required for reasonable and customary use in describing the +origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or +agreed to in writing, Licensor provides the Work (and each +Contributor provides its Contributions) on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +implied, including, without limitation, any warranties or conditions +of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A +PARTICULAR PURPOSE. You are solely responsible for determining the +appropriateness of using or redistributing the Work and assume any +risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, +whether in tort (including negligence), contract, or otherwise, +unless required by applicable law (such as deliberate and grossly +negligent acts) or agreed to in writing, shall any Contributor be +liable to You for damages, including any direct, indirect, special, +incidental, or consequential damages of any character arising as a +result of this License or out of the use or inability to use the +Work (including but not limited to damages for loss of goodwill, +work stoppage, computer failure or malfunction, or any and all +other commercial damages or losses), even if such Contributor +has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing +the Work or Derivative Works thereof, You may choose to offer, +and charge a fee for, acceptance of support, warranty, indemnity, +or other liability obligations and/or rights consistent with this +License. However, in accepting such obligations, You may act only +on Your own behalf and on Your sole responsibility, not on behalf +of any other Contributor, and only if You agree to indemnify, +defend, and hold each Contributor harmless for any liability +incurred by, or claims asserted against, such Contributor by reason +of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + +To apply the Apache License to your work, attach the following +boilerplate notice, with the fields enclosed by brackets "[]" +replaced with your own identifying information. (Don't include +the brackets!) The text should be enclosed in the appropriate +comment syntax for the file format. We also recommend that a +file or class name and description of purpose be included on the +same "printed page" as the copyright notice for easier +identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + + + +## Runtime Library Exception to the Apache 2.0 License: ## + + +As an exception, if you use this Software to compile your source code and +portions of this Software are embedded into the binary product as a result, +you may redistribute such product without providing attribution as would +otherwise be required by Sections 4(a), 4(b) and 4(d) of the License. diff --git a/RELEASE_NOTES b/RELEASE_NOTES index fc973dbf..153685ed 100644 --- a/RELEASE_NOTES +++ b/RELEASE_NOTES @@ -1,3 +1,26 @@ +2022.2.2 +- 2022-02-22 TUN-5754: Allow ingress validate to take plaintext option +- 2022-02-17 TUN-5678: Cloudflared uses typed tunnel API + +2022.2.1 +- 2022-02-10 TUN-5184: Handle errors in bidrectional streaming (websocket#Stream) gracefully when 1 side has ended +- 2022-02-14 Update issue templates +- 2022-02-14 Update issue templates +- 2022-02-11 TUN-5768: Update cloudflared license file +- 2022-02-11 TUN-5698: Make ingress rules and warp routing dynamically configurable +- 2022-02-14 TUN-5678: Adapt cloudflared to use new typed APIs +- 2022-02-17 Revert "TUN-5678: Adapt cloudflared to use new typed APIs" +- 2022-02-11 TUN-5697: Listen for UpdateConfiguration RPC in quic transport +- 2022-02-04 TUN-5744: Add a test to make sure cloudflared uses scheme defined in ingress rule, not X-Forwarded-Proto header +- 2022-02-07 TUN-5749: Refactor cloudflared to pave way for reconfigurable ingress - Split origin into supervisor and proxy packages - Create configManager to handle dynamic config +- 2021-10-19 TUN-5184: Make sure outstanding websocket write is finished, and no more writes after shutdown + +2022.2.0 +- 2022-02-02 TUN-4947: Use http when talking to Unix sockets origins +- 2022-02-02 TUN-5695: Define RPC method to update configuration +- 2022-01-27 TUN-5621: Correctly manage QUIC stream closing +- 2022-01-28 TUN-5702: Allow to deserialize config from JSON + 2022.1.3 - 2022-01-21 TUN-5477: Unhide vnet commands - 2022-01-24 TUN-5669: Change network command to vnet diff --git a/cfapi/base_client.go b/cfapi/base_client.go index 42b99316..48b349c3 100644 --- a/cfapi/base_client.go +++ b/cfapi/base_client.go @@ -48,7 +48,7 @@ func NewRESTClient(baseURL, accountTag, zoneTag, authToken, userAgent string, lo if strings.HasSuffix(baseURL, "/") { baseURL = baseURL[:len(baseURL)-1] } - accountLevelEndpoint, err := url.Parse(fmt.Sprintf("%s/accounts/%s/tunnels", baseURL, accountTag)) + accountLevelEndpoint, err := url.Parse(fmt.Sprintf("%s/accounts/%s/cfd_tunnel", baseURL, accountTag)) if err != nil { return nil, errors.Wrap(err, "failed to create account level endpoint") } diff --git a/cfapi/client.go b/cfapi/client.go index b4f17927..d1e794f7 100644 --- a/cfapi/client.go +++ b/cfapi/client.go @@ -5,7 +5,7 @@ import ( ) type TunnelClient interface { - CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, error) + CreateTunnel(name string, tunnelSecret []byte) (*TunnelWithToken, error) GetTunnel(tunnelID uuid.UUID) (*Tunnel, error) DeleteTunnel(tunnelID uuid.UUID) error ListTunnels(filter *TunnelFilter) ([]*Tunnel, error) diff --git a/cfapi/tunnel.go b/cfapi/tunnel.go index 5d4ea298..cede09c9 100644 --- a/cfapi/tunnel.go +++ b/cfapi/tunnel.go @@ -23,6 +23,11 @@ type Tunnel struct { Connections []Connection `json:"connections"` } +type TunnelWithToken struct { + Tunnel + Token string `json:"token"` +} + type Connection struct { ColoName string `json:"colo_name"` ID uuid.UUID `json:"id"` @@ -63,7 +68,7 @@ func (cp CleanupParams) encode() string { return cp.queryParams.Encode() } -func (r *RESTClient) CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, error) { +func (r *RESTClient) CreateTunnel(name string, tunnelSecret []byte) (*TunnelWithToken, error) { if name == "" { return nil, errors.New("tunnel name required") } @@ -83,7 +88,11 @@ func (r *RESTClient) CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, er switch resp.StatusCode { case http.StatusOK: - return unmarshalTunnel(resp.Body) + var tunnel TunnelWithToken + if serdeErr := parseResponse(resp.Body, &tunnel); err != nil { + return nil, serdeErr + } + return &tunnel, nil case http.StatusConflict: return nil, ErrTunnelNameConflict } diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 8833e209..2a692f73 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -31,8 +31,9 @@ import ( "github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/metrics" - "github.com/cloudflare/cloudflared/origin" + "github.com/cloudflare/cloudflared/orchestration" "github.com/cloudflare/cloudflared/signal" + "github.com/cloudflare/cloudflared/supervisor" "github.com/cloudflare/cloudflared/tlsconfig" "github.com/cloudflare/cloudflared/tunneldns" ) @@ -223,7 +224,7 @@ func routeFromFlag(c *cli.Context) (route cfapi.HostnameRoute, ok bool) { func StartServer( c *cli.Context, info *cliutil.BuildInfo, - namedTunnel *connection.NamedTunnelConfig, + namedTunnel *connection.NamedTunnelProperties, log *zerolog.Logger, isUIEnabled bool, ) error { @@ -333,7 +334,7 @@ func StartServer( observer.SendURL(quickTunnelURL) } - tunnelConfig, ingressRules, err := prepareTunnelConfig(c, info, log, logTransport, observer, namedTunnel) + tunnelConfig, dynamicConfig, err := prepareTunnelConfig(c, info, log, logTransport, observer, namedTunnel) if err != nil { log.Err(err).Msg("Couldn't start tunnel") return err @@ -353,11 +354,12 @@ func StartServer( errC <- metrics.ServeMetrics(metricsListener, ctx.Done(), readinessServer, quickTunnelURL, log) }() - if err := ingressRules.StartOrigins(&wg, log, ctx.Done(), errC); err != nil { + orchestrator, err := orchestration.NewOrchestrator(ctx, dynamicConfig, tunnelConfig.Tags, tunnelConfig.Log) + if err != nil { return err } - reconnectCh := make(chan origin.ReconnectSignal, 1) + reconnectCh := make(chan supervisor.ReconnectSignal, 1) if c.IsSet("stdin-control") { log.Info().Msg("Enabling control through stdin") go stdinControl(reconnectCh, log) @@ -369,7 +371,7 @@ func StartServer( wg.Done() log.Info().Msg("Tunnel server stopped") }() - errC <- origin.StartTunnelDaemon(ctx, tunnelConfig, connectedSignal, reconnectCh, graceShutdownC) + errC <- supervisor.StartTunnelDaemon(ctx, tunnelConfig, orchestrator, connectedSignal, reconnectCh, graceShutdownC) }() if isUIEnabled { @@ -377,7 +379,7 @@ func StartServer( info.Version(), hostname, metricsListener.Addr().String(), - &ingressRules, + dynamicConfig.Ingress, tunnelConfig.HAConnections, ) app := tunnelUI.Launch(ctx, log, logTransport) @@ -998,7 +1000,7 @@ func configureProxyDNSFlags(shouldHide bool) []cli.Flag { } } -func stdinControl(reconnectCh chan origin.ReconnectSignal, log *zerolog.Logger) { +func stdinControl(reconnectCh chan supervisor.ReconnectSignal, log *zerolog.Logger) { for { scanner := bufio.NewScanner(os.Stdin) for scanner.Scan() { @@ -1009,7 +1011,7 @@ func stdinControl(reconnectCh chan origin.ReconnectSignal, log *zerolog.Logger) case "": break case "reconnect": - var reconnect origin.ReconnectSignal + var reconnect supervisor.ReconnectSignal if len(parts) > 1 { var err error if reconnect.Delay, err = time.ParseDuration(parts[1]); err != nil { diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index 0d29cc7d..27ce90e0 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -23,7 +23,8 @@ import ( "github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/ingress" - "github.com/cloudflare/cloudflared/origin" + "github.com/cloudflare/cloudflared/orchestration" + "github.com/cloudflare/cloudflared/supervisor" "github.com/cloudflare/cloudflared/tlsconfig" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" "github.com/cloudflare/cloudflared/validation" @@ -87,7 +88,7 @@ func logClientOptions(c *cli.Context, log *zerolog.Logger) { } } -func dnsProxyStandAlone(c *cli.Context, namedTunnel *connection.NamedTunnelConfig) bool { +func dnsProxyStandAlone(c *cli.Context, namedTunnel *connection.NamedTunnelProperties) bool { return c.IsSet("proxy-dns") && (!c.IsSet("hostname") && !c.IsSet("tag") && !c.IsSet("hello-world") && namedTunnel == nil) } @@ -152,44 +153,44 @@ func prepareTunnelConfig( info *cliutil.BuildInfo, log, logTransport *zerolog.Logger, observer *connection.Observer, - namedTunnel *connection.NamedTunnelConfig, -) (*origin.TunnelConfig, ingress.Ingress, error) { + namedTunnel *connection.NamedTunnelProperties, +) (*supervisor.TunnelConfig, *orchestration.Config, error) { isNamedTunnel := namedTunnel != nil configHostname := c.String("hostname") hostname, err := validation.ValidateHostname(configHostname) if err != nil { log.Err(err).Str(LogFieldHostname, configHostname).Msg("Invalid hostname") - return nil, ingress.Ingress{}, errors.Wrap(err, "Invalid hostname") + return nil, nil, errors.Wrap(err, "Invalid hostname") } clientID := c.String("id") if !c.IsSet("id") { clientID, err = generateRandomClientID(log) if err != nil { - return nil, ingress.Ingress{}, err + return nil, nil, err } } tags, err := NewTagSliceFromCLI(c.StringSlice("tag")) if err != nil { log.Err(err).Msg("Tag parse failure") - return nil, ingress.Ingress{}, errors.Wrap(err, "Tag parse failure") + return nil, nil, errors.Wrap(err, "Tag parse failure") } tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID}) var ( ingressRules ingress.Ingress - classicTunnel *connection.ClassicTunnelConfig + classicTunnel *connection.ClassicTunnelProperties ) cfg := config.GetConfiguration() if isNamedTunnel { clientUUID, err := uuid.NewRandom() if err != nil { - return nil, ingress.Ingress{}, errors.Wrap(err, "can't generate connector UUID") + return nil, nil, errors.Wrap(err, "can't generate connector UUID") } log.Info().Msgf("Generated Connector ID: %s", clientUUID) - features := append(c.StringSlice("features"), origin.FeatureSerializedHeaders) + features := append(c.StringSlice("features"), supervisor.FeatureSerializedHeaders) namedTunnel.Client = tunnelpogs.ClientInfo{ ClientID: clientUUID[:], Features: dedup(features), @@ -198,10 +199,10 @@ func prepareTunnelConfig( } ingressRules, err = ingress.ParseIngress(cfg) if err != nil && err != ingress.ErrNoIngressRules { - return nil, ingress.Ingress{}, err + return nil, nil, err } if !ingressRules.IsEmpty() && c.IsSet("url") { - return nil, ingress.Ingress{}, ingress.ErrURLIncompatibleWithIngress + return nil, nil, ingress.ErrURLIncompatibleWithIngress } } else { @@ -212,10 +213,10 @@ func prepareTunnelConfig( originCert, err := getOriginCert(originCertPath, &originCertLog) if err != nil { - return nil, ingress.Ingress{}, errors.Wrap(err, "Error getting origin cert") + return nil, nil, errors.Wrap(err, "Error getting origin cert") } - classicTunnel = &connection.ClassicTunnelConfig{ + classicTunnel = &connection.ClassicTunnelProperties{ Hostname: hostname, OriginCert: originCert, // turn off use of reconnect token and auth refresh when using named tunnels @@ -227,20 +228,14 @@ func prepareTunnelConfig( if ingressRules.IsEmpty() { ingressRules, err = ingress.NewSingleOrigin(c, !isNamedTunnel) if err != nil { - return nil, ingress.Ingress{}, err + return nil, nil, err } } - var warpRoutingService *ingress.WarpRoutingService warpRoutingEnabled := isWarpRoutingEnabled(cfg.WarpRouting, isNamedTunnel) - if warpRoutingEnabled { - warpRoutingService = ingress.NewWarpRoutingService() - log.Info().Msgf("Warp-routing is enabled") - } - - protocolSelector, err := connection.NewProtocolSelector(c.String("protocol"), warpRoutingEnabled, namedTunnel, edgediscovery.ProtocolPercentage, origin.ResolveTTL, log) + protocolSelector, err := connection.NewProtocolSelector(c.String("protocol"), warpRoutingEnabled, namedTunnel, edgediscovery.ProtocolPercentage, supervisor.ResolveTTL, log) if err != nil { - return nil, ingress.Ingress{}, err + return nil, nil, err } log.Info().Msgf("Initial protocol %s", protocolSelector.Current()) @@ -248,11 +243,11 @@ func prepareTunnelConfig( for _, p := range connection.ProtocolList { tlsSettings := p.TLSSettings() if tlsSettings == nil { - return nil, ingress.Ingress{}, fmt.Errorf("%s has unknown TLS settings", p) + return nil, nil, fmt.Errorf("%s has unknown TLS settings", p) } edgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c, tlsSettings.ServerName) if err != nil { - return nil, ingress.Ingress{}, errors.Wrap(err, "unable to create TLS config to connect with edge") + return nil, nil, errors.Wrap(err, "unable to create TLS config to connect with edge") } if len(tlsSettings.NextProtos) > 0 { edgeTLSConfig.NextProtos = tlsSettings.NextProtos @@ -260,15 +255,9 @@ func prepareTunnelConfig( edgeTLSConfigs[p] = edgeTLSConfig } - originProxy := origin.NewOriginProxy(ingressRules, warpRoutingService, tags, log) gracePeriod, err := gracePeriod(c) if err != nil { - return nil, ingress.Ingress{}, err - } - connectionConfig := &connection.Config{ - OriginProxy: originProxy, - GracePeriod: gracePeriod, - ReplaceExisting: c.Bool("force"), + return nil, nil, err } muxerConfig := &connection.MuxerConfig{ HeartbeatInterval: c.Duration("heartbeat-interval"), @@ -279,21 +268,22 @@ func prepareTunnelConfig( MetricsUpdateFreq: c.Duration("metrics-update-freq"), } - return &origin.TunnelConfig{ - ConnectionConfig: connectionConfig, - OSArch: info.OSArch(), - ClientID: clientID, - EdgeAddrs: c.StringSlice("edge"), - Region: c.String("region"), - HAConnections: c.Int("ha-connections"), - IncidentLookup: origin.NewIncidentLookup(), - IsAutoupdated: c.Bool("is-autoupdated"), - LBPool: c.String("lb-pool"), - Tags: tags, - Log: log, - LogTransport: logTransport, - Observer: observer, - ReportedVersion: info.Version(), + tunnelConfig := &supervisor.TunnelConfig{ + GracePeriod: gracePeriod, + ReplaceExisting: c.Bool("force"), + OSArch: info.OSArch(), + ClientID: clientID, + EdgeAddrs: c.StringSlice("edge"), + Region: c.String("region"), + HAConnections: c.Int("ha-connections"), + IncidentLookup: supervisor.NewIncidentLookup(), + IsAutoupdated: c.Bool("is-autoupdated"), + LBPool: c.String("lb-pool"), + Tags: tags, + Log: log, + LogTransport: logTransport, + Observer: observer, + ReportedVersion: info.Version(), // Note TUN-3758 , we use Int because UInt is not supported with altsrc Retries: uint(c.Int("retries")), RunFromTerminal: isRunningFromTerminal(), @@ -302,7 +292,12 @@ func prepareTunnelConfig( MuxerConfig: muxerConfig, ProtocolSelector: protocolSelector, EdgeTLSConfigs: edgeTLSConfigs, - }, ingressRules, nil + } + dynamicConfig := &orchestration.Config{ + Ingress: &ingressRules, + WarpRoutingEnabled: warpRoutingEnabled, + } + return tunnelConfig, dynamicConfig, nil } func gracePeriod(c *cli.Context) (time.Duration, error) { diff --git a/cmd/cloudflared/tunnel/ingress_subcommands.go b/cmd/cloudflared/tunnel/ingress_subcommands.go index 22a5944b..cc55e7a1 100644 --- a/cmd/cloudflared/tunnel/ingress_subcommands.go +++ b/cmd/cloudflared/tunnel/ingress_subcommands.go @@ -1,6 +1,7 @@ package tunnel import ( + "encoding/json" "fmt" "net/url" @@ -12,6 +13,15 @@ import ( "github.com/urfave/cli/v2" ) +const ingressDataJSONFlagName = "json" + +var ingressDataJSON = &cli.StringFlag{ + Name: ingressDataJSONFlagName, + Aliases: []string{"j"}, + Usage: `Accepts data in the form of json as an input rather than read from a file`, + EnvVars: []string{"TUNNEL_INGRESS_VALIDATE_JSON"}, +} + func buildIngressSubcommand() *cli.Command { return &cli.Command{ Name: "ingress", @@ -49,6 +59,7 @@ func buildValidateIngressCommand() *cli.Command { Usage: "Validate the ingress configuration ", UsageText: "cloudflared tunnel [--config FILEPATH] ingress validate", Description: "Validates the configuration file, ensuring your ingress rules are OK.", + Flags: []cli.Flag{ingressDataJSON}, } } @@ -69,12 +80,11 @@ func buildTestURLCommand() *cli.Command { // validateIngressCommand check the syntax of the ingress rules in the cloudflared config file func validateIngressCommand(c *cli.Context, warnings string) error { - conf := config.GetConfiguration() - if conf.Source() == "" { - fmt.Println("No configuration file was found. Please create one, or use the --config flag to specify its filepath. You can use the help command to learn more about configuration files") - return nil + conf, err := getConfiguration(c) + if err != nil { + return err } - fmt.Println("Validating rules from", conf.Source()) + if _, err := ingress.ParseIngress(conf); err != nil { return errors.Wrap(err, "Validation failed") } @@ -90,6 +100,22 @@ func validateIngressCommand(c *cli.Context, warnings string) error { return nil } +func getConfiguration(c *cli.Context) (*config.Configuration, error) { + var conf *config.Configuration + if c.IsSet(ingressDataJSONFlagName) { + ingressJSON := c.String(ingressDataJSONFlagName) + fmt.Println("Validating rules from cmdline flag --json") + err := json.Unmarshal([]byte(ingressJSON), &conf) + return conf, err + } + conf = config.GetConfiguration() + if conf.Source() == "" { + return nil, errors.New("No configuration file was found. Please create one, or use the --config flag to specify its filepath. You can use the help command to learn more about configuration files") + } + fmt.Println("Validating rules from", conf.Source()) + return conf, nil +} + // testURLCommand checks which ingress rule matches the given URL. func testURLCommand(c *cli.Context) error { requestArg := c.Args().First() diff --git a/cmd/cloudflared/tunnel/quick_tunnel.go b/cmd/cloudflared/tunnel/quick_tunnel.go index 08b5ff78..0dd7747b 100644 --- a/cmd/cloudflared/tunnel/quick_tunnel.go +++ b/cmd/cloudflared/tunnel/quick_tunnel.go @@ -55,7 +55,6 @@ func RunQuickTunnel(sc *subcommandContext) error { AccountTag: data.Result.AccountTag, TunnelSecret: data.Result.Secret, TunnelID: tunnelID, - TunnelName: data.Result.Name, } url := data.Result.Hostname @@ -77,7 +76,7 @@ func RunQuickTunnel(sc *subcommandContext) error { return StartServer( sc.c, buildInfo, - &connection.NamedTunnelConfig{Credentials: credentials, QuickTunnelUrl: data.Result.Hostname}, + &connection.NamedTunnelProperties{Credentials: credentials, QuickTunnelUrl: data.Result.Hostname}, sc.log, sc.isUIEnabled, ) diff --git a/cmd/cloudflared/tunnel/subcommand_context.go b/cmd/cloudflared/tunnel/subcommand_context.go index cb5b15be..53609c3a 100644 --- a/cmd/cloudflared/tunnel/subcommand_context.go +++ b/cmd/cloudflared/tunnel/subcommand_context.go @@ -185,7 +185,6 @@ func (sc *subcommandContext) create(name string, credentialsFilePath string, sec AccountTag: credential.cert.AccountID, TunnelSecret: tunnelSecret, TunnelID: tunnel.ID, - TunnelName: name, } usedCertPath := false if credentialsFilePath == "" { @@ -221,7 +220,9 @@ func (sc *subcommandContext) create(name string, credentialsFilePath string, sec } fmt.Println(" Keep this file secret. To revoke these credentials, delete the tunnel.") fmt.Printf("\nCreated tunnel %s with id %s\n", tunnel.Name, tunnel.ID) - return tunnel, nil + fmt.Printf("\nTunnel Token: %s\n", tunnel.Token) + + return &tunnel.Tunnel, nil } func (sc *subcommandContext) list(filter *cfapi.TunnelFilter) ([]*cfapi.Tunnel, error) { @@ -301,10 +302,16 @@ func (sc *subcommandContext) run(tunnelID uuid.UUID) error { return err } + return sc.runWithCredentials(credentials) +} + +func (sc *subcommandContext) runWithCredentials(credentials connection.Credentials) error { + sc.log.Info().Str(LogFieldTunnelID, credentials.TunnelID.String()).Msg("Starting tunnel") + return StartServer( sc.c, buildInfo, - &connection.NamedTunnelConfig{Credentials: credentials}, + &connection.NamedTunnelProperties{Credentials: credentials}, sc.log, sc.isUIEnabled, ) @@ -370,7 +377,7 @@ func (sc *subcommandContext) findID(input string) (uuid.UUID, error) { // Look up name in the credentials file. credFinder := newStaticPath(sc.c.String(CredFileFlag), sc.fs) if credentials, err := sc.readTunnelCredentials(credFinder); err == nil { - if credentials.TunnelID != uuid.Nil && input == credentials.TunnelName { + if credentials.TunnelID != uuid.Nil { return credentials.TunnelID, nil } } diff --git a/cmd/cloudflared/tunnel/subcommand_context_test.go b/cmd/cloudflared/tunnel/subcommand_context_test.go index 61a1e68b..31d04b05 100644 --- a/cmd/cloudflared/tunnel/subcommand_context_test.go +++ b/cmd/cloudflared/tunnel/subcommand_context_test.go @@ -11,6 +11,7 @@ import ( "github.com/google/uuid" "github.com/pkg/errors" "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" "github.com/urfave/cli/v2" "github.com/cloudflare/cloudflared/cfapi" @@ -115,7 +116,6 @@ func Test_subcommandContext_findCredentials(t *testing.T) { AccountTag: accountTag, TunnelID: tunnelID, TunnelSecret: secret, - TunnelName: name, }, }, { @@ -160,7 +160,6 @@ func Test_subcommandContext_findCredentials(t *testing.T) { AccountTag: accountTag, TunnelID: tunnelID, TunnelSecret: secret, - TunnelName: name, }, }, } @@ -322,3 +321,48 @@ func Test_subcommandContext_Delete(t *testing.T) { }) } } + +func Test_subcommandContext_ValidateIngressCommand(t *testing.T) { + var tests = []struct { + name string + c *cli.Context + wantErr bool + expectedErr error + }{ + { + name: "read a valid configuration from data", + c: func() *cli.Context { + data := `{ "warp-routing": {"enabled": true}, "originRequest" : {"connectTimeout": 10}, "ingress" : [ {"hostname": "test", "service": "https://localhost:8000" } , {"service": "http_status:404"} ]}` + flagSet := flag.NewFlagSet("json", flag.PanicOnError) + flagSet.String(ingressDataJSONFlagName, data, "") + c := cli.NewContext(cli.NewApp(), flagSet, nil) + _ = c.Set(ingressDataJSONFlagName, data) + return c + }(), + }, + { + name: "read an invalid configuration with multiple mistakes", + c: func() *cli.Context { + data := `{ "ingress" : [ {"hostname": "test", "service": "localhost:8000" } , {"service": "http_status:invalid_status"} ]}` + flagSet := flag.NewFlagSet("json", flag.PanicOnError) + flagSet.String(ingressDataJSONFlagName, data, "") + c := cli.NewContext(cli.NewApp(), flagSet, nil) + _ = c.Set(ingressDataJSONFlagName, data) + return c + }(), + wantErr: true, + expectedErr: errors.New("Validation failed: localhost:8000 is an invalid address, please make sure it has a scheme and a hostname"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateIngressCommand(tt.c, "") + if tt.wantErr { + assert.Equal(t, tt.expectedErr.Error(), err.Error()) + } else { + assert.Nil(t, err) + } + }) + } +} diff --git a/cmd/cloudflared/tunnel/subcommands.go b/cmd/cloudflared/tunnel/subcommands.go index 22dca2a4..8362d8f8 100644 --- a/cmd/cloudflared/tunnel/subcommands.go +++ b/cmd/cloudflared/tunnel/subcommands.go @@ -2,6 +2,7 @@ package tunnel import ( "crypto/rand" + "encoding/base64" "encoding/json" "fmt" "io/ioutil" @@ -34,6 +35,7 @@ const ( CredFileFlagAlias = "cred-file" CredFileFlag = "credentials-file" CredContentsFlag = "credentials-contents" + TunnelTokenFlag = "token" overwriteDNSFlagName = "overwrite-dns" LogFieldTunnelID = "tunnelID" @@ -118,6 +120,11 @@ var ( Usage: "Contents of the tunnel credentials JSON file to use. When provided along with credentials-file, this will take precedence.", EnvVars: []string{"TUNNEL_CRED_CONTENTS"}, }) + tunnelTokenFlag = altsrc.NewStringFlag(&cli.StringFlag{ + Name: TunnelTokenFlag, + Usage: "The Tunnel token. When provided along with credentials, this will take precedence.", + EnvVars: []string{"TUNNEL_TOKEN"}, + }) forceDeleteFlag = &cli.BoolFlag{ Name: "force", Aliases: []string{"f"}, @@ -597,6 +604,7 @@ func buildRunCommand() *cli.Command { credentialsContentsFlag, selectProtocolFlag, featuresFlag, + tunnelTokenFlag, } flags = append(flags, configureProxyFlags(false)...) return &cli.Command{ @@ -627,14 +635,6 @@ func runCommand(c *cli.Context) error { if c.NArg() > 1 { return cliutil.UsageError(`"cloudflared tunnel run" accepts only one argument, the ID or name of the tunnel to run.`) } - tunnelRef := c.Args().First() - if tunnelRef == "" { - // see if tunnel id was in the config file - tunnelRef = config.GetConfiguration().TunnelID - if tunnelRef == "" { - return cliutil.UsageError(`"cloudflared tunnel run" requires the ID or name of the tunnel to run as the last command line argument or in the configuration file.`) - } - } if c.String("hostname") != "" { sc.log.Warn().Msg("The property `hostname` in your configuration is ignored because you configured a Named Tunnel " + @@ -642,7 +642,38 @@ func runCommand(c *cli.Context) error { "your origin will not be reachable. You should remove the `hostname` property to avoid this warning.") } - return runNamedTunnel(sc, tunnelRef) + // Check if token is provided and if not use default tunnelID flag method + if tokenStr := c.String(TunnelTokenFlag); tokenStr != "" { + if token, err := parseToken(tokenStr); err == nil { + return sc.runWithCredentials(token.Credentials()) + } + + return cliutil.UsageError("Provided Tunnel token is not valid.") + } else { + tunnelRef := c.Args().First() + if tunnelRef == "" { + // see if tunnel id was in the config file + tunnelRef = config.GetConfiguration().TunnelID + if tunnelRef == "" { + return cliutil.UsageError(`"cloudflared tunnel run" requires the ID or name of the tunnel to run as the last command line argument or in the configuration file.`) + } + } + + return runNamedTunnel(sc, tunnelRef) + } +} + +func parseToken(tokenStr string) (*connection.TunnelToken, error) { + content, err := base64.StdEncoding.DecodeString(tokenStr) + if err != nil { + return nil, err + } + + var token connection.TunnelToken + if err := json.Unmarshal(content, &token); err != nil { + return nil, err + } + return &token, nil } func runNamedTunnel(sc *subcommandContext, tunnelRef string) error { @@ -650,9 +681,6 @@ func runNamedTunnel(sc *subcommandContext, tunnelRef string) error { if err != nil { return errors.Wrap(err, "error parsing tunnel ID") } - - sc.log.Info().Str(LogFieldTunnelID, tunnelID.String()).Msg("Starting tunnel") - return sc.run(tunnelID) } diff --git a/cmd/cloudflared/tunnel/subcommands_test.go b/cmd/cloudflared/tunnel/subcommands_test.go index 4ebbd922..81f542c7 100644 --- a/cmd/cloudflared/tunnel/subcommands_test.go +++ b/cmd/cloudflared/tunnel/subcommands_test.go @@ -1,14 +1,18 @@ package tunnel import ( + "encoding/base64" + "encoding/json" "path/filepath" "testing" "github.com/google/uuid" homedir "github.com/mitchellh/go-homedir" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/cloudflare/cloudflared/cfapi" + "github.com/cloudflare/cloudflared/connection" ) func Test_fmtConnections(t *testing.T) { @@ -177,3 +181,24 @@ func Test_validateHostname(t *testing.T) { }) } } + +func Test_TunnelToken(t *testing.T) { + token, err := parseToken("aabc") + require.Error(t, err) + require.Nil(t, token) + + expectedToken := &connection.TunnelToken{ + AccountTag: "abc", + TunnelSecret: []byte("secret"), + TunnelID: uuid.New(), + } + + tokenJsonStr, err := json.Marshal(expectedToken) + require.NoError(t, err) + + token64 := base64.StdEncoding.EncodeToString(tokenJsonStr) + + token, err = parseToken(token64) + require.NoError(t, err) + require.Equal(t, token, expectedToken) +} diff --git a/cmd/cloudflared/updater/update.go b/cmd/cloudflared/updater/update.go index c385cf04..23aa327a 100644 --- a/cmd/cloudflared/updater/update.go +++ b/cmd/cloudflared/updater/update.go @@ -19,7 +19,7 @@ import ( const ( DefaultCheckUpdateFreq = time.Hour * 24 - noUpdateInShellMessage = "cloudflared will not automatically update when run from the shell. To enable auto-updates, run cloudflared as a service: https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/run-tunnel/run-as-service" + noUpdateInShellMessage = "cloudflared will not automatically update when run from the shell. To enable auto-updates, run cloudflared as a service: https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/run-tunnel/as-a-service/" noUpdateOnWindowsMessage = "cloudflared will not automatically update on Windows systems." noUpdateManagedPackageMessage = "cloudflared will not automatically update if installed by a package manager." isManagedInstallFile = ".installedFromPackageManager" diff --git a/config/configuration.go b/config/configuration.go index 961de6c6..8b16d4fe 100644 --- a/config/configuration.go +++ b/config/configuration.go @@ -175,60 +175,62 @@ func ValidateUrl(c *cli.Context, allowURLFromArgs bool) (*url.URL, error) { } type UnvalidatedIngressRule struct { - Hostname string - Path string - Service string - OriginRequest OriginRequestConfig `yaml:"originRequest"` + Hostname string `json:"hostname"` + Path string `json:"path"` + Service string `json:"service"` + OriginRequest OriginRequestConfig `yaml:"originRequest" json:"originRequest"` } // OriginRequestConfig is a set of optional fields that users may set to // customize how cloudflared sends requests to origin services. It is used to set // up general config that apply to all rules, and also, specific per-rule // config. -// Note: To specify a time.Duration in go-yaml, use e.g. "3s" or "24h". +// Note: +// - To specify a time.Duration in go-yaml, use e.g. "3s" or "24h". +// - To specify a time.Duration in json, use int64 of the nanoseconds type OriginRequestConfig struct { // HTTP proxy timeout for establishing a new connection - ConnectTimeout *time.Duration `yaml:"connectTimeout"` + ConnectTimeout *time.Duration `yaml:"connectTimeout" json:"connectTimeout"` // HTTP proxy timeout for completing a TLS handshake - TLSTimeout *time.Duration `yaml:"tlsTimeout"` + TLSTimeout *time.Duration `yaml:"tlsTimeout" json:"tlsTimeout"` // HTTP proxy TCP keepalive duration - TCPKeepAlive *time.Duration `yaml:"tcpKeepAlive"` + TCPKeepAlive *time.Duration `yaml:"tcpKeepAlive" json:"tcpKeepAlive"` // HTTP proxy should disable "happy eyeballs" for IPv4/v6 fallback - NoHappyEyeballs *bool `yaml:"noHappyEyeballs"` + NoHappyEyeballs *bool `yaml:"noHappyEyeballs" json:"noHappyEyeballs"` // HTTP proxy maximum keepalive connection pool size - KeepAliveConnections *int `yaml:"keepAliveConnections"` + KeepAliveConnections *int `yaml:"keepAliveConnections" json:"keepAliveConnections"` // HTTP proxy timeout for closing an idle connection - KeepAliveTimeout *time.Duration `yaml:"keepAliveTimeout"` + KeepAliveTimeout *time.Duration `yaml:"keepAliveTimeout" json:"keepAliveTimeout"` // Sets the HTTP Host header for the local webserver. - HTTPHostHeader *string `yaml:"httpHostHeader"` + HTTPHostHeader *string `yaml:"httpHostHeader" json:"httpHostHeader"` // Hostname on the origin server certificate. - OriginServerName *string `yaml:"originServerName"` + OriginServerName *string `yaml:"originServerName" json:"originServerName"` // Path to the CA for the certificate of your origin. // This option should be used only if your certificate is not signed by Cloudflare. - CAPool *string `yaml:"caPool"` + CAPool *string `yaml:"caPool" json:"caPool"` // Disables TLS verification of the certificate presented by your origin. // Will allow any certificate from the origin to be accepted. // Note: The connection from your machine to Cloudflare's Edge is still encrypted. - NoTLSVerify *bool `yaml:"noTLSVerify"` + NoTLSVerify *bool `yaml:"noTLSVerify" json:"noTLSVerify"` // Disables chunked transfer encoding. // Useful if you are running a WSGI server. - DisableChunkedEncoding *bool `yaml:"disableChunkedEncoding"` + DisableChunkedEncoding *bool `yaml:"disableChunkedEncoding" json:"disableChunkedEncoding"` // Runs as jump host - BastionMode *bool `yaml:"bastionMode"` + BastionMode *bool `yaml:"bastionMode" json:"bastionMode"` // Listen address for the proxy. - ProxyAddress *string `yaml:"proxyAddress"` + ProxyAddress *string `yaml:"proxyAddress" json:"proxyAddress"` // Listen port for the proxy. - ProxyPort *uint `yaml:"proxyPort"` + ProxyPort *uint `yaml:"proxyPort" json:"proxyPort"` // Valid options are 'socks' or empty. - ProxyType *string `yaml:"proxyType"` + ProxyType *string `yaml:"proxyType" json:"proxyType"` // IP rules for the proxy service - IPRules []IngressIPRule `yaml:"ipRules"` + IPRules []IngressIPRule `yaml:"ipRules" json:"ipRules"` } type IngressIPRule struct { - Prefix *string `yaml:"prefix"` - Ports []int `yaml:"ports"` - Allow bool `yaml:"allow"` + Prefix *string `yaml:"prefix" json:"prefix"` + Ports []int `yaml:"ports" json:"ports"` + Allow bool `yaml:"allow" json:"allow"` } type Configuration struct { @@ -240,7 +242,7 @@ type Configuration struct { } type WarpRoutingConfig struct { - Enabled bool `yaml:"enabled"` + Enabled bool `yaml:"enabled" json:"enabled"` } type configFileSettings struct { diff --git a/config/configuration_test.go b/config/configuration_test.go index 58ec9639..11db35db 100644 --- a/config/configuration_test.go +++ b/config/configuration_test.go @@ -1,6 +1,7 @@ package config import ( + "encoding/json" "testing" "time" @@ -26,6 +27,18 @@ func TestConfigFileSettings(t *testing.T) { ) rawYAML := ` tunnel: config-file-test +originRequest: + ipRules: + - prefix: "10.0.0.0/8" + ports: + - 80 + - 8080 + allow: false + - prefix: "fc00::/7" + ports: + - 443 + - 4443 + allow: true ingress: - hostname: tunnel1.example.com path: /id @@ -53,6 +66,21 @@ counters: assert.Equal(t, firstIngress, config.Ingress[0]) assert.Equal(t, secondIngress, config.Ingress[1]) assert.Equal(t, warpRouting, config.WarpRouting) + privateV4 := "10.0.0.0/8" + privateV6 := "fc00::/7" + ipRules := []IngressIPRule{ + { + Prefix: &privateV4, + Ports: []int{80, 8080}, + Allow: false, + }, + { + Prefix: &privateV6, + Ports: []int{443, 4443}, + Allow: true, + }, + } + assert.Equal(t, ipRules, config.OriginRequest.IPRules) retries, err := config.Int("retries") assert.NoError(t, err) @@ -81,3 +109,71 @@ counters: assert.Equal(t, 456, counters[1]) } + +func TestUnmarshalOriginRequestConfig(t *testing.T) { + raw := []byte(` +{ + "connectTimeout": 10000000000, + "tlsTimeout": 30000000000, + "tcpKeepAlive": 30000000000, + "noHappyEyeballs": true, + "keepAliveTimeout": 60000000000, + "keepAliveConnections": 10, + "httpHostHeader": "app.tunnel.com", + "originServerName": "app.tunnel.com", + "caPool": "/etc/capool", + "noTLSVerify": true, + "disableChunkedEncoding": true, + "bastionMode": true, + "proxyAddress": "127.0.0.3", + "proxyPort": 9000, + "proxyType": "socks", + "ipRules": [ + { + "prefix": "10.0.0.0/8", + "ports": [80, 8080], + "allow": false + }, + { + "prefix": "fc00::/7", + "ports": [443, 4443], + "allow": true + } + ] +} +`) + var config OriginRequestConfig + assert.NoError(t, json.Unmarshal(raw, &config)) + assert.Equal(t, time.Second*10, *config.ConnectTimeout) + assert.Equal(t, time.Second*30, *config.TLSTimeout) + assert.Equal(t, time.Second*30, *config.TCPKeepAlive) + assert.Equal(t, true, *config.NoHappyEyeballs) + assert.Equal(t, time.Second*60, *config.KeepAliveTimeout) + assert.Equal(t, 10, *config.KeepAliveConnections) + assert.Equal(t, "app.tunnel.com", *config.HTTPHostHeader) + assert.Equal(t, "app.tunnel.com", *config.OriginServerName) + assert.Equal(t, "/etc/capool", *config.CAPool) + assert.Equal(t, true, *config.NoTLSVerify) + assert.Equal(t, true, *config.DisableChunkedEncoding) + assert.Equal(t, true, *config.BastionMode) + assert.Equal(t, "127.0.0.3", *config.ProxyAddress) + assert.Equal(t, true, *config.NoTLSVerify) + assert.Equal(t, uint(9000), *config.ProxyPort) + assert.Equal(t, "socks", *config.ProxyType) + + privateV4 := "10.0.0.0/8" + privateV6 := "fc00::/7" + ipRules := []IngressIPRule{ + { + Prefix: &privateV4, + Ports: []int{80, 8080}, + Allow: false, + }, + { + Prefix: &privateV6, + Ports: []int{443, 4443}, + Allow: true, + }, + } + assert.Equal(t, ipRules, config.IPRules) +} diff --git a/connection/connection.go b/connection/connection.go index 2a57229f..525c1a6e 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -25,13 +25,12 @@ const ( var switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols)) -type Config struct { - OriginProxy OriginProxy - GracePeriod time.Duration - ReplaceExisting bool +type Orchestrator interface { + UpdateConfig(version int32, config []byte) *pogs.UpdateConfigurationResponse + GetOriginProxy() (OriginProxy, error) } -type NamedTunnelConfig struct { +type NamedTunnelProperties struct { Credentials Credentials Client pogs.ClientInfo QuickTunnelUrl string @@ -42,7 +41,6 @@ type Credentials struct { AccountTag string TunnelSecret []byte TunnelID uuid.UUID - TunnelName string } func (c *Credentials) Auth() pogs.TunnelAuth { @@ -52,7 +50,22 @@ func (c *Credentials) Auth() pogs.TunnelAuth { } } -type ClassicTunnelConfig struct { +// TunnelToken are Credentials but encoded with custom fields namings. +type TunnelToken struct { + AccountTag string `json:"a"` + TunnelSecret []byte `json:"s"` + TunnelID uuid.UUID `json:"t"` +} + +func (t TunnelToken) Credentials() Credentials { + return Credentials{ + AccountTag: t.AccountTag, + TunnelSecret: t.TunnelSecret, + TunnelID: t.TunnelID, + } +} + +type ClassicTunnelProperties struct { Hostname string OriginCert []byte // feature-flag to use new edge reconnect tokens diff --git a/connection/connection_test.go b/connection/connection_test.go index e8e477ea..9e43fee2 100644 --- a/connection/connection_test.go +++ b/connection/connection_test.go @@ -4,33 +4,28 @@ import ( "context" "fmt" "io" + "math/rand" "net/http" - "net/url" "testing" "time" - "github.com/gobwas/ws/wsutil" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" - "github.com/cloudflare/cloudflared/ingress" + tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" + "github.com/cloudflare/cloudflared/websocket" ) const ( - largeFileSize = 2 * 1024 * 1024 + largeFileSize = 2 * 1024 * 1024 + testGracePeriod = time.Millisecond * 100 ) var ( - unusedWarpRoutingService = (*ingress.WarpRoutingService)(nil) - testConfig = &Config{ - OriginProxy: &mockOriginProxy{}, - GracePeriod: time.Millisecond * 100, + testOrchestrator = &mockOrchestrator{ + originProxy: &mockOriginProxy{}, } log = zerolog.Nop() - testOriginURL = &url.URL{ - Scheme: "https", - Host: "connectiontest.argotunnel.com", - } testLargeResp = make([]byte, largeFileSize) ) @@ -42,6 +37,20 @@ type testRequest struct { isProxyError bool } +type mockOrchestrator struct { + originProxy OriginProxy +} + +func (*mockOrchestrator) UpdateConfig(version int32, config []byte) *tunnelpogs.UpdateConfigurationResponse { + return &tunnelpogs.UpdateConfigurationResponse{ + LastAppliedVersion: version, + } +} + +func (mcr *mockOrchestrator) GetOriginProxy() (OriginProxy, error) { + return mcr.originProxy, nil +} + type mockOriginProxy struct{} func (moc *mockOriginProxy) ProxyHTTP( @@ -50,7 +59,15 @@ func (moc *mockOriginProxy) ProxyHTTP( isWebsocket bool, ) error { if isWebsocket { - return wsEndpoint(w, req) + switch req.URL.Path { + case "/ws/echo": + return wsEchoEndpoint(w, req) + case "/ws/flaky": + return wsFlakyEndpoint(w, req) + default: + originRespEndpoint(w, http.StatusNotFound, []byte("ws endpoint not found")) + return fmt.Errorf("Unknwon websocket endpoint %s", req.URL.Path) + } } switch req.URL.Path { case "/ok": @@ -78,32 +95,82 @@ func (moc *mockOriginProxy) ProxyTCP( return nil } -type nowriter struct { - io.Reader +type echoPipe struct { + reader *io.PipeReader + writer *io.PipeWriter } -func (nowriter) Write(p []byte) (int, error) { - return 0, fmt.Errorf("Writer not implemented") +func (ep *echoPipe) Read(p []byte) (int, error) { + return ep.reader.Read(p) } -func wsEndpoint(w ResponseWriter, r *http.Request) error { +func (ep *echoPipe) Write(p []byte) (int, error) { + return ep.writer.Write(p) +} + +// A mock origin that echos data by streaming like a tcpOverWSConnection +// https://github.com/cloudflare/cloudflared/blob/master/ingress/origin_connection.go +func wsEchoEndpoint(w ResponseWriter, r *http.Request) error { resp := &http.Response{ StatusCode: http.StatusSwitchingProtocols, } - _ = w.WriteRespHeaders(resp.StatusCode, resp.Header) - clientReader := nowriter{r.Body} + if err := w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil { + return err + } + wsCtx, cancel := context.WithCancel(r.Context()) + readPipe, writePipe := io.Pipe() + wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, r), &log) go func() { - for { - data, err := wsutil.ReadClientText(clientReader) - if err != nil { - return - } - if err := wsutil.WriteServerText(w, data); err != nil { - return - } + select { + case <-wsCtx.Done(): + case <-r.Context().Done(): } + readPipe.Close() + writePipe.Close() }() - <-r.Context().Done() + + originConn := &echoPipe{reader: readPipe, writer: writePipe} + websocket.Stream(wsConn, originConn, &log) + cancel() + wsConn.Close() + return nil +} + +type flakyConn struct { + closeAt time.Time +} + +func (fc *flakyConn) Read(p []byte) (int, error) { + if time.Now().After(fc.closeAt) { + return 0, io.EOF + } + n := copy(p, "Read from flaky connection") + return n, nil +} + +func (fc *flakyConn) Write(p []byte) (int, error) { + if time.Now().After(fc.closeAt) { + return 0, fmt.Errorf("flaky connection closed") + } + return len(p), nil +} + +func wsFlakyEndpoint(w ResponseWriter, r *http.Request) error { + resp := &http.Response{ + StatusCode: http.StatusSwitchingProtocols, + } + if err := w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil { + return err + } + wsCtx, cancel := context.WithCancel(r.Context()) + + wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, r), &log) + + closedAfter := time.Millisecond * time.Duration(rand.Intn(50)) + originConn := &flakyConn{closeAt: time.Now().Add(closedAfter)} + websocket.Stream(wsConn, originConn, &log) + cancel() + wsConn.Close() return nil } diff --git a/connection/control.go b/connection/control.go index c0c6a1d7..2467e80a 100644 --- a/connection/control.go +++ b/connection/control.go @@ -16,9 +16,9 @@ type RPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) Na type controlStream struct { observer *Observer - connectedFuse ConnectedFuse - namedTunnelConfig *NamedTunnelConfig - connIndex uint8 + connectedFuse ConnectedFuse + namedTunnelProperties *NamedTunnelProperties + connIndex uint8 newRPCClientFunc RPCClientFunc @@ -39,7 +39,7 @@ type ControlStreamHandler interface { func NewControlStream( observer *Observer, connectedFuse ConnectedFuse, - namedTunnelConfig *NamedTunnelConfig, + namedTunnelConfig *NamedTunnelProperties, connIndex uint8, newRPCClientFunc RPCClientFunc, gracefulShutdownC <-chan struct{}, @@ -49,13 +49,13 @@ func NewControlStream( newRPCClientFunc = newRegistrationRPCClient } return &controlStream{ - observer: observer, - connectedFuse: connectedFuse, - namedTunnelConfig: namedTunnelConfig, - newRPCClientFunc: newRPCClientFunc, - connIndex: connIndex, - gracefulShutdownC: gracefulShutdownC, - gracePeriod: gracePeriod, + observer: observer, + connectedFuse: connectedFuse, + namedTunnelProperties: namedTunnelConfig, + newRPCClientFunc: newRPCClientFunc, + connIndex: connIndex, + gracefulShutdownC: gracefulShutdownC, + gracePeriod: gracePeriod, } } @@ -66,7 +66,7 @@ func (c *controlStream) ServeControlStream( ) error { rpcClient := c.newRPCClientFunc(ctx, rw, c.observer.log) - if err := rpcClient.RegisterConnection(ctx, c.namedTunnelConfig, connOptions, c.connIndex, c.observer); err != nil { + if err := rpcClient.RegisterConnection(ctx, c.namedTunnelProperties, connOptions, c.connIndex, c.observer); err != nil { rpcClient.Close() return err } diff --git a/connection/h2mux.go b/connection/h2mux.go index 1e7c652b..1c7276ac 100644 --- a/connection/h2mux.go +++ b/connection/h2mux.go @@ -22,9 +22,10 @@ const ( ) type h2muxConnection struct { - config *Config - muxerConfig *MuxerConfig - muxer *h2mux.Muxer + orchestrator Orchestrator + gracePeriod time.Duration + muxerConfig *MuxerConfig + muxer *h2mux.Muxer // connectionID is only used by metrics, and prometheus requires labels to be string connIndexStr string connIndex uint8 @@ -60,7 +61,8 @@ func (mc *MuxerConfig) H2MuxerConfig(h h2mux.MuxedStreamHandler, log *zerolog.Lo // NewTunnelHandler returns a TunnelHandler, origin LAN IP and error func NewH2muxConnection( - config *Config, + orchestrator Orchestrator, + gracePeriod time.Duration, muxerConfig *MuxerConfig, edgeConn net.Conn, connIndex uint8, @@ -68,7 +70,8 @@ func NewH2muxConnection( gracefulShutdownC <-chan struct{}, ) (*h2muxConnection, error, bool) { h := &h2muxConnection{ - config: config, + orchestrator: orchestrator, + gracePeriod: gracePeriod, muxerConfig: muxerConfig, connIndexStr: uint8ToString(connIndex), connIndex: connIndex, @@ -88,7 +91,7 @@ func NewH2muxConnection( return h, nil, false } -func (h *h2muxConnection) ServeNamedTunnel(ctx context.Context, namedTunnel *NamedTunnelConfig, connOptions *tunnelpogs.ConnectionOptions, connectedFuse ConnectedFuse) error { +func (h *h2muxConnection) ServeNamedTunnel(ctx context.Context, namedTunnel *NamedTunnelProperties, connOptions *tunnelpogs.ConnectionOptions, connectedFuse ConnectedFuse) error { errGroup, serveCtx := errgroup.WithContext(ctx) errGroup.Go(func() error { return h.serveMuxer(serveCtx) @@ -117,7 +120,7 @@ func (h *h2muxConnection) ServeNamedTunnel(ctx context.Context, namedTunnel *Nam return err } -func (h *h2muxConnection) ServeClassicTunnel(ctx context.Context, classicTunnel *ClassicTunnelConfig, credentialManager CredentialManager, registrationOptions *tunnelpogs.RegistrationOptions, connectedFuse ConnectedFuse) error { +func (h *h2muxConnection) ServeClassicTunnel(ctx context.Context, classicTunnel *ClassicTunnelProperties, credentialManager CredentialManager, registrationOptions *tunnelpogs.RegistrationOptions, connectedFuse ConnectedFuse) error { errGroup, serveCtx := errgroup.WithContext(ctx) errGroup.Go(func() error { return h.serveMuxer(serveCtx) @@ -224,7 +227,13 @@ func (h *h2muxConnection) ServeStream(stream *h2mux.MuxedStream) error { sourceConnectionType = TypeWebsocket } - err := h.config.OriginProxy.ProxyHTTP(respWriter, req, sourceConnectionType == TypeWebsocket) + originProxy, err := h.orchestrator.GetOriginProxy() + if err != nil { + respWriter.WriteErrorResponse() + return err + } + + err = originProxy.ProxyHTTP(respWriter, req, sourceConnectionType == TypeWebsocket) if err != nil { respWriter.WriteErrorResponse() } diff --git a/connection/h2mux_test.go b/connection/h2mux_test.go index e6eab072..787cfd17 100644 --- a/connection/h2mux_test.go +++ b/connection/h2mux_test.go @@ -48,7 +48,7 @@ func newH2MuxConnection(t require.TestingT) (*h2muxConnection, *h2mux.Muxer) { }() var connIndex = uint8(0) testObserver := NewObserver(&log, &log, false) - h2muxConn, err, _ := NewH2muxConnection(testConfig, testMuxerConfig, originConn, connIndex, testObserver, nil) + h2muxConn, err, _ := NewH2muxConnection(testOrchestrator, testGracePeriod, testMuxerConfig, originConn, connIndex, testObserver, nil) require.NoError(t, err) return h2muxConn, <-edgeMuxChan } @@ -147,7 +147,7 @@ func TestServeStreamWS(t *testing.T) { headers := []h2mux.Header{ { Name: ":path", - Value: "/ws", + Value: "/ws/echo", }, { Name: "connection", @@ -167,10 +167,10 @@ func TestServeStreamWS(t *testing.T) { assert.True(t, hasHeader(stream, ResponseMetaHeader, responseMetaHeaderOrigin)) data := []byte("test websocket") - err = wsutil.WriteClientText(writePipe, data) + err = wsutil.WriteClientBinary(writePipe, data) require.NoError(t, err) - respBody, err := wsutil.ReadServerText(stream) + respBody, err := wsutil.ReadServerBinary(stream) require.NoError(t, err) require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody))) diff --git a/connection/http2.go b/connection/http2.go index c0ab8f23..d1e78c1f 100644 --- a/connection/http2.go +++ b/connection/http2.go @@ -30,12 +30,12 @@ var errEdgeConnectionClosed = fmt.Errorf("connection with edge closed") // HTTP2Connection represents a net.Conn that uses HTTP2 frames to proxy traffic from the edge to cloudflared on the // origin. type HTTP2Connection struct { - conn net.Conn - server *http2.Server - config *Config - connOptions *tunnelpogs.ConnectionOptions - observer *Observer - connIndex uint8 + conn net.Conn + server *http2.Server + orchestrator Orchestrator + connOptions *tunnelpogs.ConnectionOptions + observer *Observer + connIndex uint8 // newRPCClientFunc allows us to mock RPCs during testing newRPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient @@ -49,7 +49,7 @@ type HTTP2Connection struct { // NewHTTP2Connection returns a new instance of HTTP2Connection. func NewHTTP2Connection( conn net.Conn, - config *Config, + orchestrator Orchestrator, connOptions *tunnelpogs.ConnectionOptions, observer *Observer, connIndex uint8, @@ -61,7 +61,7 @@ func NewHTTP2Connection( server: &http2.Server{ MaxConcurrentStreams: MaxConcurrentStreams, }, - config: config, + orchestrator: orchestrator, connOptions: connOptions, observer: observer, connIndex: connIndex, @@ -106,6 +106,12 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + originProxy, err := c.orchestrator.GetOriginProxy() + if err != nil { + c.observer.log.Error().Msg(err.Error()) + return + } + switch connType { case TypeControlStream: if err := c.controlStreamHandler.ServeControlStream(r.Context(), respWriter, c.connOptions); err != nil { @@ -116,7 +122,7 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) { case TypeWebsocket, TypeHTTP: stripWebsocketUpgradeHeader(r) - if err := c.config.OriginProxy.ProxyHTTP(respWriter, r, connType == TypeWebsocket); err != nil { + if err := originProxy.ProxyHTTP(respWriter, r, connType == TypeWebsocket); err != nil { err := fmt.Errorf("Failed to proxy HTTP: %w", err) c.log.Error().Err(err) respWriter.WriteErrorResponse() @@ -131,7 +137,7 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) { } rws := NewHTTPResponseReadWriterAcker(respWriter, r) - if err := c.config.OriginProxy.ProxyTCP(r.Context(), rws, &TCPRequest{ + if err := originProxy.ProxyTCP(r.Context(), rws, &TCPRequest{ Dest: host, CFRay: FindCfRayHeader(r), LBProbe: IsLBProbeRequest(r), diff --git a/connection/http2_test.go b/connection/http2_test.go index 4b7435bd..c067229c 100644 --- a/connection/http2_test.go +++ b/connection/http2_test.go @@ -2,6 +2,7 @@ package connection import ( "context" + "errors" "fmt" "io" "io/ioutil" @@ -27,22 +28,23 @@ var ( ) func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) { - edgeConn, originConn := net.Pipe() + edgeConn, cfdConn := net.Pipe() var connIndex = uint8(0) log := zerolog.Nop() obs := NewObserver(&log, &log, false) controlStream := NewControlStream( obs, mockConnectedFuse{}, - &NamedTunnelConfig{}, + &NamedTunnelProperties{}, connIndex, nil, nil, 1*time.Second, ) return NewHTTP2Connection( - originConn, - testConfig, + cfdConn, + // OriginProxy is set in testConfigManager + testOrchestrator, &pogs.ConnectionOptions{}, obs, connIndex, @@ -130,7 +132,7 @@ type mockNamedTunnelRPCClient struct { func (mc mockNamedTunnelRPCClient) RegisterConnection( c context.Context, - config *NamedTunnelConfig, + properties *NamedTunnelProperties, options *tunnelpogs.ConnectionOptions, connIndex uint8, observer *Observer, @@ -166,6 +168,8 @@ type wsRespWriter struct { *httptest.ResponseRecorder readPipe *io.PipeReader writePipe *io.PipeWriter + closed bool + panicked bool } func newWSRespWriter() *wsRespWriter { @@ -174,46 +178,59 @@ func newWSRespWriter() *wsRespWriter { httptest.NewRecorder(), readPipe, writePipe, + false, + false, } } +type nowriter struct { + io.Reader +} + +func (nowriter) Write(_ []byte) (int, error) { + return 0, fmt.Errorf("writer not implemented") +} + func (w *wsRespWriter) RespBody() io.ReadWriter { return nowriter{w.readPipe} } func (w *wsRespWriter) Write(data []byte) (n int, err error) { + if w.closed { + w.panicked = true + return 0, errors.New("wsRespWriter panicked") + } return w.writePipe.Write(data) } +func (w *wsRespWriter) close() { + w.closed = true +} + func TestServeWS(t *testing.T) { http2Conn, _ := newTestHTTP2Connection() ctx, cancel := context.WithCancel(context.Background()) - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - http2Conn.Serve(ctx) - }() respWriter := newWSRespWriter() readPipe, writePipe := io.Pipe() - req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/ws", readPipe) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/ws/echo", readPipe) require.NoError(t, err) req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade) - wg.Add(1) + serveDone := make(chan struct{}) go func() { - defer wg.Done() + defer close(serveDone) http2Conn.ServeHTTP(respWriter, req) + respWriter.close() }() data := []byte("test websocket") - err = wsutil.WriteClientText(writePipe, data) + err = wsutil.WriteClientBinary(writePipe, data) require.NoError(t, err) - respBody, err := wsutil.ReadServerText(respWriter.RespBody()) + respBody, err := wsutil.ReadServerBinary(respWriter.RespBody()) require.NoError(t, err) require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody))) @@ -223,7 +240,65 @@ func TestServeWS(t *testing.T) { require.Equal(t, http.StatusOK, resp.StatusCode) require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeader)) + <-serveDone + require.False(t, respWriter.panicked) +} + +// TestNoWriteAfterServeHTTPReturns is a regression test of https://jira.cfops.it/browse/TUN-5184 +// to make sure we don't write to the ResponseWriter after the ServeHTTP method returns +func TestNoWriteAfterServeHTTPReturns(t *testing.T) { + cfdHTTP2Conn, edgeTCPConn := newTestHTTP2Connection() + + ctx, cancel := context.WithCancel(context.Background()) + var wg sync.WaitGroup + + serverDone := make(chan struct{}) + go func() { + defer close(serverDone) + cfdHTTP2Conn.Serve(ctx) + }() + + edgeTransport := http2.Transport{} + edgeHTTP2Conn, err := edgeTransport.NewClientConn(edgeTCPConn) + require.NoError(t, err) + message := []byte(t.Name()) + + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + readPipe, writePipe := io.Pipe() + reqCtx, reqCancel := context.WithCancel(ctx) + req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, "http://localhost:8080/ws/flaky", readPipe) + require.NoError(t, err) + req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade) + + resp, err := edgeHTTP2Conn.RoundTrip(req) + require.NoError(t, err) + // http2RespWriter should rewrite status 101 to 200 + require.Equal(t, http.StatusOK, resp.StatusCode) + + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-reqCtx.Done(): + return + default: + } + _ = wsutil.WriteClientBinary(writePipe, message) + } + }() + + time.Sleep(time.Millisecond * 100) + reqCancel() + }() + } + wg.Wait() + cancel() + <-serverDone } func TestServeControlStream(t *testing.T) { @@ -238,7 +313,7 @@ func TestServeControlStream(t *testing.T) { controlStream := NewControlStream( obs, mockConnectedFuse{}, - &NamedTunnelConfig{}, + &NamedTunnelProperties{}, 1, rpcClientFactory.newMockRPCClient, nil, @@ -288,7 +363,7 @@ func TestFailRegistration(t *testing.T) { controlStream := NewControlStream( obs, mockConnectedFuse{}, - &NamedTunnelConfig{}, + &NamedTunnelProperties{}, http2Conn.connIndex, rpcClientFactory.newMockRPCClient, nil, @@ -334,7 +409,7 @@ func TestGracefulShutdownHTTP2(t *testing.T) { controlStream := NewControlStream( obs, mockConnectedFuse{}, - &NamedTunnelConfig{}, + &NamedTunnelProperties{}, http2Conn.connIndex, rpcClientFactory.newMockRPCClient, shutdownC, diff --git a/connection/protocol.go b/connection/protocol.go index 399e6d9d..b94bb80e 100644 --- a/connection/protocol.go +++ b/connection/protocol.go @@ -195,7 +195,7 @@ type PercentageFetcher func() (edgediscovery.ProtocolPercents, error) func NewProtocolSelector( protocolFlag string, warpRoutingEnabled bool, - namedTunnel *NamedTunnelConfig, + namedTunnel *NamedTunnelProperties, fetchFunc PercentageFetcher, ttl time.Duration, log *zerolog.Logger, diff --git a/connection/protocol_test.go b/connection/protocol_test.go index 9bb8c50c..9ab5aae3 100644 --- a/connection/protocol_test.go +++ b/connection/protocol_test.go @@ -16,7 +16,7 @@ const ( ) var ( - testNamedTunnelConfig = &NamedTunnelConfig{ + testNamedTunnelProperties = &NamedTunnelProperties{ Credentials: Credentials{ AccountTag: "testAccountTag", }, @@ -51,7 +51,7 @@ func TestNewProtocolSelector(t *testing.T) { hasFallback bool expectedFallback Protocol warpRoutingEnabled bool - namedTunnelConfig *NamedTunnelConfig + namedTunnelConfig *NamedTunnelProperties fetchFunc PercentageFetcher wantErr bool }{ @@ -66,35 +66,35 @@ func TestNewProtocolSelector(t *testing.T) { protocol: "h2mux", expectedProtocol: H2mux, fetchFunc: func() (edgediscovery.ProtocolPercents, error) { return nil, nil }, - namedTunnelConfig: testNamedTunnelConfig, + namedTunnelConfig: testNamedTunnelProperties, }, { name: "named tunnel over http2", protocol: "http2", expectedProtocol: HTTP2, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}), - namedTunnelConfig: testNamedTunnelConfig, + namedTunnelConfig: testNamedTunnelProperties, }, { name: "named tunnel http2 disabled still gets http2 because it is manually picked", protocol: "http2", expectedProtocol: HTTP2, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}), - namedTunnelConfig: testNamedTunnelConfig, + namedTunnelConfig: testNamedTunnelProperties, }, { name: "named tunnel quic disabled still gets quic because it is manually picked", protocol: "quic", expectedProtocol: QUIC, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: -1}), - namedTunnelConfig: testNamedTunnelConfig, + namedTunnelConfig: testNamedTunnelProperties, }, { name: "named tunnel quic and http2 disabled", protocol: "auto", expectedProtocol: H2mux, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: -1}), - namedTunnelConfig: testNamedTunnelConfig, + namedTunnelConfig: testNamedTunnelProperties, }, { name: "named tunnel quic disabled", @@ -104,21 +104,21 @@ func TestNewProtocolSelector(t *testing.T) { hasFallback: true, expectedFallback: H2mux, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: -1}), - namedTunnelConfig: testNamedTunnelConfig, + namedTunnelConfig: testNamedTunnelProperties, }, { name: "named tunnel auto all http2 disabled", protocol: "auto", expectedProtocol: H2mux, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}), - namedTunnelConfig: testNamedTunnelConfig, + namedTunnelConfig: testNamedTunnelProperties, }, { name: "named tunnel auto to h2mux", protocol: "auto", expectedProtocol: H2mux, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}), - namedTunnelConfig: testNamedTunnelConfig, + namedTunnelConfig: testNamedTunnelProperties, }, { name: "named tunnel auto to http2", @@ -127,7 +127,7 @@ func TestNewProtocolSelector(t *testing.T) { hasFallback: true, expectedFallback: H2mux, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), - namedTunnelConfig: testNamedTunnelConfig, + namedTunnelConfig: testNamedTunnelProperties, }, { name: "named tunnel auto to quic", @@ -136,7 +136,7 @@ func TestNewProtocolSelector(t *testing.T) { hasFallback: true, expectedFallback: HTTP2, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}), - namedTunnelConfig: testNamedTunnelConfig, + namedTunnelConfig: testNamedTunnelProperties, }, { name: "warp routing requesting h2mux", @@ -145,7 +145,7 @@ func TestNewProtocolSelector(t *testing.T) { hasFallback: false, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), warpRoutingEnabled: true, - namedTunnelConfig: testNamedTunnelConfig, + namedTunnelConfig: testNamedTunnelProperties, }, { name: "warp routing requesting h2mux picks HTTP2 even if http2 percent is -1", @@ -154,7 +154,7 @@ func TestNewProtocolSelector(t *testing.T) { hasFallback: false, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}), warpRoutingEnabled: true, - namedTunnelConfig: testNamedTunnelConfig, + namedTunnelConfig: testNamedTunnelProperties, }, { name: "warp routing http2", @@ -163,7 +163,7 @@ func TestNewProtocolSelector(t *testing.T) { hasFallback: false, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), warpRoutingEnabled: true, - namedTunnelConfig: testNamedTunnelConfig, + namedTunnelConfig: testNamedTunnelProperties, }, { name: "warp routing quic", @@ -173,7 +173,7 @@ func TestNewProtocolSelector(t *testing.T) { expectedFallback: HTTP2Warp, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}), warpRoutingEnabled: true, - namedTunnelConfig: testNamedTunnelConfig, + namedTunnelConfig: testNamedTunnelProperties, }, { name: "warp routing auto", @@ -182,7 +182,7 @@ func TestNewProtocolSelector(t *testing.T) { hasFallback: false, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), warpRoutingEnabled: true, - namedTunnelConfig: testNamedTunnelConfig, + namedTunnelConfig: testNamedTunnelProperties, }, { name: "warp routing auto- quic", @@ -192,7 +192,7 @@ func TestNewProtocolSelector(t *testing.T) { expectedFallback: HTTP2Warp, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}), warpRoutingEnabled: true, - namedTunnelConfig: testNamedTunnelConfig, + namedTunnelConfig: testNamedTunnelProperties, }, { // None named tunnel can only use h2mux, so specifying an unknown protocol is not an error @@ -204,14 +204,14 @@ func TestNewProtocolSelector(t *testing.T) { name: "named tunnel unknown protocol", protocol: "unknown", fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), - namedTunnelConfig: testNamedTunnelConfig, + namedTunnelConfig: testNamedTunnelProperties, wantErr: true, }, { name: "named tunnel fetch error", protocol: "auto", fetchFunc: mockFetcher(true), - namedTunnelConfig: testNamedTunnelConfig, + namedTunnelConfig: testNamedTunnelProperties, expectedProtocol: HTTP2, wantErr: false, }, @@ -237,7 +237,7 @@ func TestNewProtocolSelector(t *testing.T) { func TestAutoProtocolSelectorRefresh(t *testing.T) { fetcher := dynamicMockFetcher{} - selector, err := NewProtocolSelector("auto", noWarpRoutingEnabled, testNamedTunnelConfig, fetcher.fetch(), testNoTTL, &log) + selector, err := NewProtocolSelector("auto", noWarpRoutingEnabled, testNamedTunnelProperties, fetcher.fetch(), testNoTTL, &log) assert.NoError(t, err) assert.Equal(t, H2mux, selector.Current()) @@ -267,7 +267,7 @@ func TestAutoProtocolSelectorRefresh(t *testing.T) { func TestHTTP2ProtocolSelectorRefresh(t *testing.T) { fetcher := dynamicMockFetcher{} // Since the user chooses http2 on purpose, we always stick to it. - selector, err := NewProtocolSelector("http2", noWarpRoutingEnabled, testNamedTunnelConfig, fetcher.fetch(), testNoTTL, &log) + selector, err := NewProtocolSelector("http2", noWarpRoutingEnabled, testNamedTunnelProperties, fetcher.fetch(), testNoTTL, &log) assert.NoError(t, err) assert.Equal(t, HTTP2, selector.Current()) @@ -297,7 +297,7 @@ func TestHTTP2ProtocolSelectorRefresh(t *testing.T) { func TestProtocolSelectorRefreshTTL(t *testing.T) { fetcher := dynamicMockFetcher{} fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}} - selector, err := NewProtocolSelector("auto", noWarpRoutingEnabled, testNamedTunnelConfig, fetcher.fetch(), time.Hour, &log) + selector, err := NewProtocolSelector("auto", noWarpRoutingEnabled, testNamedTunnelProperties, fetcher.fetch(), time.Hour, &log) assert.NoError(t, err) assert.Equal(t, QUIC, selector.Current()) diff --git a/connection/quic.go b/connection/quic.go index c1b4ff9d..1b9f2e55 100644 --- a/connection/quic.go +++ b/connection/quic.go @@ -36,7 +36,7 @@ const ( type QUICConnection struct { session quic.Session logger *zerolog.Logger - httpProxy OriginProxy + orchestrator Orchestrator sessionManager datagramsession.Manager controlStreamHandler ControlStreamHandler connOptions *tunnelpogs.ConnectionOptions @@ -47,7 +47,7 @@ func NewQUICConnection( quicConfig *quic.Config, edgeAddr net.Addr, tlsConfig *tls.Config, - httpProxy OriginProxy, + orchestrator Orchestrator, connOptions *tunnelpogs.ConnectionOptions, controlStreamHandler ControlStreamHandler, logger *zerolog.Logger, @@ -66,7 +66,7 @@ func NewQUICConnection( return &QUICConnection{ session: session, - httpProxy: httpProxy, + orchestrator: orchestrator, logger: logger, sessionManager: sessionManager, controlStreamHandler: controlStreamHandler, @@ -122,7 +122,7 @@ func (q *QUICConnection) serveControlStream(ctx context.Context, controlStream q func (q *QUICConnection) acceptStream(ctx context.Context) error { defer q.Close() for { - stream, err := q.session.AcceptStream(ctx) + quicStream, err := q.session.AcceptStream(ctx) if err != nil { // context.Canceled is usually a user ctrl+c. We don't want to log an error here as it's intentional. if errors.Is(err, context.Canceled) || q.controlStreamHandler.IsStopped() { @@ -131,7 +131,9 @@ func (q *QUICConnection) acceptStream(ctx context.Context) error { return fmt.Errorf("failed to accept QUIC stream: %w", err) } go func() { + stream := quicpogs.NewSafeStreamCloser(quicStream) defer stream.Close() + if err = q.handleStream(stream); err != nil { q.logger.Err(err).Msg("Failed to handle QUIC stream") } @@ -144,7 +146,7 @@ func (q *QUICConnection) Close() { q.session.CloseWithError(0, "") } -func (q *QUICConnection) handleStream(stream quic.Stream) error { +func (q *QUICConnection) handleStream(stream io.ReadWriteCloser) error { signature, err := quicpogs.DetermineProtocol(stream) if err != nil { return err @@ -173,6 +175,10 @@ func (q *QUICConnection) handleDataStream(stream *quicpogs.RequestServerStream) return err } + originProxy, err := q.orchestrator.GetOriginProxy() + if err != nil { + return err + } switch connectRequest.Type { case quicpogs.ConnectionTypeHTTP, quicpogs.ConnectionTypeWebsocket: req, err := buildHTTPRequest(connectRequest, stream) @@ -181,16 +187,16 @@ func (q *QUICConnection) handleDataStream(stream *quicpogs.RequestServerStream) } w := newHTTPResponseAdapter(stream) - return q.httpProxy.ProxyHTTP(w, req, connectRequest.Type == quicpogs.ConnectionTypeWebsocket) + return originProxy.ProxyHTTP(w, req, connectRequest.Type == quicpogs.ConnectionTypeWebsocket) case quicpogs.ConnectionTypeTCP: rwa := &streamReadWriteAcker{stream} - return q.httpProxy.ProxyTCP(context.Background(), rwa, &TCPRequest{Dest: connectRequest.Dest}) + return originProxy.ProxyTCP(context.Background(), rwa, &TCPRequest{Dest: connectRequest.Dest}) } return nil } func (q *QUICConnection) handleRPCStream(rpcStream *quicpogs.RPCServerStream) error { - return rpcStream.Serve(q, q.logger) + return rpcStream.Serve(q, q, q.logger) } // RegisterUdpSession is the RPC method invoked by edge to register and run a session @@ -258,6 +264,11 @@ func (q *QUICConnection) UnregisterUdpSession(ctx context.Context, sessionID uui return q.sessionManager.UnregisterSession(ctx, sessionID, message, true) } +// UpdateConfiguration is the RPC method invoked by edge when there is a new configuration +func (q *QUICConnection) UpdateConfiguration(ctx context.Context, version int32, config []byte) *tunnelpogs.UpdateConfigurationResponse { + return q.orchestrator.UpdateConfig(version, config) +} + // streamReadWriteAcker is a light wrapper over QUIC streams with a callback to send response back to // the client. type streamReadWriteAcker struct { diff --git a/connection/quic_test.go b/connection/quic_test.go index ac945400..9763ae33 100644 --- a/connection/quic_test.go +++ b/connection/quic_test.go @@ -3,14 +3,9 @@ package connection import ( "bytes" "context" - "crypto/rand" - "crypto/rsa" "crypto/tls" - "crypto/x509" - "encoding/pem" "fmt" "io" - "math/big" "net" "net/http" "net/url" @@ -33,7 +28,7 @@ import ( ) var ( - testTLSServerConfig = generateTLSConfig() + testTLSServerConfig = quicpogs.GenerateTLSConfig() testQUICConfig = &quic.Config{ KeepAlive: true, EnableDatagrams: true, @@ -52,7 +47,7 @@ func TestQUICServer(t *testing.T) { // This is simply a sample websocket frame message. wsBuf := &bytes.Buffer{} - wsutil.WriteClientText(wsBuf, []byte("Hello")) + wsutil.WriteClientBinary(wsBuf, []byte("Hello")) var tests = []struct { desc string @@ -84,7 +79,7 @@ func TestQUICServer(t *testing.T) { }, { desc: "test http body request streaming", - dest: "/echo_body", + dest: "/slow_echo_body", connectionType: quicpogs.ConnectionTypeHTTP, metadata: []quicpogs.Metadata{ { @@ -109,7 +104,7 @@ func TestQUICServer(t *testing.T) { }, { desc: "test ws proxy", - dest: "/ok", + dest: "/ws/echo", connectionType: quicpogs.ConnectionTypeWebsocket, metadata: []quicpogs.Metadata{ { @@ -130,7 +125,7 @@ func TestQUICServer(t *testing.T) { }, }, message: wsBuf.Bytes(), - expectedResponse: []byte{0x81, 0x5, 0x48, 0x65, 0x6c, 0x6c, 0x6f}, + expectedResponse: []byte{0x82, 0x5, 0x48, 0x65, 0x6c, 0x6c, 0x6f}, }, { desc: "test tcp proxy", @@ -195,8 +190,9 @@ func quicServer( session, err := earlyListener.Accept(ctx) require.NoError(t, err) - stream, err := session.OpenStreamSync(context.Background()) + quicStream, err := session.OpenStreamSync(context.Background()) require.NoError(t, err) + stream := quicpogs.NewSafeStreamCloser(quicStream) reqClientStream := quicpogs.RequestClientStream{ReadWriteCloser: stream} err = reqClientStream.WriteConnectRequestData(dest, connectionType, metadata...) @@ -207,42 +203,20 @@ func quicServer( if message != nil { // ALPN successful. Write data. - _, err := stream.Write([]byte(message)) + _, err := stream.Write(message) require.NoError(t, err) } response := make([]byte, len(expectedResponse)) - stream.Read(response) - require.NoError(t, err) + _, err = stream.Read(response) + if err != io.EOF { + require.NoError(t, err) + } // For now it is an echo server. Verify if the same data is returned. assert.Equal(t, expectedResponse, response) } -// Setup a bare-bones TLS config for the server -func generateTLSConfig() *tls.Config { - key, err := rsa.GenerateKey(rand.Reader, 1024) - if err != nil { - panic(err) - } - template := x509.Certificate{SerialNumber: big.NewInt(1)} - certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key) - if err != nil { - panic(err) - } - keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) - certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) - - tlsCert, err := tls.X509KeyPair(certPEM, keyPEM) - if err != nil { - panic(err) - } - return &tls.Config{ - Certificates: []tls.Certificate{tlsCert}, - NextProtos: []string{"argotunnel"}, - } -} - type mockOriginProxyWithRequest struct{} func (moc *mockOriginProxyWithRequest) ProxyHTTP(w ResponseWriter, r *http.Request, isWebsocket bool) error { @@ -259,11 +233,14 @@ func (moc *mockOriginProxyWithRequest) ProxyHTTP(w ResponseWriter, r *http.Reque } if isWebsocket { - return wsEndpoint(w, r) + return wsEchoEndpoint(w, r) } switch r.URL.Path { case "/ok": originRespEndpoint(w, http.StatusOK, []byte(http.StatusText(http.StatusOK))) + case "/slow_echo_body": + time.Sleep(5) + fallthrough case "/echo_body": resp := &http.Response{ StatusCode: http.StatusOK, @@ -583,12 +560,12 @@ func serveSession(ctx context.Context, qc *QUICConnection, edgeQUICSession quic. if closeType != closedByRemote { // Session was not closed by remote, so closeUDPSession should be invoked to unregister from remote unregisterFromEdgeChan := make(chan struct{}) - rpcServer := &mockSessionRPCServer{ + sessionRPCServer := &mockSessionRPCServer{ sessionID: sessionID, unregisterReason: expectedReason, calledUnregisterChan: unregisterFromEdgeChan, } - go runMockSessionRPCServer(ctx, edgeQUICSession, rpcServer, t) + go runRPCServer(ctx, edgeQUICSession, sessionRPCServer, nil, t) <-unregisterFromEdgeChan } @@ -604,7 +581,7 @@ const ( closedByTimeout ) -func runMockSessionRPCServer(ctx context.Context, session quic.Session, rpcServer *mockSessionRPCServer, t *testing.T) { +func runRPCServer(ctx context.Context, session quic.Session, sessionRPCServer tunnelpogs.SessionManager, configRPCServer tunnelpogs.ConfigurationManager, t *testing.T) { stream, err := session.AcceptStream(ctx) require.NoError(t, err) @@ -619,7 +596,7 @@ func runMockSessionRPCServer(ctx context.Context, session quic.Session, rpcServe assert.NoError(t, err) log := zerolog.New(os.Stdout) - err = rpcServerStream.Serve(rpcServer, &log) + err = rpcServerStream.Serve(sessionRPCServer, configRPCServer, &log) assert.NoError(t, err) } @@ -641,7 +618,6 @@ func (s mockSessionRPCServer) UnregisterUdpSession(ctx context.Context, sessionI return fmt.Errorf("expect unregister reason %s, got %s", s.unregisterReason, reason) } close(s.calledUnregisterChan) - fmt.Println("unregister from edge") return nil } @@ -651,13 +627,12 @@ func testQUICConnection(udpListenerAddr net.Addr, t *testing.T) *QUICConnection NextProtos: []string{"argotunnel"}, } // Start a mock httpProxy - originProxy := &mockOriginProxyWithRequest{} log := zerolog.New(os.Stdout) qc, err := NewQUICConnection( testQUICConfig, udpListenerAddr, tlsClientConfig, - originProxy, + &mockOrchestrator{originProxy: &mockOriginProxyWithRequest{}}, &tunnelpogs.ConnectionOptions{}, fakeControlStream{}, &log, diff --git a/connection/rpc.go b/connection/rpc.go index e8eb6f4a..937604b3 100644 --- a/connection/rpc.go +++ b/connection/rpc.go @@ -37,7 +37,7 @@ func NewTunnelServerClient( } } -func (tsc *tunnelServerClient) Authenticate(ctx context.Context, classicTunnel *ClassicTunnelConfig, registrationOptions *tunnelpogs.RegistrationOptions) (tunnelpogs.AuthOutcome, error) { +func (tsc *tunnelServerClient) Authenticate(ctx context.Context, classicTunnel *ClassicTunnelProperties, registrationOptions *tunnelpogs.RegistrationOptions) (tunnelpogs.AuthOutcome, error) { authResp, err := tsc.client.Authenticate(ctx, classicTunnel.OriginCert, classicTunnel.Hostname, registrationOptions) if err != nil { return nil, err @@ -54,7 +54,7 @@ func (tsc *tunnelServerClient) Close() { type NamedTunnelRPCClient interface { RegisterConnection( c context.Context, - config *NamedTunnelConfig, + config *NamedTunnelProperties, options *tunnelpogs.ConnectionOptions, connIndex uint8, observer *Observer, @@ -86,15 +86,15 @@ func newRegistrationRPCClient( func (rsc *registrationServerClient) RegisterConnection( ctx context.Context, - config *NamedTunnelConfig, + properties *NamedTunnelProperties, options *tunnelpogs.ConnectionOptions, connIndex uint8, observer *Observer, ) error { conn, err := rsc.client.RegisterConnection( ctx, - config.Credentials.Auth(), - config.Credentials.TunnelID, + properties.Credentials.Auth(), + properties.Credentials.TunnelID, connIndex, options, ) @@ -137,7 +137,7 @@ const ( authenticate rpcName = " authenticate" ) -func (h *h2muxConnection) registerTunnel(ctx context.Context, credentialSetter CredentialManager, classicTunnel *ClassicTunnelConfig, registrationOptions *tunnelpogs.RegistrationOptions) error { +func (h *h2muxConnection) registerTunnel(ctx context.Context, credentialSetter CredentialManager, classicTunnel *ClassicTunnelProperties, registrationOptions *tunnelpogs.RegistrationOptions) error { h.observer.sendRegisteringEvent(registrationOptions.ConnectionID) stream, err := h.newRPCStream(ctx, register) @@ -174,7 +174,7 @@ type CredentialManager interface { func (h *h2muxConnection) processRegistrationSuccess( registration *tunnelpogs.TunnelRegistration, name rpcName, - credentialManager CredentialManager, classicTunnel *ClassicTunnelConfig, + credentialManager CredentialManager, classicTunnel *ClassicTunnelProperties, ) error { for _, logLine := range registration.LogLines { h.observer.log.Info().Msg(logLine) @@ -205,7 +205,7 @@ func (h *h2muxConnection) processRegisterTunnelError(err tunnelpogs.TunnelRegist } } -func (h *h2muxConnection) reconnectTunnel(ctx context.Context, credentialManager CredentialManager, classicTunnel *ClassicTunnelConfig, registrationOptions *tunnelpogs.RegistrationOptions) error { +func (h *h2muxConnection) reconnectTunnel(ctx context.Context, credentialManager CredentialManager, classicTunnel *ClassicTunnelProperties, registrationOptions *tunnelpogs.RegistrationOptions) error { token, err := credentialManager.ReconnectToken() if err != nil { return err @@ -264,7 +264,7 @@ func (h *h2muxConnection) logServerInfo(ctx context.Context, rpcClient *tunnelSe func (h *h2muxConnection) registerNamedTunnel( ctx context.Context, - namedTunnel *NamedTunnelConfig, + namedTunnel *NamedTunnelProperties, connOptions *tunnelpogs.ConnectionOptions, ) error { stream, err := h.newRPCStream(ctx, register) @@ -283,7 +283,7 @@ func (h *h2muxConnection) registerNamedTunnel( func (h *h2muxConnection) unregister(isNamedTunnel bool) { h.observer.sendUnregisteringEvent(h.connIndex) - unregisterCtx, cancel := context.WithTimeout(context.Background(), h.config.GracePeriod) + unregisterCtx, cancel := context.WithTimeout(context.Background(), h.gracePeriod) defer cancel() stream, err := h.newRPCStream(unregisterCtx, unregister) @@ -296,13 +296,13 @@ func (h *h2muxConnection) unregister(isNamedTunnel bool) { rpcClient := h.newRPCClientFunc(unregisterCtx, stream, h.observer.log) defer rpcClient.Close() - rpcClient.GracefulShutdown(unregisterCtx, h.config.GracePeriod) + rpcClient.GracefulShutdown(unregisterCtx, h.gracePeriod) } else { rpcClient := NewTunnelServerClient(unregisterCtx, stream, h.observer.log) defer rpcClient.Close() // gracePeriod is encoded in int64 using capnproto - _ = rpcClient.client.UnregisterTunnel(unregisterCtx, h.config.GracePeriod.Nanoseconds()) + _ = rpcClient.client.UnregisterTunnel(unregisterCtx, h.gracePeriod.Nanoseconds()) } h.observer.log.Info().Uint8(LogFieldConnIndex, h.connIndex).Msg("Unregistered tunnel connection") diff --git a/ingress/origin_request_config.go b/ingress/config.go similarity index 79% rename from ingress/origin_request_config.go rename to ingress/config.go index 37e3ebbc..b389bc9b 100644 --- a/ingress/origin_request_config.go +++ b/ingress/config.go @@ -1,6 +1,7 @@ package ingress import ( + "encoding/json" "time" "github.com/urfave/cli/v2" @@ -38,6 +39,34 @@ const ( socksProxy = "socks" ) +// RemoteConfig models ingress settings that can be managed remotely, for example through the dashboard. +type RemoteConfig struct { + Ingress Ingress + WarpRouting config.WarpRoutingConfig +} + +type remoteConfigJSON struct { + GlobalOriginRequest config.OriginRequestConfig `json:"originRequest"` + IngressRules []config.UnvalidatedIngressRule `json:"ingress"` + WarpRouting config.WarpRoutingConfig `json:"warp-routing"` +} + +func (rc *RemoteConfig) UnmarshalJSON(b []byte) error { + var rawConfig remoteConfigJSON + if err := json.Unmarshal(b, &rawConfig); err != nil { + return err + } + ingress, err := validateIngress(rawConfig.IngressRules, originRequestFromConfig(rawConfig.GlobalOriginRequest)) + if err != nil { + return err + } + + rc.Ingress = ingress + rc.WarpRouting = rawConfig.WarpRouting + + return nil +} + func originRequestFromSingeRule(c *cli.Context) OriginRequestConfig { var connectTimeout time.Duration = defaultConnectTimeout var tlsTimeout time.Duration = defaultTLSTimeout @@ -119,7 +148,7 @@ func originRequestFromSingeRule(c *cli.Context) OriginRequestConfig { } } -func originRequestFromYAML(y config.OriginRequestConfig) OriginRequestConfig { +func originRequestFromConfig(c config.OriginRequestConfig) OriginRequestConfig { out := OriginRequestConfig{ ConnectTimeout: defaultConnectTimeout, TLSTimeout: defaultTLSTimeout, @@ -128,50 +157,58 @@ func originRequestFromYAML(y config.OriginRequestConfig) OriginRequestConfig { KeepAliveTimeout: defaultKeepAliveTimeout, ProxyAddress: defaultProxyAddress, } - if y.ConnectTimeout != nil { - out.ConnectTimeout = *y.ConnectTimeout + if c.ConnectTimeout != nil { + out.ConnectTimeout = *c.ConnectTimeout } - if y.TLSTimeout != nil { - out.TLSTimeout = *y.TLSTimeout + if c.TLSTimeout != nil { + out.TLSTimeout = *c.TLSTimeout } - if y.TCPKeepAlive != nil { - out.TCPKeepAlive = *y.TCPKeepAlive + if c.TCPKeepAlive != nil { + out.TCPKeepAlive = *c.TCPKeepAlive } - if y.NoHappyEyeballs != nil { - out.NoHappyEyeballs = *y.NoHappyEyeballs + if c.NoHappyEyeballs != nil { + out.NoHappyEyeballs = *c.NoHappyEyeballs } - if y.KeepAliveConnections != nil { - out.KeepAliveConnections = *y.KeepAliveConnections + if c.KeepAliveConnections != nil { + out.KeepAliveConnections = *c.KeepAliveConnections } - if y.KeepAliveTimeout != nil { - out.KeepAliveTimeout = *y.KeepAliveTimeout + if c.KeepAliveTimeout != nil { + out.KeepAliveTimeout = *c.KeepAliveTimeout } - if y.HTTPHostHeader != nil { - out.HTTPHostHeader = *y.HTTPHostHeader + if c.HTTPHostHeader != nil { + out.HTTPHostHeader = *c.HTTPHostHeader } - if y.OriginServerName != nil { - out.OriginServerName = *y.OriginServerName + if c.OriginServerName != nil { + out.OriginServerName = *c.OriginServerName } - if y.CAPool != nil { - out.CAPool = *y.CAPool + if c.CAPool != nil { + out.CAPool = *c.CAPool } - if y.NoTLSVerify != nil { - out.NoTLSVerify = *y.NoTLSVerify + if c.NoTLSVerify != nil { + out.NoTLSVerify = *c.NoTLSVerify } - if y.DisableChunkedEncoding != nil { - out.DisableChunkedEncoding = *y.DisableChunkedEncoding + if c.DisableChunkedEncoding != nil { + out.DisableChunkedEncoding = *c.DisableChunkedEncoding } - if y.BastionMode != nil { - out.BastionMode = *y.BastionMode + if c.BastionMode != nil { + out.BastionMode = *c.BastionMode } - if y.ProxyAddress != nil { - out.ProxyAddress = *y.ProxyAddress + if c.ProxyAddress != nil { + out.ProxyAddress = *c.ProxyAddress } - if y.ProxyPort != nil { - out.ProxyPort = *y.ProxyPort + if c.ProxyPort != nil { + out.ProxyPort = *c.ProxyPort } - if y.ProxyType != nil { - out.ProxyType = *y.ProxyType + if c.ProxyType != nil { + out.ProxyType = *c.ProxyType + } + if len(c.IPRules) > 0 { + for _, r := range c.IPRules { + rule, err := ipaccess.NewRuleByCIDR(r.Prefix, r.Ports, r.Allow) + if err == nil { + out.IPRules = append(out.IPRules, rule) + } + } } return out } @@ -188,10 +225,10 @@ type OriginRequestConfig struct { TCPKeepAlive time.Duration `yaml:"tcpKeepAlive"` // HTTP proxy should disable "happy eyeballs" for IPv4/v6 fallback NoHappyEyeballs bool `yaml:"noHappyEyeballs"` - // HTTP proxy maximum keepalive connection pool size - KeepAliveConnections int `yaml:"keepAliveConnections"` // HTTP proxy timeout for closing an idle connection KeepAliveTimeout time.Duration `yaml:"keepAliveTimeout"` + // HTTP proxy maximum keepalive connection pool size + KeepAliveConnections int `yaml:"keepAliveConnections"` // Sets the HTTP Host header for the local webserver. HTTPHostHeader string `yaml:"httpHostHeader"` // Hostname on the origin server certificate. @@ -308,6 +345,19 @@ func (defaults *OriginRequestConfig) setProxyType(overrides config.OriginRequest } } +func (defaults *OriginRequestConfig) setIPRules(overrides config.OriginRequestConfig) { + if val := overrides.IPRules; len(val) > 0 { + ipAccessRule := make([]ipaccess.Rule, len(overrides.IPRules)) + for i, r := range overrides.IPRules { + rule, err := ipaccess.NewRuleByCIDR(r.Prefix, r.Ports, r.Allow) + if err == nil { + ipAccessRule[i] = rule + } + } + defaults.IPRules = ipAccessRule + } +} + // SetConfig gets config for the requests that cloudflared sends to origins. // Each field has a setter method which sets a value for the field by trying to find: // 1. The user config for this rule @@ -332,5 +382,6 @@ func setConfig(defaults OriginRequestConfig, overrides config.OriginRequestConfi cfg.setProxyPort(overrides) cfg.setProxyAddress(overrides) cfg.setProxyType(overrides) + cfg.setIPRules(overrides) return cfg } diff --git a/ingress/config_test.go b/ingress/config_test.go new file mode 100644 index 00000000..ff0a3c0a --- /dev/null +++ b/ingress/config_test.go @@ -0,0 +1,422 @@ +package ingress + +import ( + "encoding/json" + "flag" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/urfave/cli/v2" + yaml "gopkg.in/yaml.v2" + + "github.com/cloudflare/cloudflared/config" + "github.com/cloudflare/cloudflared/ipaccess" +) + +// Ensure that the nullable config from `config` package and the +// non-nullable config from `ingress` package have the same number of +// fields. +// This test ensures that programmers didn't add a new field to +// one struct and forget to add it to the other ;) +func TestCorrespondingFields(t *testing.T) { + require.Equal( + t, + CountFields(t, config.OriginRequestConfig{}), + CountFields(t, OriginRequestConfig{}), + ) +} + +func CountFields(t *testing.T, val interface{}) int { + b, err := yaml.Marshal(val) + require.NoError(t, err) + m := make(map[string]interface{}, 0) + err = yaml.Unmarshal(b, &m) + require.NoError(t, err) + return len(m) +} + +func TestUnmarshalRemoteConfigOverridesGlobal(t *testing.T) { + rawConfig := []byte(` +{ + "originRequest": { + "connectTimeout": 90, + "noHappyEyeballs": true + }, + "ingress": [ + { + "hostname": "jira.cfops.com", + "service": "http://192.16.19.1:80", + "originRequest": { + "noTLSVerify": true, + "connectTimeout": 10 + } + }, + { + "service": "http_status:404" + } + ], + "warp-routing": { + "enabled": true + } +} +`) + var remoteConfig RemoteConfig + err := json.Unmarshal(rawConfig, &remoteConfig) + require.NoError(t, err) + require.True(t, remoteConfig.Ingress.Rules[0].Config.NoTLSVerify) + require.True(t, remoteConfig.Ingress.defaults.NoHappyEyeballs) +} + +func TestOriginRequestConfigOverrides(t *testing.T) { + validate := func(ing Ingress) { + // Rule 0 didn't override anything, so it inherits the user-specified + // root-level configuration. + actual0 := ing.Rules[0].Config + expected0 := OriginRequestConfig{ + ConnectTimeout: 1 * time.Minute, + TLSTimeout: 1 * time.Second, + TCPKeepAlive: 1 * time.Second, + NoHappyEyeballs: true, + KeepAliveTimeout: 1 * time.Second, + KeepAliveConnections: 1, + HTTPHostHeader: "abc", + OriginServerName: "a1", + CAPool: "/tmp/path0", + NoTLSVerify: true, + DisableChunkedEncoding: true, + BastionMode: true, + ProxyAddress: "127.1.2.3", + ProxyPort: uint(100), + ProxyType: "socks5", + IPRules: []ipaccess.Rule{ + newIPRule(t, "10.0.0.0/8", []int{80, 8080}, false), + newIPRule(t, "fc00::/7", []int{443, 4443}, true), + }, + } + require.Equal(t, expected0, actual0) + + // Rule 1 overrode all the root-level config. + actual1 := ing.Rules[1].Config + expected1 := OriginRequestConfig{ + ConnectTimeout: 2 * time.Minute, + TLSTimeout: 2 * time.Second, + TCPKeepAlive: 2 * time.Second, + NoHappyEyeballs: false, + KeepAliveTimeout: 2 * time.Second, + KeepAliveConnections: 2, + HTTPHostHeader: "def", + OriginServerName: "b2", + CAPool: "/tmp/path1", + NoTLSVerify: false, + DisableChunkedEncoding: false, + BastionMode: false, + ProxyAddress: "interface", + ProxyPort: uint(200), + ProxyType: "", + IPRules: []ipaccess.Rule{ + newIPRule(t, "10.0.0.0/16", []int{3000, 3030}, false), + newIPRule(t, "192.16.0.0/24", []int{5000, 5050}, true), + }, + } + require.Equal(t, expected1, actual1) + } + + rulesYAML := ` +originRequest: + connectTimeout: 1m + tlsTimeout: 1s + noHappyEyeballs: true + tcpKeepAlive: 1s + keepAliveConnections: 1 + keepAliveTimeout: 1s + httpHostHeader: abc + originServerName: a1 + caPool: /tmp/path0 + noTLSVerify: true + disableChunkedEncoding: true + bastionMode: True + proxyAddress: 127.1.2.3 + proxyPort: 100 + proxyType: socks5 + ipRules: + - prefix: "10.0.0.0/8" + ports: + - 80 + - 8080 + allow: false + - prefix: "fc00::/7" + ports: + - 443 + - 4443 + allow: true +ingress: +- hostname: tun.example.com + service: https://localhost:8000 +- hostname: "*" + service: https://localhost:8001 + originRequest: + connectTimeout: 2m + tlsTimeout: 2s + noHappyEyeballs: false + tcpKeepAlive: 2s + keepAliveConnections: 2 + keepAliveTimeout: 2s + httpHostHeader: def + originServerName: b2 + caPool: /tmp/path1 + noTLSVerify: false + disableChunkedEncoding: false + bastionMode: false + proxyAddress: interface + proxyPort: 200 + proxyType: "" + ipRules: + - prefix: "10.0.0.0/16" + ports: + - 3000 + - 3030 + allow: false + - prefix: "192.16.0.0/24" + ports: + - 5000 + - 5050 + allow: true +` + + ing, err := ParseIngress(MustReadIngress(rulesYAML)) + require.NoError(t, err) + validate(ing) + + rawConfig := []byte(` +{ + "originRequest": { + "connectTimeout": 60000000000, + "tlsTimeout": 1000000000, + "noHappyEyeballs": true, + "tcpKeepAlive": 1000000000, + "keepAliveConnections": 1, + "keepAliveTimeout": 1000000000, + "httpHostHeader": "abc", + "originServerName": "a1", + "caPool": "/tmp/path0", + "noTLSVerify": true, + "disableChunkedEncoding": true, + "bastionMode": true, + "proxyAddress": "127.1.2.3", + "proxyPort": 100, + "proxyType": "socks5", + "ipRules": [ + { + "prefix": "10.0.0.0/8", + "ports": [80, 8080], + "allow": false + }, + { + "prefix": "fc00::/7", + "ports": [443, 4443], + "allow": true + } + ] + }, + "ingress": [ + { + "hostname": "tun.example.com", + "service": "https://localhost:8000" + }, + { + "hostname": "*", + "service": "https://localhost:8001", + "originRequest": { + "connectTimeout": 120000000000, + "tlsTimeout": 2000000000, + "noHappyEyeballs": false, + "tcpKeepAlive": 2000000000, + "keepAliveConnections": 2, + "keepAliveTimeout": 2000000000, + "httpHostHeader": "def", + "originServerName": "b2", + "caPool": "/tmp/path1", + "noTLSVerify": false, + "disableChunkedEncoding": false, + "bastionMode": false, + "proxyAddress": "interface", + "proxyPort": 200, + "proxyType": "", + "ipRules": [ + { + "prefix": "10.0.0.0/16", + "ports": [3000, 3030], + "allow": false + }, + { + "prefix": "192.16.0.0/24", + "ports": [5000, 5050], + "allow": true + } + ] + } + } + ], + "warp-routing": { + "enabled": true + } +} +`) + var remoteConfig RemoteConfig + err = json.Unmarshal(rawConfig, &remoteConfig) + require.NoError(t, err) + validate(remoteConfig.Ingress) +} + +func TestOriginRequestConfigDefaults(t *testing.T) { + validate := func(ing Ingress) { + // Rule 0 didn't override anything, so it inherits the cloudflared defaults + actual0 := ing.Rules[0].Config + expected0 := OriginRequestConfig{ + ConnectTimeout: defaultConnectTimeout, + TLSTimeout: defaultTLSTimeout, + TCPKeepAlive: defaultTCPKeepAlive, + KeepAliveConnections: defaultKeepAliveConnections, + KeepAliveTimeout: defaultKeepAliveTimeout, + ProxyAddress: defaultProxyAddress, + } + require.Equal(t, expected0, actual0) + + // Rule 1 overrode all defaults. + actual1 := ing.Rules[1].Config + expected1 := OriginRequestConfig{ + ConnectTimeout: 2 * time.Minute, + TLSTimeout: 2 * time.Second, + TCPKeepAlive: 2 * time.Second, + NoHappyEyeballs: false, + KeepAliveTimeout: 2 * time.Second, + KeepAliveConnections: 2, + HTTPHostHeader: "def", + OriginServerName: "b2", + CAPool: "/tmp/path1", + NoTLSVerify: false, + DisableChunkedEncoding: false, + BastionMode: false, + ProxyAddress: "interface", + ProxyPort: uint(200), + ProxyType: "", + IPRules: []ipaccess.Rule{ + newIPRule(t, "10.0.0.0/16", []int{3000, 3030}, false), + newIPRule(t, "192.16.0.0/24", []int{5000, 5050}, true), + }, + } + require.Equal(t, expected1, actual1) + } + + rulesYAML := ` +ingress: +- hostname: tun.example.com + service: https://localhost:8000 +- hostname: "*" + service: https://localhost:8001 + originRequest: + connectTimeout: 2m + tlsTimeout: 2s + noHappyEyeballs: false + tcpKeepAlive: 2s + keepAliveConnections: 2 + keepAliveTimeout: 2s + httpHostHeader: def + originServerName: b2 + caPool: /tmp/path1 + noTLSVerify: false + disableChunkedEncoding: false + bastionMode: false + proxyAddress: interface + proxyPort: 200 + proxyType: "" + ipRules: + - prefix: "10.0.0.0/16" + ports: + - 3000 + - 3030 + allow: false + - prefix: "192.16.0.0/24" + ports: + - 5000 + - 5050 + allow: true +` + ing, err := ParseIngress(MustReadIngress(rulesYAML)) + if err != nil { + t.Error(err) + } + validate(ing) + + rawConfig := []byte(` +{ + "ingress": [ + { + "hostname": "tun.example.com", + "service": "https://localhost:8000" + }, + { + "hostname": "*", + "service": "https://localhost:8001", + "originRequest": { + "connectTimeout": 120000000000, + "tlsTimeout": 2000000000, + "noHappyEyeballs": false, + "tcpKeepAlive": 2000000000, + "keepAliveConnections": 2, + "keepAliveTimeout": 2000000000, + "httpHostHeader": "def", + "originServerName": "b2", + "caPool": "/tmp/path1", + "noTLSVerify": false, + "disableChunkedEncoding": false, + "bastionMode": false, + "proxyAddress": "interface", + "proxyPort": 200, + "proxyType": "", + "ipRules": [ + { + "prefix": "10.0.0.0/16", + "ports": [3000, 3030], + "allow": false + }, + { + "prefix": "192.16.0.0/24", + "ports": [5000, 5050], + "allow": true + } + ] + } + } + ] +} +`) + + var remoteConfig RemoteConfig + err = json.Unmarshal(rawConfig, &remoteConfig) + require.NoError(t, err) + validate(remoteConfig.Ingress) +} + +func TestDefaultConfigFromCLI(t *testing.T) { + set := flag.NewFlagSet("contrive", 0) + c := cli.NewContext(nil, set, nil) + + expected := OriginRequestConfig{ + ConnectTimeout: defaultConnectTimeout, + TLSTimeout: defaultTLSTimeout, + TCPKeepAlive: defaultTCPKeepAlive, + KeepAliveConnections: defaultKeepAliveConnections, + KeepAliveTimeout: defaultKeepAliveTimeout, + ProxyAddress: defaultProxyAddress, + } + actual := originRequestFromSingeRule(c) + require.Equal(t, expected, actual) +} + +func newIPRule(t *testing.T, prefix string, ports []int, allow bool) ipaccess.Rule { + rule, err := ipaccess.NewRuleByCIDR(&prefix, ports, allow) + require.NoError(t, err) + return rule +} diff --git a/ingress/ingress.go b/ingress/ingress.go index 0529e382..5e5f9655 100644 --- a/ingress/ingress.go +++ b/ingress/ingress.go @@ -7,7 +7,6 @@ import ( "regexp" "strconv" "strings" - "sync" "github.com/pkg/errors" "github.com/rs/zerolog" @@ -145,13 +144,11 @@ func (ing Ingress) IsSingleRule() bool { // StartOrigins will start any origin services managed by cloudflared, e.g. proxy servers or Hello World. func (ing Ingress) StartOrigins( - wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, - errC chan error, ) error { for _, rule := range ing.Rules { - if err := rule.Service.start(wg, log, shutdownC, errC, rule.Config); err != nil { + if err := rule.Service.start(log, shutdownC, rule.Config); err != nil { return errors.Wrapf(err, "Error starting local service %s", rule.Service) } } @@ -163,7 +160,7 @@ func (ing Ingress) CatchAll() *Rule { return &ing.Rules[len(ing.Rules)-1] } -func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestConfig) (Ingress, error) { +func validateIngress(ingress []config.UnvalidatedIngressRule, defaults OriginRequestConfig) (Ingress, error) { rules := make([]Rule, len(ingress)) for i, r := range ingress { cfg := setConfig(defaults, r.OriginRequest) @@ -290,7 +287,7 @@ func ParseIngress(conf *config.Configuration) (Ingress, error) { if len(conf.Ingress) == 0 { return Ingress{}, ErrNoIngressRules } - return validate(conf.Ingress, originRequestFromYAML(conf.OriginRequest)) + return validateIngress(conf.Ingress, originRequestFromConfig(conf.OriginRequest)) } func isHTTPService(url *url.URL) bool { diff --git a/ingress/ingress_test.go b/ingress/ingress_test.go index 7867d2b5..9d09e8f8 100644 --- a/ingress/ingress_test.go +++ b/ingress/ingress_test.go @@ -34,7 +34,7 @@ func Test_parseIngress(t *testing.T) { localhost8000 := MustParseURL(t, "https://localhost:8000") localhost8001 := MustParseURL(t, "https://localhost:8001") fourOhFour := newStatusCode(404) - defaultConfig := setConfig(originRequestFromYAML(config.OriginRequestConfig{}), config.OriginRequestConfig{}) + defaultConfig := setConfig(originRequestFromConfig(config.OriginRequestConfig{}), config.OriginRequestConfig{}) require.Equal(t, defaultKeepAliveConnections, defaultConfig.KeepAliveConnections) tr := true type args struct { @@ -324,7 +324,17 @@ ingress: { Hostname: "socks.foo.com", Service: newSocksProxyOverWSService(accessPolicy()), - Config: defaultConfig, + Config: setConfig(originRequestFromConfig(config.OriginRequestConfig{}), config.OriginRequestConfig{IPRules: []config.IngressIPRule{ + { + Prefix: ipRulePrefix("1.1.1.0/24"), + Ports: []int{80, 443}, + Allow: true, + }, + { + Prefix: ipRulePrefix("0.0.0.0/0"), + Allow: false, + }, + }}), }, { Service: &fourOhFour, @@ -345,7 +355,7 @@ ingress: { Hostname: "bastion.foo.com", Service: newBastionService(), - Config: setConfig(originRequestFromYAML(config.OriginRequestConfig{}), config.OriginRequestConfig{BastionMode: &tr}), + Config: setConfig(originRequestFromConfig(config.OriginRequestConfig{}), config.OriginRequestConfig{BastionMode: &tr}), }, { Service: &fourOhFour, @@ -365,7 +375,7 @@ ingress: { Hostname: "bastion.foo.com", Service: newBastionService(), - Config: setConfig(originRequestFromYAML(config.OriginRequestConfig{}), config.OriginRequestConfig{BastionMode: &tr}), + Config: setConfig(originRequestFromConfig(config.OriginRequestConfig{}), config.OriginRequestConfig{BastionMode: &tr}), }, { Service: &fourOhFour, @@ -397,6 +407,10 @@ ingress: } } +func ipRulePrefix(s string) *string { + return &s +} + func TestSingleOriginSetsConfig(t *testing.T) { flagSet := flag.NewFlagSet(t.Name(), flag.PanicOnError) flagSet.Bool("hello-world", true, "") diff --git a/ingress/origin_connection.go b/ingress/origin_connection.go index 9588ce36..2e8b946e 100644 --- a/ingress/origin_connection.go +++ b/ingress/origin_connection.go @@ -53,7 +53,7 @@ func (wc *tcpOverWSConnection) Stream(ctx context.Context, tunnelConn io.ReadWri wc.streamHandler(wsConn, wc.conn, log) cancel() // Makes sure wsConn stops sending ping before terminating the stream - wsConn.WaitForShutdown() + wsConn.Close() } func (wc *tcpOverWSConnection) Close() { @@ -73,21 +73,8 @@ func (sp *socksProxyOverWSConnection) Stream(ctx context.Context, tunnelConn io. socks.StreamNetHandler(wsConn, sp.accessPolicy, log) cancel() // Makes sure wsConn stops sending ping before terminating the stream - wsConn.WaitForShutdown() + wsConn.Close() } func (sp *socksProxyOverWSConnection) Close() { } - -// wsProxyConnection represents a bidirectional stream for a websocket connection to the origin -type wsProxyConnection struct { - rwc io.ReadWriteCloser -} - -func (conn *wsProxyConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) { - websocket.Stream(tunnelConn, conn.rwc, log) -} - -func (conn *wsProxyConnection) Close() { - conn.rwc.Close() -} diff --git a/ingress/origin_proxy.go b/ingress/origin_proxy.go index ffc3d9cb..63c10137 100644 --- a/ingress/origin_proxy.go +++ b/ingress/origin_proxy.go @@ -23,6 +23,7 @@ type StreamBasedOriginProxy interface { } func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) { + req.URL.Scheme = "http" return o.transport.RoundTrip(req) } diff --git a/ingress/origin_proxy_test.go b/ingress/origin_proxy_test.go index b14408b8..cc244aee 100644 --- a/ingress/origin_proxy_test.go +++ b/ingress/origin_proxy_test.go @@ -8,7 +8,6 @@ import ( "net/http" "net/http/httptest" "net/url" - "sync" "testing" "github.com/stretchr/testify/assert" @@ -132,10 +131,8 @@ func TestHTTPServiceHostHeaderOverride(t *testing.T) { httpService := &httpService{ url: originURL, } - var wg sync.WaitGroup shutdownC := make(chan struct{}) - errC := make(chan error) - require.NoError(t, httpService.start(&wg, testLogger, shutdownC, errC, cfg)) + require.NoError(t, httpService.start(testLogger, shutdownC, cfg)) req, err := http.NewRequest(http.MethodGet, originURL.String(), nil) require.NoError(t, err) @@ -147,7 +144,46 @@ func TestHTTPServiceHostHeaderOverride(t *testing.T) { respBody, err := ioutil.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, respBody, []byte(originURL.Host)) +} +// TestHTTPServiceUsesIngressRuleScheme makes sure httpService uses scheme defined in ingress rule and not by eyeball request +func TestHTTPServiceUsesIngressRuleScheme(t *testing.T) { + handler := func(w http.ResponseWriter, r *http.Request) { + require.NotNil(t, r.TLS) + // Echo the X-Forwarded-Proto header for assertions + w.Write([]byte(r.Header.Get("X-Forwarded-Proto"))) + } + origin := httptest.NewTLSServer(http.HandlerFunc(handler)) + defer origin.Close() + + originURL, err := url.Parse(origin.URL) + require.NoError(t, err) + require.Equal(t, "https", originURL.Scheme) + + cfg := OriginRequestConfig{ + NoTLSVerify: true, + } + httpService := &httpService{ + url: originURL, + } + shutdownC := make(chan struct{}) + require.NoError(t, httpService.start(testLogger, shutdownC, cfg)) + + // Tunnel uses scheme defined in the service field of the ingress rule, independent of the X-Forwarded-Proto header + protos := []string{"https", "http", "dne"} + for _, p := range protos { + req, err := http.NewRequest(http.MethodGet, originURL.String(), nil) + require.NoError(t, err) + req.Header.Add("X-Forwarded-Proto", p) + + resp, err := httpService.RoundTrip(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + + respBody, err := ioutil.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, respBody, []byte(p)) + } } func tcpListenRoutine(listener net.Listener, closeChan chan struct{}) { diff --git a/ingress/origin_request_config_test.go b/ingress/origin_request_config_test.go deleted file mode 100644 index d4113d1d..00000000 --- a/ingress/origin_request_config_test.go +++ /dev/null @@ -1,203 +0,0 @@ -package ingress - -import ( - "flag" - "testing" - "time" - - "github.com/stretchr/testify/require" - "github.com/urfave/cli/v2" - yaml "gopkg.in/yaml.v2" - - "github.com/cloudflare/cloudflared/config" -) - -// Ensure that the nullable config from `config` package and the -// non-nullable config from `ingress` package have the same number of -// fields. -// This test ensures that programmers didn't add a new field to -// one struct and forget to add it to the other ;) -func TestCorrespondingFields(t *testing.T) { - require.Equal( - t, - CountFields(t, config.OriginRequestConfig{}), - CountFields(t, OriginRequestConfig{}), - ) -} - -func CountFields(t *testing.T, val interface{}) int { - b, err := yaml.Marshal(val) - require.NoError(t, err) - m := make(map[string]interface{}, 0) - err = yaml.Unmarshal(b, &m) - require.NoError(t, err) - return len(m) -} - -func TestOriginRequestConfigOverrides(t *testing.T) { - rulesYAML := ` -originRequest: - connectTimeout: 1m - tlsTimeout: 1s - noHappyEyeballs: true - tcpKeepAlive: 1s - keepAliveConnections: 1 - keepAliveTimeout: 1s - httpHostHeader: abc - originServerName: a1 - caPool: /tmp/path0 - noTLSVerify: true - disableChunkedEncoding: true - bastionMode: True - proxyAddress: 127.1.2.3 - proxyPort: 100 - proxyType: socks5 -ingress: -- hostname: tun.example.com - service: https://localhost:8000 -- hostname: "*" - service: https://localhost:8001 - originRequest: - connectTimeout: 2m - tlsTimeout: 2s - noHappyEyeballs: false - tcpKeepAlive: 2s - keepAliveConnections: 2 - keepAliveTimeout: 2s - httpHostHeader: def - originServerName: b2 - caPool: /tmp/path1 - noTLSVerify: false - disableChunkedEncoding: false - bastionMode: false - proxyAddress: interface - proxyPort: 200 - proxyType: "" -` - ing, err := ParseIngress(MustReadIngress(rulesYAML)) - if err != nil { - t.Error(err) - } - - // Rule 0 didn't override anything, so it inherits the user-specified - // root-level configuration. - actual0 := ing.Rules[0].Config - expected0 := OriginRequestConfig{ - ConnectTimeout: 1 * time.Minute, - TLSTimeout: 1 * time.Second, - NoHappyEyeballs: true, - TCPKeepAlive: 1 * time.Second, - KeepAliveConnections: 1, - KeepAliveTimeout: 1 * time.Second, - HTTPHostHeader: "abc", - OriginServerName: "a1", - CAPool: "/tmp/path0", - NoTLSVerify: true, - DisableChunkedEncoding: true, - BastionMode: true, - ProxyAddress: "127.1.2.3", - ProxyPort: uint(100), - ProxyType: "socks5", - } - require.Equal(t, expected0, actual0) - - // Rule 1 overrode all the root-level config. - actual1 := ing.Rules[1].Config - expected1 := OriginRequestConfig{ - ConnectTimeout: 2 * time.Minute, - TLSTimeout: 2 * time.Second, - NoHappyEyeballs: false, - TCPKeepAlive: 2 * time.Second, - KeepAliveConnections: 2, - KeepAliveTimeout: 2 * time.Second, - HTTPHostHeader: "def", - OriginServerName: "b2", - CAPool: "/tmp/path1", - NoTLSVerify: false, - DisableChunkedEncoding: false, - BastionMode: false, - ProxyAddress: "interface", - ProxyPort: uint(200), - ProxyType: "", - } - require.Equal(t, expected1, actual1) -} - -func TestOriginRequestConfigDefaults(t *testing.T) { - rulesYAML := ` -ingress: -- hostname: tun.example.com - service: https://localhost:8000 -- hostname: "*" - service: https://localhost:8001 - originRequest: - connectTimeout: 2m - tlsTimeout: 2s - noHappyEyeballs: false - tcpKeepAlive: 2s - keepAliveConnections: 2 - keepAliveTimeout: 2s - httpHostHeader: def - originServerName: b2 - caPool: /tmp/path1 - noTLSVerify: false - disableChunkedEncoding: false - bastionMode: false - proxyAddress: interface - proxyPort: 200 - proxyType: "" -` - ing, err := ParseIngress(MustReadIngress(rulesYAML)) - if err != nil { - t.Error(err) - } - - // Rule 0 didn't override anything, so it inherits the cloudflared defaults - actual0 := ing.Rules[0].Config - expected0 := OriginRequestConfig{ - ConnectTimeout: defaultConnectTimeout, - TLSTimeout: defaultTLSTimeout, - TCPKeepAlive: defaultTCPKeepAlive, - KeepAliveConnections: defaultKeepAliveConnections, - KeepAliveTimeout: defaultKeepAliveTimeout, - ProxyAddress: defaultProxyAddress, - } - require.Equal(t, expected0, actual0) - - // Rule 1 overrode all defaults. - actual1 := ing.Rules[1].Config - expected1 := OriginRequestConfig{ - ConnectTimeout: 2 * time.Minute, - TLSTimeout: 2 * time.Second, - NoHappyEyeballs: false, - TCPKeepAlive: 2 * time.Second, - KeepAliveConnections: 2, - KeepAliveTimeout: 2 * time.Second, - HTTPHostHeader: "def", - OriginServerName: "b2", - CAPool: "/tmp/path1", - NoTLSVerify: false, - DisableChunkedEncoding: false, - BastionMode: false, - ProxyAddress: "interface", - ProxyPort: uint(200), - ProxyType: "", - } - require.Equal(t, expected1, actual1) -} - -func TestDefaultConfigFromCLI(t *testing.T) { - set := flag.NewFlagSet("contrive", 0) - c := cli.NewContext(nil, set, nil) - - expected := OriginRequestConfig{ - ConnectTimeout: defaultConnectTimeout, - TLSTimeout: defaultTLSTimeout, - TCPKeepAlive: defaultTCPKeepAlive, - KeepAliveConnections: defaultKeepAliveConnections, - KeepAliveTimeout: defaultKeepAliveTimeout, - ProxyAddress: defaultProxyAddress, - } - actual := originRequestFromSingeRule(c) - require.Equal(t, expected, actual) -} diff --git a/ingress/origin_service.go b/ingress/origin_service.go index fc636c86..116b77f0 100644 --- a/ingress/origin_service.go +++ b/ingress/origin_service.go @@ -8,7 +8,6 @@ import ( "net" "net/http" "net/url" - "sync" "time" "github.com/pkg/errors" @@ -20,13 +19,18 @@ import ( "github.com/cloudflare/cloudflared/tlsconfig" ) +const ( + HelloWorldService = "Hello World test origin" +) + // OriginService is something a tunnel can proxy traffic to. type OriginService interface { String() string // Start the origin service if it's managed by cloudflared, e.g. proxy servers or Hello World. // If it's not managed by cloudflared, this is a no-op because the user is responsible for // starting the origin service. - start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error + // Implementor of services managed by cloudflared should terminate the service if shutdownC is closed + start(log *zerolog.Logger, shutdownC <-chan struct{}, cfg OriginRequestConfig) error } // unixSocketPath is an OriginService representing a unix socket (which accepts HTTP) @@ -39,7 +43,7 @@ func (o *unixSocketPath) String() string { return "unix socket: " + o.path } -func (o *unixSocketPath) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { +func (o *unixSocketPath) start(log *zerolog.Logger, _ <-chan struct{}, cfg OriginRequestConfig) error { transport, err := newHTTPTransport(o, cfg, log) if err != nil { return err @@ -54,7 +58,7 @@ type httpService struct { transport *http.Transport } -func (o *httpService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { +func (o *httpService) start(log *zerolog.Logger, _ <-chan struct{}, cfg OriginRequestConfig) error { transport, err := newHTTPTransport(o, cfg, log) if err != nil { return err @@ -78,7 +82,7 @@ func (o *rawTCPService) String() string { return o.name } -func (o *rawTCPService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { +func (o *rawTCPService) start(log *zerolog.Logger, _ <-chan struct{}, cfg OriginRequestConfig) error { return nil } @@ -139,7 +143,7 @@ func (o *tcpOverWSService) String() string { return o.dest } -func (o *tcpOverWSService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { +func (o *tcpOverWSService) start(log *zerolog.Logger, _ <-chan struct{}, cfg OriginRequestConfig) error { if cfg.ProxyType == socksProxy { o.streamHandler = socks.StreamHandler } else { @@ -148,7 +152,7 @@ func (o *tcpOverWSService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdo return nil } -func (o *socksProxyOverWSService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { +func (o *socksProxyOverWSService) start(log *zerolog.Logger, _ <-chan struct{}, cfg OriginRequestConfig) error { return nil } @@ -164,18 +168,16 @@ type helloWorld struct { } func (o *helloWorld) String() string { - return "Hello World test origin" + return HelloWorldService } // Start starts a HelloWorld server and stores its address in the Service receiver. func (o *helloWorld) start( - wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, - errC chan error, cfg OriginRequestConfig, ) error { - if err := o.httpService.start(wg, log, shutdownC, errC, cfg); err != nil { + if err := o.httpService.start(log, shutdownC, cfg); err != nil { return err } @@ -183,11 +185,7 @@ func (o *helloWorld) start( if err != nil { return errors.Wrap(err, "Cannot start Hello World Server") } - wg.Add(1) - go func() { - defer wg.Done() - _ = hello.StartHelloWorldServer(log, helloListener, shutdownC) - }() + go hello.StartHelloWorldServer(log, helloListener, shutdownC) o.server = helloListener o.httpService.url = &url.URL{ @@ -218,10 +216,8 @@ func (o *statusCode) String() string { } func (o *statusCode) start( - wg *sync.WaitGroup, log *zerolog.Logger, - shutdownC <-chan struct{}, - errC chan error, + _ <-chan struct{}, cfg OriginRequestConfig, ) error { return nil @@ -296,6 +292,6 @@ func (mos MockOriginHTTPService) String() string { return "MockOriginService" } -func (mos MockOriginHTTPService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { +func (mos MockOriginHTTPService) start(log *zerolog.Logger, _ <-chan struct{}, cfg OriginRequestConfig) error { return nil } diff --git a/orchestration/config.go b/orchestration/config.go new file mode 100644 index 00000000..dff7e701 --- /dev/null +++ b/orchestration/config.go @@ -0,0 +1,15 @@ +package orchestration + +import ( + "github.com/cloudflare/cloudflared/ingress" +) + +type newConfig struct { + ingress.RemoteConfig + // Add more fields when we support other settings in tunnel orchestration +} + +type Config struct { + Ingress *ingress.Ingress + WarpRoutingEnabled bool +} diff --git a/orchestration/orchestrator.go b/orchestration/orchestrator.go new file mode 100644 index 00000000..d072e966 --- /dev/null +++ b/orchestration/orchestrator.go @@ -0,0 +1,158 @@ +package orchestration + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "sync/atomic" + + "github.com/pkg/errors" + "github.com/rs/zerolog" + + "github.com/cloudflare/cloudflared/connection" + "github.com/cloudflare/cloudflared/ingress" + "github.com/cloudflare/cloudflared/proxy" + tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" +) + +// Orchestrator manages configurations so they can be updatable during runtime +// properties are static, so it can be read without lock +// currentVersion and config are read/write infrequently, so their access are synchronized with RWMutex +// access to proxy is synchronized with atmoic.Value, because it uses copy-on-write to provide scalable frequently +// read when update is infrequent +type Orchestrator struct { + currentVersion int32 + // Used by UpdateConfig to make sure one update at a time + lock sync.RWMutex + // Underlying value is proxy.Proxy, can be read without the lock, but still needs the lock to update + proxy atomic.Value + config *Config + tags []tunnelpogs.Tag + log *zerolog.Logger + + // orchestrator must not handle any more updates after shutdownC is closed + shutdownC <-chan struct{} + // Closing proxyShutdownC will close the previous proxy + proxyShutdownC chan<- struct{} +} + +func NewOrchestrator(ctx context.Context, config *Config, tags []tunnelpogs.Tag, log *zerolog.Logger) (*Orchestrator, error) { + o := &Orchestrator{ + // Lowest possible version, any remote configuration will have version higher than this + currentVersion: 0, + config: config, + tags: tags, + log: log, + shutdownC: ctx.Done(), + } + if err := o.updateIngress(*config.Ingress, config.WarpRoutingEnabled); err != nil { + return nil, err + } + go o.waitToCloseLastProxy() + return o, nil +} + +// Update creates a new proxy with the new ingress rules +func (o *Orchestrator) UpdateConfig(version int32, config []byte) *tunnelpogs.UpdateConfigurationResponse { + o.lock.Lock() + defer o.lock.Unlock() + + if o.currentVersion >= version { + o.log.Debug(). + Int32("current_version", o.currentVersion). + Int32("received_version", version). + Msg("Current version is equal or newer than receivied version") + return &tunnelpogs.UpdateConfigurationResponse{ + LastAppliedVersion: o.currentVersion, + } + } + var newConf newConfig + if err := json.Unmarshal(config, &newConf); err != nil { + o.log.Err(err). + Int32("version", version). + Str("config", string(config)). + Msgf("Failed to deserialize new configuration") + return &tunnelpogs.UpdateConfigurationResponse{ + LastAppliedVersion: o.currentVersion, + Err: err, + } + } + + if err := o.updateIngress(newConf.Ingress, newConf.WarpRouting.Enabled); err != nil { + o.log.Err(err). + Int32("version", version). + Str("config", string(config)). + Msgf("Failed to update ingress") + return &tunnelpogs.UpdateConfigurationResponse{ + LastAppliedVersion: o.currentVersion, + Err: err, + } + } + o.currentVersion = version + + o.log.Info(). + Int32("version", version). + Str("config", string(config)). + Msg("Updated to new configuration") + return &tunnelpogs.UpdateConfigurationResponse{ + LastAppliedVersion: o.currentVersion, + } +} + +// The caller is responsible to make sure there is no concurrent access +func (o *Orchestrator) updateIngress(ingressRules ingress.Ingress, warpRoutingEnabled bool) error { + select { + case <-o.shutdownC: + return fmt.Errorf("cloudflared already shutdown") + default: + } + + // Start new proxy before closing the ones from last version. + // The upside is we don't need to restart proxy from last version, which can fail + // The downside is new version might have ingress rule that require previous version to be shutdown first + // The downside is minimized because none of the ingress.OriginService implementation have that requirement + proxyShutdownC := make(chan struct{}) + if err := ingressRules.StartOrigins(o.log, proxyShutdownC); err != nil { + return errors.Wrap(err, "failed to start origin") + } + newProxy := proxy.NewOriginProxy(ingressRules, warpRoutingEnabled, o.tags, o.log) + o.proxy.Store(newProxy) + o.config.Ingress = &ingressRules + o.config.WarpRoutingEnabled = warpRoutingEnabled + + // If proxyShutdownC is nil, there is no previous running proxy + if o.proxyShutdownC != nil { + close(o.proxyShutdownC) + } + o.proxyShutdownC = proxyShutdownC + return nil +} + +// GetOriginProxy returns an interface to proxy to origin. It satisfies connection.ConfigManager interface +func (o *Orchestrator) GetOriginProxy() (connection.OriginProxy, error) { + val := o.proxy.Load() + if val == nil { + err := fmt.Errorf("origin proxy not configured") + o.log.Error().Msg(err.Error()) + return nil, err + } + proxy, ok := val.(*proxy.Proxy) + if !ok { + err := fmt.Errorf("origin proxy has unexpected value %+v", val) + o.log.Error().Msg(err.Error()) + return nil, err + } + return proxy, nil +} + +func (o *Orchestrator) waitToCloseLastProxy() { + <-o.shutdownC + o.lock.Lock() + defer o.lock.Unlock() + + if o.proxyShutdownC != nil { + close(o.proxyShutdownC) + o.proxyShutdownC = nil + } +} diff --git a/orchestration/orchestrator_test.go b/orchestration/orchestrator_test.go new file mode 100644 index 00000000..b4b19224 --- /dev/null +++ b/orchestration/orchestrator_test.go @@ -0,0 +1,686 @@ +package orchestration + +import ( + "context" + "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/gobwas/ws/wsutil" + gows "github.com/gorilla/websocket" + "github.com/rs/zerolog" + "github.com/stretchr/testify/require" + + "github.com/cloudflare/cloudflared/connection" + "github.com/cloudflare/cloudflared/ingress" + "github.com/cloudflare/cloudflared/proxy" + tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" +) + +var ( + testLogger = zerolog.Logger{} + testTags = []tunnelpogs.Tag{ + { + Name: "package", + Value: "orchestration", + }, + { + Name: "purpose", + Value: "test", + }, + } +) + +// TestUpdateConfiguration tests that +// - configurations can be deserialized +// - proxy can be updated +// - last applied version and error are returned +// - configurations can be deserialized +// - receiving an old version is noop +func TestUpdateConfiguration(t *testing.T) { + initConfig := &Config{ + Ingress: &ingress.Ingress{}, + WarpRoutingEnabled: false, + } + orchestrator, err := NewOrchestrator(context.Background(), initConfig, testTags, &testLogger) + require.NoError(t, err) + initOriginProxy, err := orchestrator.GetOriginProxy() + require.NoError(t, err) + require.IsType(t, &proxy.Proxy{}, initOriginProxy) + + configJSONV2 := []byte(` +{ + "unknown_field": "not_deserialized", + "originRequest": { + "connectTimeout": 90000000000, + "noHappyEyeballs": true + }, + "ingress": [ + { + "hostname": "jira.tunnel.org", + "path": "^\/login", + "service": "http://192.16.19.1:443", + "originRequest": { + "noTLSVerify": true, + "connectTimeout": 10000000000 + } + }, + { + "hostname": "jira.tunnel.org", + "service": "http://172.32.20.6:80", + "originRequest": { + "noTLSVerify": true, + "connectTimeout": 30000000000 + } + }, + { + "service": "http_status:404" + } + ], + "warp-routing": { + "enabled": true + } +} +`) + + updateWithValidation(t, orchestrator, 2, configJSONV2) + configV2 := orchestrator.config + // Validate ingress rule 0 + require.Equal(t, "jira.tunnel.org", configV2.Ingress.Rules[0].Hostname) + require.True(t, configV2.Ingress.Rules[0].Matches("jira.tunnel.org", "/login")) + require.True(t, configV2.Ingress.Rules[0].Matches("jira.tunnel.org", "/login/2fa")) + require.False(t, configV2.Ingress.Rules[0].Matches("jira.tunnel.org", "/users")) + require.Equal(t, "http://192.16.19.1:443", configV2.Ingress.Rules[0].Service.String()) + require.Len(t, configV2.Ingress.Rules, 3) + // originRequest of this ingress rule overrides global default + require.Equal(t, time.Second*10, configV2.Ingress.Rules[0].Config.ConnectTimeout) + require.Equal(t, true, configV2.Ingress.Rules[0].Config.NoTLSVerify) + // Inherited from global default + require.Equal(t, true, configV2.Ingress.Rules[0].Config.NoHappyEyeballs) + // Validate ingress rule 1 + require.Equal(t, "jira.tunnel.org", configV2.Ingress.Rules[1].Hostname) + require.True(t, configV2.Ingress.Rules[1].Matches("jira.tunnel.org", "/users")) + require.Equal(t, "http://172.32.20.6:80", configV2.Ingress.Rules[1].Service.String()) + // originRequest of this ingress rule overrides global default + require.Equal(t, time.Second*30, configV2.Ingress.Rules[1].Config.ConnectTimeout) + require.Equal(t, true, configV2.Ingress.Rules[1].Config.NoTLSVerify) + // Inherited from global default + require.Equal(t, true, configV2.Ingress.Rules[1].Config.NoHappyEyeballs) + // Validate ingress rule 2, it's the catch-all rule + require.True(t, configV2.Ingress.Rules[2].Matches("blogs.tunnel.io", "/2022/02/10")) + // Inherited from global default + require.Equal(t, time.Second*90, configV2.Ingress.Rules[2].Config.ConnectTimeout) + require.Equal(t, false, configV2.Ingress.Rules[2].Config.NoTLSVerify) + require.Equal(t, true, configV2.Ingress.Rules[2].Config.NoHappyEyeballs) + require.True(t, configV2.WarpRoutingEnabled) + + originProxyV2, err := orchestrator.GetOriginProxy() + require.NoError(t, err) + require.IsType(t, &proxy.Proxy{}, originProxyV2) + require.NotEqual(t, originProxyV2, initOriginProxy) + + // Should not downgrade to an older version + resp := orchestrator.UpdateConfig(1, nil) + require.NoError(t, resp.Err) + require.Equal(t, int32(2), resp.LastAppliedVersion) + + invalidJSON := []byte(` +{ + "originRequest": +} + +`) + + resp = orchestrator.UpdateConfig(3, invalidJSON) + require.Error(t, resp.Err) + require.Equal(t, int32(2), resp.LastAppliedVersion) + originProxyV3, err := orchestrator.GetOriginProxy() + require.NoError(t, err) + require.Equal(t, originProxyV2, originProxyV3) + + configJSONV10 := []byte(` +{ + "ingress": [ + { + "service": "hello-world" + } + ], + "warp-routing": { + "enabled": false + } +} +`) + updateWithValidation(t, orchestrator, 10, configJSONV10) + configV10 := orchestrator.config + require.Len(t, configV10.Ingress.Rules, 1) + require.True(t, configV10.Ingress.Rules[0].Matches("blogs.tunnel.io", "/2022/02/10")) + require.Equal(t, ingress.HelloWorldService, configV10.Ingress.Rules[0].Service.String()) + require.False(t, configV10.WarpRoutingEnabled) + + originProxyV10, err := orchestrator.GetOriginProxy() + require.NoError(t, err) + require.IsType(t, &proxy.Proxy{}, originProxyV10) + require.NotEqual(t, originProxyV10, originProxyV2) +} + +// TestConcurrentUpdateAndRead makes sure orchestrator can receive updates and return origin proxy concurrently +func TestConcurrentUpdateAndRead(t *testing.T) { + const ( + concurrentRequests = 200 + hostname = "public.tunnels.org" + expectedHost = "internal.tunnels.svc.cluster.local" + tcpBody = "testProxyTCP" + ) + + httpOrigin := httptest.NewServer(&validateHostHandler{ + expectedHost: expectedHost, + body: t.Name(), + }) + defer httpOrigin.Close() + + tcpOrigin, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer tcpOrigin.Close() + + var ( + configJSONV1 = []byte(fmt.Sprintf(` +{ + "originRequest": { + "connectTimeout": 90000000000, + "noHappyEyeballs": true + }, + "ingress": [ + { + "hostname": "%s", + "service": "%s", + "originRequest": { + "httpHostHeader": "%s", + "connectTimeout": 10000000000 + } + }, + { + "service": "http_status:404" + } + ], + "warp-routing": { + "enabled": true + } +} +`, hostname, httpOrigin.URL, expectedHost)) + configJSONV2 = []byte(` +{ + "ingress": [ + { + "service": "http_status:204" + } + ], + "warp-routing": { + "enabled": false + } +} +`) + + configJSONV3 = []byte(` +{ + "ingress": [ + { + "service": "http_status:418" + } + ], + "warp-routing": { + "enabled": true + } +} +`) + + // appliedV2 makes sure v3 is applied after v2 + appliedV2 = make(chan struct{}) + + initConfig = &Config{ + Ingress: &ingress.Ingress{}, + WarpRoutingEnabled: false, + } + ) + + orchestrator, err := NewOrchestrator(context.Background(), initConfig, testTags, &testLogger) + require.NoError(t, err) + + updateWithValidation(t, orchestrator, 1, configJSONV1) + + var wg sync.WaitGroup + // tcpOrigin will be closed when the test exits. Only the handler routines are included in the wait group + go func() { + serveTCPOrigin(t, tcpOrigin, &wg) + }() + for i := 0; i < concurrentRequests; i++ { + originProxy, err := orchestrator.GetOriginProxy() + require.NoError(t, err) + wg.Add(1) + go func(i int, originProxy connection.OriginProxy) { + defer wg.Done() + resp, err := proxyHTTP(t, originProxy, hostname) + require.NoError(t, err) + + var warpRoutingDisabled bool + // The response can be from initOrigin, http_status:204 or http_status:418 + switch resp.StatusCode { + // v1 proxy, warp enabled + case 200: + body, err := ioutil.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, t.Name(), string(body)) + warpRoutingDisabled = false + // v2 proxy, warp disabled + case 204: + require.Greater(t, i, concurrentRequests/4) + warpRoutingDisabled = true + // v3 proxy, warp enabled + case 418: + require.Greater(t, i, concurrentRequests/2) + warpRoutingDisabled = false + } + + // Once we have originProxy, it won't be changed by configuration updates. + // We can infer the version by the ProxyHTTP response code + pr, pw := io.Pipe() + // concurrentRespWriter makes sure ResponseRecorder is not read/write concurrently, and read waits for the first write + w := newRespReadWriteFlusher() + + // Write TCP message and make sure it's echo back. This has to be done in a go routune since ProxyTCP doesn't + // return until the stream is closed. + if !warpRoutingDisabled { + wg.Add(1) + go func() { + defer wg.Done() + defer pw.Close() + tcpEyeball(t, pw, tcpBody, w) + }() + } + proxyTCP(t, originProxy, tcpOrigin.Addr().String(), w, pr, warpRoutingDisabled) + }(i, originProxy) + + if i == concurrentRequests/4 { + wg.Add(1) + go func() { + defer wg.Done() + updateWithValidation(t, orchestrator, 2, configJSONV2) + close(appliedV2) + }() + } + + if i == concurrentRequests/2 { + wg.Add(1) + go func() { + defer wg.Done() + <-appliedV2 + updateWithValidation(t, orchestrator, 3, configJSONV3) + }() + } + } + + wg.Wait() +} + +func proxyHTTP(t *testing.T, originProxy connection.OriginProxy, hostname string) (*http.Response, error) { + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", hostname), nil) + require.NoError(t, err) + + w := httptest.NewRecorder() + respWriter, err := connection.NewHTTP2RespWriter(req, w, connection.TypeHTTP) + require.NoError(t, err) + + err = originProxy.ProxyHTTP(respWriter, req, false) + if err != nil { + return nil, err + } + + return w.Result(), nil +} + +func tcpEyeball(t *testing.T, reqWriter io.WriteCloser, body string, respReadWriter *respReadWriteFlusher) { + writeN, err := reqWriter.Write([]byte(body)) + require.NoError(t, err) + + readBuffer := make([]byte, writeN) + n, err := respReadWriter.Read(readBuffer) + require.NoError(t, err) + require.Equal(t, body, string(readBuffer[:n])) + require.Equal(t, writeN, n) +} + +func proxyTCP(t *testing.T, originProxy connection.OriginProxy, originAddr string, w http.ResponseWriter, reqBody io.ReadCloser, expectErr bool) { + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", originAddr), reqBody) + require.NoError(t, err) + + respWriter, err := connection.NewHTTP2RespWriter(req, w, connection.TypeTCP) + require.NoError(t, err) + + tcpReq := &connection.TCPRequest{ + Dest: originAddr, + CFRay: "123", + LBProbe: false, + } + rws := connection.NewHTTPResponseReadWriterAcker(respWriter, req) + if expectErr { + require.Error(t, originProxy.ProxyTCP(context.Background(), rws, tcpReq)) + return + } + + require.NoError(t, originProxy.ProxyTCP(context.Background(), rws, tcpReq)) +} + +func serveTCPOrigin(t *testing.T, tcpOrigin net.Listener, wg *sync.WaitGroup) { + for { + conn, err := tcpOrigin.Accept() + if err != nil { + return + } + wg.Add(1) + go func() { + defer wg.Done() + defer conn.Close() + + echoTCP(t, conn) + }() + } +} + +func echoTCP(t *testing.T, conn net.Conn) { + readBuf := make([]byte, 1000) + readN, err := conn.Read(readBuf) + require.NoError(t, err) + + writeN, err := conn.Write(readBuf[:readN]) + require.NoError(t, err) + require.Equal(t, readN, writeN) +} + +type validateHostHandler struct { + expectedHost string + body string +} + +func (vhh *validateHostHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.Host != vhh.expectedHost { + w.WriteHeader(http.StatusBadRequest) + return + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(vhh.body)) +} + +func updateWithValidation(t *testing.T, orchestrator *Orchestrator, version int32, config []byte) { + resp := orchestrator.UpdateConfig(version, config) + require.NoError(t, resp.Err) + require.Equal(t, version, resp.LastAppliedVersion) +} + +// TestClosePreviousProxies makes sure proxies started in the pervious configuration version are shutdown +func TestClosePreviousProxies(t *testing.T) { + var ( + hostname = "hello.tunnel1.org" + configWithHelloWorld = []byte(fmt.Sprintf(` +{ + "ingress": [ + { + "hostname": "%s", + "service": "hello-world" + }, + { + "service": "http_status:404" + } + ], + "warp-routing": { + "enabled": true + } +} +`, hostname)) + + configTeapot = []byte(` +{ + "ingress": [ + { + "service": "http_status:418" + } + ], + "warp-routing": { + "enabled": true + } +} +`) + initConfig = &Config{ + Ingress: &ingress.Ingress{}, + WarpRoutingEnabled: false, + } + ) + + ctx, cancel := context.WithCancel(context.Background()) + orchestrator, err := NewOrchestrator(ctx, initConfig, testTags, &testLogger) + require.NoError(t, err) + + updateWithValidation(t, orchestrator, 1, configWithHelloWorld) + + originProxyV1, err := orchestrator.GetOriginProxy() + require.NoError(t, err) + resp, err := proxyHTTP(t, originProxyV1, hostname) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + + updateWithValidation(t, orchestrator, 2, configTeapot) + + originProxyV2, err := orchestrator.GetOriginProxy() + require.NoError(t, err) + resp, err = proxyHTTP(t, originProxyV2, hostname) + require.NoError(t, err) + require.Equal(t, http.StatusTeapot, resp.StatusCode) + + // The hello-world server in config v1 should have been stopped + resp, err = proxyHTTP(t, originProxyV1, hostname) + require.Error(t, err) + require.Nil(t, resp) + + // Apply the config with hello world server again, orchestrator should spin up another hello world server + updateWithValidation(t, orchestrator, 3, configWithHelloWorld) + + originProxyV3, err := orchestrator.GetOriginProxy() + require.NoError(t, err) + require.NotEqual(t, originProxyV1, originProxyV3) + + resp, err = proxyHTTP(t, originProxyV3, hostname) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + + // cancel the context should terminate the last proxy + cancel() + // Wait for proxies to shutdown + time.Sleep(time.Millisecond * 10) + + resp, err = proxyHTTP(t, originProxyV3, hostname) + require.Error(t, err) + require.Nil(t, resp) +} + +// TestPersistentConnection makes sure updating the ingress doesn't intefere with existing connections +func TestPersistentConnection(t *testing.T) { + const ( + hostname = "http://ws.tunnel.org" + ) + msg := t.Name() + initConfig := &Config{ + Ingress: &ingress.Ingress{}, + WarpRoutingEnabled: false, + } + orchestrator, err := NewOrchestrator(context.Background(), initConfig, testTags, &testLogger) + require.NoError(t, err) + + wsOrigin := httptest.NewServer(http.HandlerFunc(wsEcho)) + defer wsOrigin.Close() + + tcpOrigin, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer tcpOrigin.Close() + + configWithWSAndWarp := []byte(fmt.Sprintf(` +{ + "ingress": [ + { + "service": "%s" + } + ], + "warp-routing": { + "enabled": true + } +} +`, wsOrigin.URL)) + + updateWithValidation(t, orchestrator, 1, configWithWSAndWarp) + + originProxy, err := orchestrator.GetOriginProxy() + require.NoError(t, err) + + wsReqReader, wsReqWriter := io.Pipe() + wsRespReadWriter := newRespReadWriteFlusher() + + tcpReqReader, tcpReqWriter := io.Pipe() + tcpRespReadWriter := newRespReadWriteFlusher() + + var wg sync.WaitGroup + wg.Add(3) + // Start TCP origin + go func() { + defer wg.Done() + conn, err := tcpOrigin.Accept() + require.NoError(t, err) + defer conn.Close() + + // Expect 3 TCP messages + for i := 0; i < 3; i++ { + echoTCP(t, conn) + } + }() + // Simulate cloudflared recieving a TCP connection + go func() { + defer wg.Done() + proxyTCP(t, originProxy, tcpOrigin.Addr().String(), tcpRespReadWriter, tcpReqReader, false) + }() + // Simulate cloudflared recieving a WS connection + go func() { + defer wg.Done() + + req, err := http.NewRequest(http.MethodGet, hostname, wsReqReader) + require.NoError(t, err) + // ProxyHTTP will add Connection, Upgrade and Sec-Websocket-Version headers + req.Header.Add("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") + + respWriter, err := connection.NewHTTP2RespWriter(req, wsRespReadWriter, connection.TypeWebsocket) + require.NoError(t, err) + + err = originProxy.ProxyHTTP(respWriter, req, true) + require.NoError(t, err) + }() + + // Simulate eyeball WS and TCP connections + validateWsEcho(t, msg, wsReqWriter, wsRespReadWriter) + tcpEyeball(t, tcpReqWriter, msg, tcpRespReadWriter) + + configNoWSAndWarp := []byte(` +{ + "ingress": [ + { + "service": "http_status:404" + } + ], + "warp-routing": { + "enabled": false + } +} +`) + + updateWithValidation(t, orchestrator, 2, configNoWSAndWarp) + // Make sure connection is still up + validateWsEcho(t, msg, wsReqWriter, wsRespReadWriter) + tcpEyeball(t, tcpReqWriter, msg, tcpRespReadWriter) + + updateWithValidation(t, orchestrator, 3, configWithWSAndWarp) + // Make sure connection is still up + validateWsEcho(t, msg, wsReqWriter, wsRespReadWriter) + tcpEyeball(t, tcpReqWriter, msg, tcpRespReadWriter) + + wsReqWriter.Close() + tcpReqWriter.Close() + wg.Wait() +} + +func wsEcho(w http.ResponseWriter, r *http.Request) { + upgrader := gows.Upgrader{} + + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer conn.Close() + for { + mt, message, err := conn.ReadMessage() + if err != nil { + fmt.Println("read message err", err) + break + } + err = conn.WriteMessage(mt, message) + if err != nil { + fmt.Println("write message err", err) + break + } + } +} + +func validateWsEcho(t *testing.T, msg string, reqWriter io.Writer, respReadWriter io.ReadWriter) { + err := wsutil.WriteClientText(reqWriter, []byte(msg)) + require.NoError(t, err) + + receivedMsg, err := wsutil.ReadServerText(respReadWriter) + require.NoError(t, err) + require.Equal(t, msg, string(receivedMsg)) +} + +type respReadWriteFlusher struct { + io.Reader + w io.Writer + headers http.Header + statusCode int + setStatusOnce sync.Once + hasStatus chan struct{} +} + +func newRespReadWriteFlusher() *respReadWriteFlusher { + pr, pw := io.Pipe() + return &respReadWriteFlusher{ + Reader: pr, + w: pw, + headers: make(http.Header), + hasStatus: make(chan struct{}), + } +} + +func (rrw *respReadWriteFlusher) Write(buf []byte) (int, error) { + rrw.WriteHeader(http.StatusOK) + return rrw.w.Write(buf) +} + +func (rrw *respReadWriteFlusher) Flush() {} + +func (rrw *respReadWriteFlusher) Header() http.Header { + return rrw.headers +} + +func (rrw *respReadWriteFlusher) WriteHeader(statusCode int) { + rrw.setStatusOnce.Do(func() { + rrw.statusCode = statusCode + close(rrw.hasStatus) + }) +} diff --git a/origin/metrics.go b/proxy/metrics.go similarity index 85% rename from origin/metrics.go rename to proxy/metrics.go index 1e54f271..e5406681 100644 --- a/origin/metrics.go +++ b/proxy/metrics.go @@ -1,4 +1,4 @@ -package origin +package proxy import ( "github.com/prometheus/client_golang/prometheus" @@ -43,14 +43,6 @@ var ( Help: "Count of error proxying to origin", }, ) - haConnections = prometheus.NewGauge( - prometheus.GaugeOpts{ - Namespace: connection.MetricsNamespace, - Subsystem: connection.TunnelSubsystem, - Name: "ha_connections", - Help: "Number of active ha connections", - }, - ) ) func init() { @@ -59,7 +51,6 @@ func init() { concurrentRequests, responseByCode, requestErrors, - haConnections, ) } diff --git a/origin/pool.go b/proxy/pool.go similarity index 96% rename from origin/pool.go rename to proxy/pool.go index 396a4a76..fe2cf4a5 100644 --- a/origin/pool.go +++ b/proxy/pool.go @@ -1,4 +1,4 @@ -package origin +package proxy import ( "sync" diff --git a/origin/proxy.go b/proxy/proxy.go similarity index 98% rename from origin/proxy.go rename to proxy/proxy.go index dca03746..5ec8743b 100644 --- a/origin/proxy.go +++ b/proxy/proxy.go @@ -1,4 +1,4 @@ -package origin +package proxy import ( "bufio" @@ -38,17 +38,22 @@ type Proxy struct { // NewOriginProxy returns a new instance of the Proxy struct. func NewOriginProxy( ingressRules ingress.Ingress, - warpRouting *ingress.WarpRoutingService, + warpRoutingEnabled bool, tags []tunnelpogs.Tag, log *zerolog.Logger, ) *Proxy { - return &Proxy{ + proxy := &Proxy{ ingressRules: ingressRules, - warpRouting: warpRouting, tags: tags, log: log, bufferPool: newBufferPool(512 * 1024), } + if warpRoutingEnabled { + proxy.warpRouting = ingress.NewWarpRoutingService() + log.Info().Msgf("Warp-routing is enabled") + } + + return proxy } // ProxyHTTP further depends on ingress rules to establish a connection with the origin service. This may be diff --git a/origin/proxy_posix_test.go b/proxy/proxy_posix_test.go similarity index 98% rename from origin/proxy_posix_test.go rename to proxy/proxy_posix_test.go index 1b649a43..40d070c7 100644 --- a/origin/proxy_posix_test.go +++ b/proxy/proxy_posix_test.go @@ -1,7 +1,7 @@ //go:build !windows // +build !windows -package origin +package proxy import ( "io/ioutil" diff --git a/origin/proxy_test.go b/proxy/proxy_test.go similarity index 96% rename from origin/proxy_test.go rename to proxy/proxy_test.go index e4184d7a..db8747f7 100644 --- a/origin/proxy_test.go +++ b/proxy/proxy_test.go @@ -1,4 +1,4 @@ -package origin +package proxy import ( "bytes" @@ -31,8 +31,7 @@ import ( ) var ( - testTags = []tunnelpogs.Tag{tunnelpogs.Tag{Name: "Name", Value: "value"}} - unusedWarpRoutingService = (*ingress.WarpRoutingService)(nil) + testTags = []tunnelpogs.Tag{tunnelpogs.Tag{Name: "Name", Value: "value"}} ) type mockHTTPRespWriter struct { @@ -131,17 +130,14 @@ func TestProxySingleOrigin(t *testing.T) { ingressRule, err := ingress.NewSingleOrigin(cliCtx, allowURLFromArgs) require.NoError(t, err) - var wg sync.WaitGroup - errC := make(chan error) - require.NoError(t, ingressRule.StartOrigins(&wg, &log, ctx.Done(), errC)) + require.NoError(t, ingressRule.StartOrigins(&log, ctx.Done())) - proxy := NewOriginProxy(ingressRule, unusedWarpRoutingService, testTags, &log) + proxy := NewOriginProxy(ingressRule, false, testTags, &log) t.Run("testProxyHTTP", testProxyHTTP(proxy)) t.Run("testProxyWebsocket", testProxyWebsocket(proxy)) t.Run("testProxySSE", testProxySSE(proxy)) t.Run("testProxySSEAllData", testProxySSEAllData(proxy)) cancel() - wg.Wait() } func testProxyHTTP(proxy connection.OriginProxy) func(t *testing.T) { @@ -341,11 +337,9 @@ func runIngressTestScenarios(t *testing.T, unvalidatedIngress []config.Unvalidat log := zerolog.Nop() ctx, cancel := context.WithCancel(context.Background()) - errC := make(chan error) - var wg sync.WaitGroup - require.NoError(t, ingress.StartOrigins(&wg, &log, ctx.Done(), errC)) + require.NoError(t, ingress.StartOrigins(&log, ctx.Done())) - proxy := NewOriginProxy(ingress, unusedWarpRoutingService, testTags, &log) + proxy := NewOriginProxy(ingress, false, testTags, &log) for _, test := range tests { responseWriter := newMockHTTPRespWriter() @@ -363,7 +357,6 @@ func runIngressTestScenarios(t *testing.T, unvalidatedIngress []config.Unvalidat } } cancel() - wg.Wait() } type mockAPI struct{} @@ -394,7 +387,7 @@ func TestProxyError(t *testing.T) { log := zerolog.Nop() - proxy := NewOriginProxy(ing, unusedWarpRoutingService, testTags, &log) + proxy := NewOriginProxy(ing, false, testTags, &log) responseWriter := newMockHTTPRespWriter() req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil) @@ -634,10 +627,9 @@ func TestConnections(t *testing.T) { test.args.originService(t, ln) ingressRule := createSingleIngressConfig(t, test.args.ingressServiceScheme+ln.Addr().String()) - var wg sync.WaitGroup - errC := make(chan error) - ingressRule.StartOrigins(&wg, logger, ctx.Done(), errC) - proxy := NewOriginProxy(ingressRule, test.args.warpRoutingService, testTags, logger) + ingressRule.StartOrigins(logger, ctx.Done()) + proxy := NewOriginProxy(ingressRule, true, testTags, logger) + proxy.warpRouting = test.args.warpRoutingService dest := ln.Addr().String() req, err := http.NewRequest( diff --git a/quic/quic_protocol.go b/quic/quic_protocol.go index bba808ee..1e0c6eef 100644 --- a/quic/quic_protocol.go +++ b/quic/quic_protocol.go @@ -17,8 +17,8 @@ import ( tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" ) -// The first 6 bytes of the stream is used to distinguish the type of stream. It ensures whoever performs a handshake does -// not write data before writing the metadata. +// ProtocolSignature defines the first 6 bytes of the stream, which is used to distinguish the type of stream. It +// ensures whoever performs a handshake does not write data before writing the metadata. type ProtocolSignature [6]byte var ( @@ -29,12 +29,15 @@ var ( RPCStreamProtocolSignature = ProtocolSignature{0x52, 0xBB, 0x82, 0x5C, 0xDB, 0x65} ) -const protocolVersionLength = 2 - type protocolVersion string const ( protocolV1 protocolVersion = "01" + + protocolVersionLength = 2 + + HandshakeIdleTimeout = 5 * time.Second + MaxIdleTimeout = 15 * time.Second ) // RequestServerStream is a stream to serve requests @@ -122,7 +125,7 @@ func (rcs *RequestClientStream) ReadConnectResponseData() (*ConnectResponse, err return nil, err } if signature != DataStreamProtocolSignature { - return nil, fmt.Errorf("Wrong protocol signature %v", signature) + return nil, fmt.Errorf("wrong protocol signature %v", signature) } // This is a NO-OP for now. We could cause a branching if we wanted to use multiple versions. @@ -154,13 +157,13 @@ func NewRPCServerStream(stream io.ReadWriteCloser, protocol ProtocolSignature) ( return &RPCServerStream{stream}, nil } -func (s *RPCServerStream) Serve(sessionManager tunnelpogs.SessionManager, logger *zerolog.Logger) error { +func (s *RPCServerStream) Serve(sessionManager tunnelpogs.SessionManager, configManager tunnelpogs.ConfigurationManager, logger *zerolog.Logger) error { // RPC logs are very robust, create a new logger that only logs error to reduce noise rpcLogger := logger.Level(zerolog.ErrorLevel) rpcTransport := tunnelrpc.NewTransportLogger(&rpcLogger, rpc.StreamTransport(s)) defer rpcTransport.Close() - main := tunnelpogs.SessionManager_ServerToClient(sessionManager) + main := tunnelpogs.CloudflaredServer_ServerToClient(sessionManager, configManager) rpcConn := rpc.NewConn( rpcTransport, rpc.MainInterface(main.Client), @@ -220,7 +223,7 @@ func writeSignature(stream io.Writer, signature ProtocolSignature) error { // RPCClientStream is a stream to call methods of SessionManager type RPCClientStream struct { - client tunnelpogs.SessionManager_PogsClient + client tunnelpogs.CloudflaredServer_PogsClient transport rpc.Transport } @@ -238,7 +241,7 @@ func NewRPCClientStream(ctx context.Context, stream io.ReadWriteCloser, logger * tunnelrpc.ConnLog(logger), ) return &RPCClientStream{ - client: tunnelpogs.SessionManager_PogsClient{Client: conn.Bootstrap(ctx), Conn: conn}, + client: tunnelpogs.NewCloudflaredServer_PogsClient(conn.Bootstrap(ctx), conn), transport: transport, }, nil } @@ -255,6 +258,10 @@ func (rcs *RPCClientStream) UnregisterUdpSession(ctx context.Context, sessionID return rcs.client.UnregisterUdpSession(ctx, sessionID, message) } +func (rcs *RPCClientStream) UpdateConfiguration(ctx context.Context, version int32, config []byte) (*tunnelpogs.UpdateConfigurationResponse, error) { + return rcs.client.UpdateConfiguration(ctx, version, config) +} + func (rcs *RPCClientStream) Close() { _ = rcs.client.Close() _ = rcs.transport.Close() diff --git a/quic/quic_protocol_test.go b/quic/quic_protocol_test.go index f77359e5..a801b63e 100644 --- a/quic/quic_protocol_test.go +++ b/quic/quic_protocol_test.go @@ -14,6 +14,8 @@ import ( "github.com/rs/zerolog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" ) const ( @@ -108,14 +110,10 @@ func TestConnectResponseMeta(t *testing.T) { } func TestRegisterUdpSession(t *testing.T) { - clientReader, serverWriter := io.Pipe() - serverReader, clientWriter := io.Pipe() - - clientStream := mockRPCStream{clientReader, clientWriter} - serverStream := mockRPCStream{serverReader, serverWriter} + clientStream, serverStream := newMockRPCStreams() unregisterMessage := "closed by eyeball" - rpcServer := mockRPCServer{ + sessionRPCServer := mockSessionRPCServer{ sessionID: uuid.New(), dstIP: net.IP{172, 16, 0, 1}, dstPort: 8000, @@ -129,7 +127,7 @@ func TestRegisterUdpSession(t *testing.T) { assert.NoError(t, err) rpcServerStream, err := NewRPCServerStream(serverStream, protocol) assert.NoError(t, err) - err = rpcServerStream.Serve(rpcServer, &logger) + err = rpcServerStream.Serve(sessionRPCServer, nil, &logger) assert.NoError(t, err) serverStream.Close() @@ -139,12 +137,12 @@ func TestRegisterUdpSession(t *testing.T) { rpcClientStream, err := NewRPCClientStream(context.Background(), clientStream, &logger) assert.NoError(t, err) - assert.NoError(t, rpcClientStream.RegisterUdpSession(context.Background(), rpcServer.sessionID, rpcServer.dstIP, rpcServer.dstPort, testCloseIdleAfterHint)) + assert.NoError(t, rpcClientStream.RegisterUdpSession(context.Background(), sessionRPCServer.sessionID, sessionRPCServer.dstIP, sessionRPCServer.dstPort, testCloseIdleAfterHint)) // Different sessionID, the RPC server should reject the registraion - assert.Error(t, rpcClientStream.RegisterUdpSession(context.Background(), uuid.New(), rpcServer.dstIP, rpcServer.dstPort, testCloseIdleAfterHint)) + assert.Error(t, rpcClientStream.RegisterUdpSession(context.Background(), uuid.New(), sessionRPCServer.dstIP, sessionRPCServer.dstPort, testCloseIdleAfterHint)) - assert.NoError(t, rpcClientStream.UnregisterUdpSession(context.Background(), rpcServer.sessionID, unregisterMessage)) + assert.NoError(t, rpcClientStream.UnregisterUdpSession(context.Background(), sessionRPCServer.sessionID, unregisterMessage)) // Different sessionID, the RPC server should reject the unregistraion assert.Error(t, rpcClientStream.UnregisterUdpSession(context.Background(), uuid.New(), unregisterMessage)) @@ -153,7 +151,48 @@ func TestRegisterUdpSession(t *testing.T) { <-sessionRegisteredChan } -type mockRPCServer struct { +func TestManageConfiguration(t *testing.T) { + var ( + version int32 = 168 + config = []byte(t.Name()) + ) + clientStream, serverStream := newMockRPCStreams() + + configRPCServer := mockConfigRPCServer{ + version: version, + config: config, + } + + logger := zerolog.Nop() + updatedChan := make(chan struct{}) + go func() { + protocol, err := DetermineProtocol(serverStream) + assert.NoError(t, err) + rpcServerStream, err := NewRPCServerStream(serverStream, protocol) + assert.NoError(t, err) + err = rpcServerStream.Serve(nil, configRPCServer, &logger) + assert.NoError(t, err) + + serverStream.Close() + close(updatedChan) + }() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + rpcClientStream, err := NewRPCClientStream(ctx, clientStream, &logger) + assert.NoError(t, err) + + result, err := rpcClientStream.UpdateConfiguration(ctx, version, config) + assert.NoError(t, err) + + require.Equal(t, version, result.LastAppliedVersion) + require.NoError(t, result.Err) + + rpcClientStream.Close() + <-updatedChan +} + +type mockSessionRPCServer struct { sessionID uuid.UUID dstIP net.IP dstPort uint16 @@ -161,7 +200,7 @@ type mockRPCServer struct { unregisterMessage string } -func (s mockRPCServer) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeIdleAfter time.Duration) error { +func (s mockSessionRPCServer) RegisterUdpSession(_ context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeIdleAfter time.Duration) error { if s.sessionID != sessionID { return fmt.Errorf("expect session ID %s, got %s", s.sessionID, sessionID) } @@ -177,7 +216,7 @@ func (s mockRPCServer) RegisterUdpSession(ctx context.Context, sessionID uuid.UU return nil } -func (s mockRPCServer) UnregisterUdpSession(ctx context.Context, sessionID uuid.UUID, message string) error { +func (s mockSessionRPCServer) UnregisterUdpSession(_ context.Context, sessionID uuid.UUID, message string) error { if s.sessionID != sessionID { return fmt.Errorf("expect session ID %s, got %s", s.sessionID, sessionID) } @@ -187,11 +226,39 @@ func (s mockRPCServer) UnregisterUdpSession(ctx context.Context, sessionID uuid. return nil } +type mockConfigRPCServer struct { + version int32 + config []byte +} + +func (s mockConfigRPCServer) UpdateConfiguration(_ context.Context, version int32, config []byte) *tunnelpogs.UpdateConfigurationResponse { + if s.version != version { + return &tunnelpogs.UpdateConfigurationResponse{ + Err: fmt.Errorf("expect version %d, got %d", s.version, version), + } + } + if !bytes.Equal(s.config, config) { + return &tunnelpogs.UpdateConfigurationResponse{ + Err: fmt.Errorf("expect config %v, got %v", s.config, config), + } + } + return &tunnelpogs.UpdateConfigurationResponse{LastAppliedVersion: version} +} + type mockRPCStream struct { io.ReadCloser io.WriteCloser } +func newMockRPCStreams() (client io.ReadWriteCloser, server io.ReadWriteCloser) { + clientReader, serverWriter := io.Pipe() + serverReader, clientWriter := io.Pipe() + + client = mockRPCStream{clientReader, clientWriter} + server = mockRPCStream{serverReader, serverWriter} + return +} + func (s mockRPCStream) Close() error { _ = s.ReadCloser.Close() _ = s.WriteCloser.Close() diff --git a/quic/safe_stream.go b/quic/safe_stream.go new file mode 100644 index 00000000..12ba76f4 --- /dev/null +++ b/quic/safe_stream.go @@ -0,0 +1,43 @@ +package quic + +import ( + "sync" + "time" + + "github.com/lucas-clemente/quic-go" +) + +type SafeStreamCloser struct { + lock sync.Mutex + stream quic.Stream +} + +func NewSafeStreamCloser(stream quic.Stream) *SafeStreamCloser { + return &SafeStreamCloser{ + stream: stream, + } +} + +func (s *SafeStreamCloser) Read(p []byte) (n int, err error) { + return s.stream.Read(p) +} + +func (s *SafeStreamCloser) Write(p []byte) (n int, err error) { + s.lock.Lock() + defer s.lock.Unlock() + return s.stream.Write(p) +} + +func (s *SafeStreamCloser) Close() error { + // Make sure a possible writer does not block the lock forever. We need it, so we can close the writer + // side of the stream safely. + _ = s.stream.SetWriteDeadline(time.Now()) + + // This lock is eventually acquired despite Write also acquiring it, because we set a deadline to writes. + s.lock.Lock() + defer s.lock.Unlock() + + // We have to clean up the receiving stream ourselves since the Close in the bottom does not handle that. + s.stream.CancelRead(0) + return s.stream.Close() +} diff --git a/quic/safe_stream_test.go b/quic/safe_stream_test.go new file mode 100644 index 00000000..48ffb559 --- /dev/null +++ b/quic/safe_stream_test.go @@ -0,0 +1,142 @@ +package quic + +import ( + "context" + "crypto/tls" + "io" + "net" + "sync" + "testing" + + "github.com/lucas-clemente/quic-go" + "github.com/stretchr/testify/require" +) + +var ( + testTLSServerConfig = GenerateTLSConfig() + testQUICConfig = &quic.Config{ + KeepAlive: true, + EnableDatagrams: true, + } + exchanges = 1000 + msgsPerExchange = 10 + testMsg = "Ok message" +) + +func TestSafeStreamClose(t *testing.T) { + udpAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") + require.NoError(t, err) + udpListener, err := net.ListenUDP(udpAddr.Network(), udpAddr) + require.NoError(t, err) + defer udpListener.Close() + + var serverReady sync.WaitGroup + serverReady.Add(1) + + var done sync.WaitGroup + done.Add(1) + go func() { + defer done.Done() + quicServer(t, &serverReady, udpListener) + }() + + done.Add(1) + go func() { + serverReady.Wait() + defer done.Done() + quicClient(t, udpListener.LocalAddr()) + }() + + done.Wait() +} + +func quicClient(t *testing.T, addr net.Addr) { + tlsConf := &tls.Config{ + InsecureSkipVerify: true, + NextProtos: []string{"argotunnel"}, + } + session, err := quic.DialAddr(addr.String(), tlsConf, testQUICConfig) + require.NoError(t, err) + + var wg sync.WaitGroup + for exchange := 0; exchange < exchanges; exchange++ { + quicStream, err := session.AcceptStream(context.Background()) + require.NoError(t, err) + wg.Add(1) + + go func(iter int) { + defer wg.Done() + + stream := NewSafeStreamCloser(quicStream) + defer stream.Close() + + // Do a bunch of round trips over this stream that should work. + for msg := 0; msg < msgsPerExchange; msg++ { + clientRoundTrip(t, stream, true) + } + // And one that won't work necessarily, but shouldn't break other streams in the session. + if iter%2 == 0 { + clientRoundTrip(t, stream, false) + } + }(exchange) + } + + wg.Wait() +} + +func quicServer(t *testing.T, serverReady *sync.WaitGroup, conn net.PacketConn) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + earlyListener, err := quic.Listen(conn, testTLSServerConfig, testQUICConfig) + require.NoError(t, err) + + serverReady.Done() + session, err := earlyListener.Accept(ctx) + require.NoError(t, err) + + var wg sync.WaitGroup + for exchange := 0; exchange < exchanges; exchange++ { + quicStream, err := session.OpenStreamSync(context.Background()) + require.NoError(t, err) + wg.Add(1) + + go func(iter int) { + defer wg.Done() + + stream := NewSafeStreamCloser(quicStream) + defer stream.Close() + + // Do a bunch of round trips over this stream that should work. + for msg := 0; msg < msgsPerExchange; msg++ { + serverRoundTrip(t, stream, true) + } + // And one that won't work necessarily, but shouldn't break other streams in the session. + if iter%2 == 1 { + serverRoundTrip(t, stream, false) + } + }(exchange) + } + + wg.Wait() +} + +func clientRoundTrip(t *testing.T, stream io.ReadWriteCloser, mustWork bool) { + response := make([]byte, len(testMsg)) + _, err := stream.Read(response) + if !mustWork { + return + } + if err != io.EOF { + require.NoError(t, err) + } + require.Equal(t, testMsg, string(response)) +} + +func serverRoundTrip(t *testing.T, stream io.ReadWriteCloser, mustWork bool) { + _, err := stream.Write([]byte(testMsg)) + if !mustWork { + return + } + require.NoError(t, err) +} diff --git a/quic/test_utils.go b/quic/test_utils.go new file mode 100644 index 00000000..56c342f6 --- /dev/null +++ b/quic/test_utils.go @@ -0,0 +1,34 @@ +package quic + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "math/big" +) + +// GenerateTLSConfig sets up a bare-bones TLS config for a QUIC server +func GenerateTLSConfig() *tls.Config { + key, err := rsa.GenerateKey(rand.Reader, 1024) + if err != nil { + panic(err) + } + template := x509.Certificate{SerialNumber: big.NewInt(1)} + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key) + if err != nil { + panic(err) + } + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + + tlsCert, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + panic(err) + } + return &tls.Config{ + Certificates: []tls.Certificate{tlsCert}, + NextProtos: []string{"argotunnel"}, + } +} diff --git a/origin/cloudflare_status_page.go b/supervisor/cloudflare_status_page.go similarity index 99% rename from origin/cloudflare_status_page.go rename to supervisor/cloudflare_status_page.go index dfa9143a..93d9e849 100644 --- a/origin/cloudflare_status_page.go +++ b/supervisor/cloudflare_status_page.go @@ -1,4 +1,4 @@ -package origin +package supervisor import ( "encoding/json" diff --git a/origin/cloudflare_status_page_test.go b/supervisor/cloudflare_status_page_test.go similarity index 99% rename from origin/cloudflare_status_page_test.go rename to supervisor/cloudflare_status_page_test.go index 21985dcc..a86fb63f 100644 --- a/origin/cloudflare_status_page_test.go +++ b/supervisor/cloudflare_status_page_test.go @@ -1,4 +1,4 @@ -package origin +package supervisor import ( "testing" diff --git a/origin/conn_aware_logger.go b/supervisor/conn_aware_logger.go similarity index 97% rename from origin/conn_aware_logger.go rename to supervisor/conn_aware_logger.go index b8021121..6e717588 100644 --- a/origin/conn_aware_logger.go +++ b/supervisor/conn_aware_logger.go @@ -1,4 +1,4 @@ -package origin +package supervisor import ( "github.com/rs/zerolog" diff --git a/origin/external_control.go b/supervisor/external_control.go similarity index 95% rename from origin/external_control.go rename to supervisor/external_control.go index cd9ef364..f170cde2 100644 --- a/origin/external_control.go +++ b/supervisor/external_control.go @@ -1,4 +1,4 @@ -package origin +package supervisor import ( "time" diff --git a/supervisor/metrics.go b/supervisor/metrics.go new file mode 100644 index 00000000..e6d50cdd --- /dev/null +++ b/supervisor/metrics.go @@ -0,0 +1,27 @@ +package supervisor + +import ( + "github.com/prometheus/client_golang/prometheus" + + "github.com/cloudflare/cloudflared/connection" +) + +// Metrics uses connection.MetricsNamespace(aka cloudflared) as namespace and connection.TunnelSubsystem +// (tunnel) as subsystem to keep them consistent with the previous qualifier. + +var ( + haConnections = prometheus.NewGauge( + prometheus.GaugeOpts{ + Namespace: connection.MetricsNamespace, + Subsystem: connection.TunnelSubsystem, + Name: "ha_connections", + Help: "Number of active ha connections", + }, + ) +) + +func init() { + prometheus.MustRegister( + haConnections, + ) +} diff --git a/origin/reconnect.go b/supervisor/reconnect.go similarity index 99% rename from origin/reconnect.go rename to supervisor/reconnect.go index 8b43977b..040c2714 100644 --- a/origin/reconnect.go +++ b/supervisor/reconnect.go @@ -1,4 +1,4 @@ -package origin +package supervisor import ( "context" diff --git a/origin/reconnect_test.go b/supervisor/reconnect_test.go similarity index 99% rename from origin/reconnect_test.go rename to supervisor/reconnect_test.go index fb2a1df9..593d16d1 100644 --- a/origin/reconnect_test.go +++ b/supervisor/reconnect_test.go @@ -1,4 +1,4 @@ -package origin +package supervisor import ( "context" diff --git a/origin/supervisor.go b/supervisor/supervisor.go similarity index 96% rename from origin/supervisor.go rename to supervisor/supervisor.go index 304fc3c2..f1661bf3 100644 --- a/origin/supervisor.go +++ b/supervisor/supervisor.go @@ -1,4 +1,4 @@ -package origin +package supervisor import ( "context" @@ -13,6 +13,7 @@ import ( "github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/edgediscovery/allregions" "github.com/cloudflare/cloudflared/h2mux" + "github.com/cloudflare/cloudflared/orchestration" "github.com/cloudflare/cloudflared/retry" "github.com/cloudflare/cloudflared/signal" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" @@ -38,6 +39,7 @@ const ( type Supervisor struct { cloudflaredUUID uuid.UUID config *TunnelConfig + orchestrator *orchestration.Orchestrator edgeIPs *edgediscovery.Edge tunnelErrors chan tunnelError tunnelsConnecting map[int]chan struct{} @@ -64,7 +66,7 @@ type tunnelError struct { err error } -func NewSupervisor(config *TunnelConfig, reconnectCh chan ReconnectSignal, gracefulShutdownC <-chan struct{}) (*Supervisor, error) { +func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrator, reconnectCh chan ReconnectSignal, gracefulShutdownC <-chan struct{}) (*Supervisor, error) { cloudflaredUUID, err := uuid.NewRandom() if err != nil { return nil, fmt.Errorf("failed to generate cloudflared instance ID: %w", err) @@ -88,6 +90,7 @@ func NewSupervisor(config *TunnelConfig, reconnectCh chan ReconnectSignal, grace return &Supervisor{ cloudflaredUUID: cloudflaredUUID, config: config, + orchestrator: orchestrator, edgeIPs: edgeIPs, tunnelErrors: make(chan tunnelError), tunnelsConnecting: map[int]chan struct{}{}, @@ -243,6 +246,7 @@ func (s *Supervisor) startFirstTunnel( ctx, s.reconnectCredentialManager, s.config, + s.orchestrator, addr, s.log, firstConnIndex, @@ -277,6 +281,7 @@ func (s *Supervisor) startFirstTunnel( ctx, s.reconnectCredentialManager, s.config, + s.orchestrator, addr, s.log, firstConnIndex, @@ -311,6 +316,7 @@ func (s *Supervisor) startTunnel( ctx, s.reconnectCredentialManager, s.config, + s.orchestrator, addr, s.log, uint8(index), @@ -380,7 +386,7 @@ func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int) defer rpcClient.Close() const arbitraryConnectionID = uint8(0) - registrationOptions := s.config.RegistrationOptions(arbitraryConnectionID, edgeConn.LocalAddr().String(), s.cloudflaredUUID) + registrationOptions := s.config.registrationOptions(arbitraryConnectionID, edgeConn.LocalAddr().String(), s.cloudflaredUUID) registrationOptions.NumPreviousAttempts = uint8(numPreviousAttempts) return rpcClient.Authenticate(ctx, s.config.ClassicTunnel, registrationOptions) } diff --git a/origin/tunnel.go b/supervisor/tunnel.go similarity index 89% rename from origin/tunnel.go rename to supervisor/tunnel.go index bcfea720..bca5e081 100644 --- a/origin/tunnel.go +++ b/supervisor/tunnel.go @@ -1,4 +1,4 @@ -package origin +package supervisor import ( "context" @@ -20,6 +20,7 @@ import ( "github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/edgediscovery/allregions" "github.com/cloudflare/cloudflared/h2mux" + "github.com/cloudflare/cloudflared/orchestration" quicpogs "github.com/cloudflare/cloudflared/quic" "github.com/cloudflare/cloudflared/retry" "github.com/cloudflare/cloudflared/signal" @@ -31,37 +32,36 @@ const ( dialTimeout = 15 * time.Second FeatureSerializedHeaders = "serialized_headers" FeatureQuickReconnects = "quick_reconnects" - quicHandshakeIdleTimeout = 5 * time.Second - quicMaxIdleTimeout = 15 * time.Second ) type TunnelConfig struct { - ConnectionConfig *connection.Config - OSArch string - ClientID string - CloseConnOnce *sync.Once // Used to close connectedSignal no more than once - EdgeAddrs []string - Region string - HAConnections int - IncidentLookup IncidentLookup - IsAutoupdated bool - LBPool string - Tags []tunnelpogs.Tag - Log *zerolog.Logger - LogTransport *zerolog.Logger - Observer *connection.Observer - ReportedVersion string - Retries uint - RunFromTerminal bool + GracePeriod time.Duration + ReplaceExisting bool + OSArch string + ClientID string + CloseConnOnce *sync.Once // Used to close connectedSignal no more than once + EdgeAddrs []string + Region string + HAConnections int + IncidentLookup IncidentLookup + IsAutoupdated bool + LBPool string + Tags []tunnelpogs.Tag + Log *zerolog.Logger + LogTransport *zerolog.Logger + Observer *connection.Observer + ReportedVersion string + Retries uint + RunFromTerminal bool - NamedTunnel *connection.NamedTunnelConfig - ClassicTunnel *connection.ClassicTunnelConfig + NamedTunnel *connection.NamedTunnelProperties + ClassicTunnel *connection.ClassicTunnelProperties MuxerConfig *connection.MuxerConfig ProtocolSelector connection.ProtocolSelector EdgeTLSConfigs map[connection.Protocol]*tls.Config } -func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP string, uuid uuid.UUID) *tunnelpogs.RegistrationOptions { +func (c *TunnelConfig) registrationOptions(connectionID uint8, OriginLocalIP string, uuid uuid.UUID) *tunnelpogs.RegistrationOptions { policy := tunnelrpc.ExistingTunnelPolicy_balance if c.HAConnections <= 1 && c.LBPool == "" { policy = tunnelrpc.ExistingTunnelPolicy_disconnect @@ -83,7 +83,7 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str } } -func (c *TunnelConfig) ConnectionOptions(originLocalAddr string, numPreviousAttempts uint8) *tunnelpogs.ConnectionOptions { +func (c *TunnelConfig) connectionOptions(originLocalAddr string, numPreviousAttempts uint8) *tunnelpogs.ConnectionOptions { // attempt to parse out origin IP, but don't fail since it's informational field host, _, _ := net.SplitHostPort(originLocalAddr) originIP := net.ParseIP(host) @@ -91,7 +91,7 @@ func (c *TunnelConfig) ConnectionOptions(originLocalAddr string, numPreviousAtte return &tunnelpogs.ConnectionOptions{ Client: c.NamedTunnel.Client, OriginLocalIP: originIP, - ReplaceExisting: c.ConnectionConfig.ReplaceExisting, + ReplaceExisting: c.ReplaceExisting, CompressionQuality: uint8(c.MuxerConfig.CompressionSetting), NumPreviousAttempts: numPreviousAttempts, } @@ -108,11 +108,12 @@ func (c *TunnelConfig) SupportedFeatures() []string { func StartTunnelDaemon( ctx context.Context, config *TunnelConfig, + orchestrator *orchestration.Orchestrator, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal, graceShutdownC <-chan struct{}, ) error { - s, err := NewSupervisor(config, reconnectCh, graceShutdownC) + s, err := NewSupervisor(config, orchestrator, reconnectCh, graceShutdownC) if err != nil { return err } @@ -123,6 +124,7 @@ func ServeTunnelLoop( ctx context.Context, credentialManager *reconnectCredentialManager, config *TunnelConfig, + orchestrator *orchestration.Orchestrator, addr *allregions.EdgeAddr, connAwareLogger *ConnAwareLogger, connIndex uint8, @@ -158,6 +160,7 @@ func ServeTunnelLoop( connLog, credentialManager, config, + orchestrator, addr, connIndex, connectedFuse, @@ -256,6 +259,7 @@ func ServeTunnel( connLog *ConnAwareLogger, credentialManager *reconnectCredentialManager, config *TunnelConfig, + orchestrator *orchestration.Orchestrator, addr *allregions.EdgeAddr, connIndex uint8, fuse *h2mux.BooleanFuse, @@ -284,6 +288,7 @@ func ServeTunnel( connLog, credentialManager, config, + orchestrator, addr, connIndex, fuse, @@ -332,6 +337,7 @@ func serveTunnel( connLog *ConnAwareLogger, credentialManager *reconnectCredentialManager, config *TunnelConfig, + orchestrator *orchestration.Orchestrator, addr *allregions.EdgeAddr, connIndex uint8, fuse *h2mux.BooleanFuse, @@ -341,7 +347,6 @@ func serveTunnel( protocol connection.Protocol, gracefulShutdownC <-chan struct{}, ) (err error, recoverable bool) { - connectedFuse := &connectedFuse{ fuse: fuse, backoff: backoff, @@ -353,15 +358,16 @@ func serveTunnel( connIndex, nil, gracefulShutdownC, - config.ConnectionConfig.GracePeriod, + config.GracePeriod, ) switch protocol { case connection.QUIC, connection.QUICWarp: - connOptions := config.ConnectionOptions(addr.UDP.String(), uint8(backoff.Retries())) + connOptions := config.connectionOptions(addr.UDP.String(), uint8(backoff.Retries())) return ServeQUIC(ctx, addr.UDP, config, + orchestrator, connLog, connOptions, controlStream, @@ -376,11 +382,12 @@ func serveTunnel( return err, true } - connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.Retries())) + connOptions := config.connectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.Retries())) if err := ServeHTTP2( ctx, connLog, config, + orchestrator, edgeConn, connOptions, controlStream, @@ -403,6 +410,7 @@ func serveTunnel( connLog, credentialManager, config, + orchestrator, edgeConn, connIndex, connectedFuse, @@ -429,6 +437,7 @@ func ServeH2mux( connLog *ConnAwareLogger, credentialManager *reconnectCredentialManager, config *TunnelConfig, + orchestrator *orchestration.Orchestrator, edgeConn net.Conn, connIndex uint8, connectedFuse *connectedFuse, @@ -439,7 +448,8 @@ func ServeH2mux( connLog.Logger().Debug().Msgf("Connecting via h2mux") // Returns error from parsing the origin URL or handshake errors handler, err, recoverable := connection.NewH2muxConnection( - config.ConnectionConfig, + orchestrator, + config.GracePeriod, config.MuxerConfig, edgeConn, connIndex, @@ -457,10 +467,10 @@ func ServeH2mux( errGroup.Go(func() error { if config.NamedTunnel != nil { - connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.Retries())) + connOptions := config.connectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.Retries())) return handler.ServeNamedTunnel(serveCtx, config.NamedTunnel, connOptions, connectedFuse) } - registrationOptions := config.RegistrationOptions(connIndex, edgeConn.LocalAddr().String(), cloudflaredUUID) + registrationOptions := config.registrationOptions(connIndex, edgeConn.LocalAddr().String(), cloudflaredUUID) return handler.ServeClassicTunnel(serveCtx, config.ClassicTunnel, credentialManager, registrationOptions, connectedFuse) }) @@ -475,6 +485,7 @@ func ServeHTTP2( ctx context.Context, connLog *ConnAwareLogger, config *TunnelConfig, + orchestrator *orchestration.Orchestrator, tlsServerConn net.Conn, connOptions *tunnelpogs.ConnectionOptions, controlStreamHandler connection.ControlStreamHandler, @@ -485,7 +496,7 @@ func ServeHTTP2( connLog.Logger().Debug().Msgf("Connecting via http2") h2conn := connection.NewHTTP2Connection( tlsServerConn, - config.ConnectionConfig, + orchestrator, connOptions, config.Observer, connIndex, @@ -514,6 +525,7 @@ func ServeQUIC( ctx context.Context, edgeAddr *net.UDPAddr, config *TunnelConfig, + orchestrator *orchestration.Orchestrator, connLogger *ConnAwareLogger, connOptions *tunnelpogs.ConnectionOptions, controlStreamHandler connection.ControlStreamHandler, @@ -523,8 +535,8 @@ func ServeQUIC( ) (err error, recoverable bool) { tlsConfig := config.EdgeTLSConfigs[connection.QUIC] quicConfig := &quic.Config{ - HandshakeIdleTimeout: quicHandshakeIdleTimeout, - MaxIdleTimeout: quicMaxIdleTimeout, + HandshakeIdleTimeout: quicpogs.HandshakeIdleTimeout, + MaxIdleTimeout: quicpogs.MaxIdleTimeout, MaxIncomingStreams: connection.MaxConcurrentStreams, MaxIncomingUniStreams: connection.MaxConcurrentStreams, KeepAlive: true, @@ -537,7 +549,7 @@ func ServeQUIC( quicConfig, edgeAddr, tlsConfig, - config.ConnectionConfig.OriginProxy, + orchestrator, connOptions, controlStreamHandler, connLogger.Logger()) diff --git a/origin/tunnel_test.go b/supervisor/tunnel_test.go similarity index 96% rename from origin/tunnel_test.go rename to supervisor/tunnel_test.go index 870a5049..2e646089 100644 --- a/origin/tunnel_test.go +++ b/supervisor/tunnel_test.go @@ -1,4 +1,4 @@ -package origin +package supervisor import ( "testing" @@ -32,11 +32,7 @@ func TestWaitForBackoffFallback(t *testing.T) { } log := zerolog.Nop() resolveTTL := time.Duration(0) - namedTunnel := &connection.NamedTunnelConfig{ - Credentials: connection.Credentials{ - AccountTag: "test-account", - }, - } + namedTunnel := &connection.NamedTunnelProperties{} mockFetcher := dynamicMockFetcher{ protocolPercents: edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}}, } diff --git a/origin/tunnelsforha.go b/supervisor/tunnelsforha.go similarity index 98% rename from origin/tunnelsforha.go rename to supervisor/tunnelsforha.go index 61673737..80704e38 100644 --- a/origin/tunnelsforha.go +++ b/supervisor/tunnelsforha.go @@ -1,4 +1,4 @@ -package origin +package supervisor import ( "fmt" diff --git a/tunnelrpc/pogs/cloudflaredrpc.go b/tunnelrpc/pogs/cloudflaredrpc.go new file mode 100644 index 00000000..208f1004 --- /dev/null +++ b/tunnelrpc/pogs/cloudflaredrpc.go @@ -0,0 +1,53 @@ +package pogs + +import ( + "github.com/cloudflare/cloudflared/tunnelrpc" + capnp "zombiezen.com/go/capnproto2" + "zombiezen.com/go/capnproto2/rpc" +) + +type CloudflaredServer interface { + SessionManager + ConfigurationManager +} + +type CloudflaredServer_PogsImpl struct { + SessionManager_PogsImpl + ConfigurationManager_PogsImpl +} + +func CloudflaredServer_ServerToClient(s SessionManager, c ConfigurationManager) tunnelrpc.CloudflaredServer { + return tunnelrpc.CloudflaredServer_ServerToClient(CloudflaredServer_PogsImpl{ + SessionManager_PogsImpl: SessionManager_PogsImpl{s}, + ConfigurationManager_PogsImpl: ConfigurationManager_PogsImpl{c}, + }) +} + +type CloudflaredServer_PogsClient struct { + SessionManager_PogsClient + ConfigurationManager_PogsClient + Client capnp.Client + Conn *rpc.Conn +} + +func NewCloudflaredServer_PogsClient(client capnp.Client, conn *rpc.Conn) CloudflaredServer_PogsClient { + sessionManagerClient := SessionManager_PogsClient{ + Client: client, + Conn: conn, + } + configManagerClient := ConfigurationManager_PogsClient{ + Client: client, + Conn: conn, + } + return CloudflaredServer_PogsClient{ + SessionManager_PogsClient: sessionManagerClient, + ConfigurationManager_PogsClient: configManagerClient, + Client: client, + Conn: conn, + } +} + +func (c CloudflaredServer_PogsClient) Close() error { + c.Client.Close() + return c.Conn.Close() +} diff --git a/tunnelrpc/pogs/configurationrpc.go b/tunnelrpc/pogs/configurationrpc.go new file mode 100644 index 00000000..b82cdd53 --- /dev/null +++ b/tunnelrpc/pogs/configurationrpc.go @@ -0,0 +1,95 @@ +package pogs + +import ( + "context" + "fmt" + + "github.com/cloudflare/cloudflared/tunnelrpc" + capnp "zombiezen.com/go/capnproto2" + "zombiezen.com/go/capnproto2/rpc" + "zombiezen.com/go/capnproto2/server" +) + +type ConfigurationManager interface { + UpdateConfiguration(ctx context.Context, version int32, config []byte) *UpdateConfigurationResponse +} + +type ConfigurationManager_PogsImpl struct { + impl ConfigurationManager +} + +func ConfigurationManager_ServerToClient(c ConfigurationManager) tunnelrpc.ConfigurationManager { + return tunnelrpc.ConfigurationManager_ServerToClient(ConfigurationManager_PogsImpl{c}) +} + +func (i ConfigurationManager_PogsImpl) UpdateConfiguration(p tunnelrpc.ConfigurationManager_updateConfiguration) error { + server.Ack(p.Options) + + version := p.Params.Version() + config, err := p.Params.Config() + if err != nil { + return err + } + + result, err := p.Results.NewResult() + if err != nil { + return err + } + + updateResp := i.impl.UpdateConfiguration(p.Ctx, version, config) + return updateResp.Marshal(result) +} + +type ConfigurationManager_PogsClient struct { + Client capnp.Client + Conn *rpc.Conn +} + +func (c ConfigurationManager_PogsClient) Close() error { + c.Client.Close() + return c.Conn.Close() +} + +func (c ConfigurationManager_PogsClient) UpdateConfiguration(ctx context.Context, version int32, config []byte) (*UpdateConfigurationResponse, error) { + client := tunnelrpc.ConfigurationManager{Client: c.Client} + promise := client.UpdateConfiguration(ctx, func(p tunnelrpc.ConfigurationManager_updateConfiguration_Params) error { + p.SetVersion(version) + return p.SetConfig(config) + }) + result, err := promise.Result().Struct() + if err != nil { + return nil, wrapRPCError(err) + } + response := new(UpdateConfigurationResponse) + + err = response.Unmarshal(result) + if err != nil { + return nil, err + } + return response, nil +} + +type UpdateConfigurationResponse struct { + LastAppliedVersion int32 `json:"lastAppliedVersion"` + Err error `json:"err"` +} + +func (p *UpdateConfigurationResponse) Marshal(s tunnelrpc.UpdateConfigurationResponse) error { + s.SetLatestAppliedVersion(p.LastAppliedVersion) + if p.Err != nil { + return s.SetErr(p.Err.Error()) + } + return nil +} + +func (p *UpdateConfigurationResponse) Unmarshal(s tunnelrpc.UpdateConfigurationResponse) error { + p.LastAppliedVersion = s.LatestAppliedVersion() + respErr, err := s.Err() + if err != nil { + return err + } + if respErr != "" { + p.Err = fmt.Errorf(respErr) + } + return nil +} diff --git a/tunnelrpc/tunnelrpc.capnp b/tunnelrpc/tunnelrpc.capnp index fd840d9c..8d91f6d1 100644 --- a/tunnelrpc/tunnelrpc.capnp +++ b/tunnelrpc/tunnelrpc.capnp @@ -151,4 +151,19 @@ interface SessionManager { # Let the edge decide closeAfterIdle to make sure cloudflared doesn't close session before the edge closes its side registerUdpSession @0 (sessionId :Data, dstIp :Data, dstPort: UInt16, closeAfterIdleHint: Int64) -> (result :RegisterUdpSessionResponse); unregisterUdpSession @1 (sessionId :Data, message: Text) -> (); -} \ No newline at end of file +} + +struct UpdateConfigurationResponse { + # Latest configuration that was applied successfully. The err field might be populated at the same time to indicate + # that cloudflared is using an older configuration because the latest cannot be applied + latestAppliedVersion @0 :Int32; + # Any error encountered when trying to apply the last configuration + err @1 :Text; +} + +# ConfigurationManager defines RPC to manage cloudflared configuration remotely +interface ConfigurationManager { + updateConfiguration @0 (version :Int32, config :Data) -> (result: UpdateConfigurationResponse); +} + +interface CloudflaredServer extends(SessionManager, ConfigurationManager) {} \ No newline at end of file diff --git a/tunnelrpc/tunnelrpc.capnp.go b/tunnelrpc/tunnelrpc.capnp.go index 89606120..ac1e152a 100644 --- a/tunnelrpc/tunnelrpc.capnp.go +++ b/tunnelrpc/tunnelrpc.capnp.go @@ -3880,204 +3880,661 @@ func (p SessionManager_unregisterUdpSession_Results_Promise) Struct() (SessionMa return SessionManager_unregisterUdpSession_Results{s}, err } -const schema_db8274f9144abc7e = "x\xda\xccY}p\x14e\x9a\x7f\x9e\xee\x99t\x02\x19" + - "f\xbaz 0%\x97\x93\xc2\xf2\x88\x82\x06\xce+\x8e" + - "\xb3.\x09\x06\xceD>\xd23p\xe5\x09Zvf\xde" + - "\x84\xc9\xcdt\x0f\xdd=\x91 \xc8\x87 b\xf9\x05\x82" + - "\"\xca\xc9ayW\xa0\xde\xc1\xa9\xe7\xb2%\xb5\xb2+" + - "*\xa5\xa8X\xb0\x85\x8a\xb5\x8b\xc8\xeeJ\xc1\xba\"\xac" + - "\xe5\xaeko=\xdd\xd3\x1f\x99\x84$\xc8\xfe\xb1\xffM" + - "\x9e~\xde\xf7}>~\xcf\xef}\xde'\xd7wT6" + - "r\xf5\xe1\x9a\x08\x80\xbc%\\a\xb1\xba\x0f\x97n\xbf" + - "\xeag\xabAN Z\xf7\xbc\xd6\x1a\xff\xd6\\\xfd\x09" + - "\x84y\x01`\xca\xe2\x8a\xa5(\xad\xad\x10\x00\xa4U\x15" + - "\xbf\x06\xb4\xee\x1b\xb5\xfb\x99\xe7fl\xba\x17\xc4\x04\xef" + - "+\x03NaB+J=\x02i\x16\x85u\xd2Q\xfa" + - "e\xdd\"^\xb7 \xfe\xc1{\xa4\x1d\xdc:D[\xef" + - "\x13\xeaP:d/8(\xd0\xd67\xe6\xdf\xdf\xf1\x0f" + - "\x9b\xdfY\x03b\x82\xeb\xb5\xf5+\x95KQ:XI" + - "\x9a\x07*\xe7\x02Z_o\x1a\xfd\xfc\x7f\xbe\xf7\xf6Z" + - "\x10\xafF(Y\xfai\xe5\xc7\x08(}U\xf9\xbf\x80" + - "\xd6\xa1\x0b\x0b\xce\xbf\xfc\xe6\x0d\xf7\x818\x81\x14\x90\x14" + - "6T\x8d\xe3\x00\xa5\x9dU\x0d\x80\xd6\xe93\x7f\\w" + - "\xf7\x849\x8f\x82<\x019\x800G\x1a\x07\xab\x12\xa4" + - "q\xa2\x8a\xaci\x98yhob\xca\xe3\x9b\xcaL\xb7" + - "\x15\xf7\x0f\xabC\xe9\xf002\xe8\xd0\xb0\xbb\x00\xad\xdf" + - "\x8fx\xea\xbd\xe2M\xaf>^:\xcfV\xaa\x1f^G" + - "\xbb\xb5\x0c'\x85q\xddW\xdd\xf9\xd3\x03/=\x01\xf2" + - "DD\xebx\xfb5G\xf9m\xbb>\x81\xf9(\xd0\xf1" + - "Sv\x0e\xdfA\xc6\xef\xb5u\xdf\xbf\xf6\xb5\x1f?\xfa" + - "\xd2\xba\xa7@\xbe\x1a\x11\xc0\x0e\xd6\xd8\xea?\x90B}" + - "5\x19\xbf\xe9\xd8\xbe9\xf9\x0d[w8\xee\xdb\xdf\xff" + - "\xad\x9a\xe3 d\xadi\xf9&?\xff\xd9\xd4\xb3\xa5\xc0" + - "\x84\xe9\xd3\xec\xeas\x088E\xa9\xaeE@\xeb\x86\x8f" + - "O\xcd\x9d\xfd\x7f\x1d\xff\x1dX\xbb<\xb2\x94\xd6\xae\xeb" + - "8\xb7?\x96\xcc?_\xe6\xb0\x1d\xbb\x9e\xc8.\x946" + - "D\xc8\xe1\x87\"d\xc2\x8b\x7fsK\xd5\x92S3w" + - "\x838\xd1\xdd\xe6\xc5H\x92\xb6\x09\xdd\xce\x7f\xafl\xf9" + - "\xc9\xcb\xe5p\xb2c\xb23\xd2\x8e\xd2>\xdag\xca\xde" + - "\x88m\xcf\x03\xfb\xb7^S\xf9\xcc\xd7\xaf\xf4\x17\xe6\x13" + - "#\xdaQ\xba0\x82N\xfdj\x04Efd\x0b\x1e\x7f" + - "\xbd>\xf4j0\xefr\xf44E\x86E)\xefc\xcf" + - "N\x8f\xa8_\xae~\xbdl7[1\x1ckEiL" + - "\x8cv\x1b\x19#\xe5\xd6\x05\x8fm\x0c\x9fz\xec-\xb2" + - "4\x00\xb80\x01m\xca\x9e\x98\x8e\xd2\x81\x98\x9d\xedX" + - "\x0d\x0fh%v\xff\xd3\xffL\xcf|\xf4N?\x96J" + - "M\xf1s\xd2\xec8\xfdj\x89\x93\xa1''\xee\xb9\xfb" + - "\x8b\x87\x0e\x1f)\x19j\xc7\xf0\xb9\xb8\x9d\xc2\xbdq\x8a" + - "\x9f\x87\x80\xb2(\xd9\x9a\x1f\xc5\xbbP:ko\xf7\x85" + - "\xad\xcd\x9dR\xc6\xac\xfc\xf9?\x1f\x0f$\xedl\xfc3" + - "\x84\x905\xe7_\x17tU-?y2x\xd0\x89\xb8" + - "\x1d\x91\x0b\xf6\xd2\xdf\xfe\xd7\xe9G\xce\xe43\xbf\xb2\x81" + - "\xe7\xc6l\xe4\xc8i\x04\xcd\x89#\x09\xe85\xb5\x91\x19" + - "\xe3\x8e\xb5\x9dvR\xe9lQ5j:)\\9\x8a" + - "\xb6\xb8\xe1\xce&\xb6p\xea\xad\xa7\xfb\x94|\xd3\xa8i" + - "(\xc9\xa3l\x90\x8dZ\x87\x12\xab\xa9\x01\xb0\xba\xff\x7f" + - "\xc3\xad\xcf\xbf1\xe7\x9cS\x0b\xb6\xb1\xf3k&\x134" + - "\x1e\xbe\xa7y\xee?\x8e\xdb\x7f.h\xec\xec\x1aB\xa7" + - "\xa4\xd4\xd0I\x1dS\xcf\xfc\xcbU\x0f\xbfy\xae?\x08" + - "\xae\xaa\xa9CiC\x8d\x0dAR\xfer\xe6\x7f\x1cI" + - "D\x13\xe7\xcb\x02Xa'\xaf\xa6\x0b\xa5\x035v\xf2" + - "j\xde\"\x98\xdd\xf7\xc9\x1dK>\xbc\xf7\xeb\x0b\xe5\xb9" + - "\xb6\xb7~eL\x12\xa5\x83cl~\x19C\xc8xb" + - "\xdeoV\x9c\xd9<\xea\x9b\xbe$\x97\xe8B\xa9'a" + - "\x93\\b\x9dt\x94~Y\x1f\x08\xcf\xd67\xafx\xe7" + - "\xdb@-\xecK\xb4\x92\xc3\x8f\x0bO\x9f\\\xf9\x8b;" + - "\xbe\x0b:\xbc7\xf1\x199|(A\x0e/\xfb\xf2\xc9" + - "\x9b\x1fY\xf8\xc2\xf7\xc1\xc4&V\xd3R\xb3\xa8\xaa," + - "\xa7\x17B\xe9\xeb\xdc\x9f\xe9Ii\xa5\xa0\x16\xa65\x15" + - "\xcdEL5\xb3i\xc5dI\xd6`\x144\xd5`m" + - "\x88r\x8c\x0f\x01\x84\x10@T\xba\x00\xe4;y\x94s" + - "\x1c\x8a\x88qJ\xbd\x98%\xe1\"\x1ee\x93C\x91\xe3" + - "\xe2\xc4<\xe2\xe2q\x00r\x8eGy\x09\x87\xc8\xc7\x91" + - "\x07\x10\x8b\x1b\x01\xe4%<\xcak8\xb4\x0aL\xcf+" + - "*S!j\xce\xd0u\xac\x06\x0e\xab\x01-\x9d\x99z" + - "\x8f\xd2\x9e\x83(\x0b\x88\x85\xae\xbbL\x8c\x00\x87\x11@" + - "k\x91V\xd4\x8d\xf9\xaa\x89\xd9\\\x92u\xe8\xcc\xc0E" + - "X\x01\x1cV\x00\x0e\xe4^\x8a\x19FVSg+\xaa" + - "\xd2\xc9t\x00\xf2\xac\x92\x0f\x03x\xa4\x8d.\xbd\x8b\xf5" + - "[\x81\x13'\x0a\xe830\xba\xf0\x13\xaf\xdc\x05\x9c8" + - "V\xb0t\xd6\x995L\xa6\xe3\xfcL\xc1\xde\x9b\xd7\xd4" + - "F\xb4\x8a\xaa\xf3\x01\x99\xee|\x88\xd2\xa9\x8d\xd8\x86\xbe" + - "u|_\xebn\xcae\x99jF[\xd4\x0e\xad,\xe4" + - "\xad\xfd\x85\xbc\xb5\x14\xf25\x81\x90\xaf\x9a\x0e /\xe3" + - "Q\xbe\x9fC\x91/\xc5|m\x1d\x80\xbc\x92G\xf9A" + - "\x0e\xad\xb4}HK\x06\x00\xbchv0\xc5,\xea\xcc" + - " \xd9\x08\xc06\x1e\xed\xa0\x8f\x00\\\xd1\xcdt\xb2\xdd" + - "MBT\xd1\xd3\x8b\xbcD\x0d\x10\xe9\x19K\xb2\x86\x99" + - "U;\xe7\xd9\xf2\x866-\x97M\xf7\x90W\xd5\xb6\x9d" + - "c\xa7\x01 \x8a#o\x03@N\x14\xa7\x034d;" + - "UMgV&k\xa45Ue\xc0\xa7\xcd\x15\xedJ" + - "NQ\xd3\xcc;\xa8\xa2\xefA\xce\x01)\xa6w3}" + - "\x92\x12\x80\xef\xf86EW\xf8\xbc!W{q\x9cq" + - "\x1b\x80\xdc\xcc\xa3\xdc\x16\x88\xe3l\x8a\xe3,\x1e\xe5[" + - "\x03q\x9cOql\xe3Q^\xc8\xa1\xa5\xe9\xd9\xce\xac" + - "z\x13\x03^\x0f\"\xd00U%\xcf(f\xa5x\xac" + - "\xd0\x0afVS\x0d\x8c\xf9\xfc\x0f\x88\xb1@\xa4\x84\xc1" + - "09\xc9\x85\x94\x8b(M\x1d\x9fdFQ\xc8\x99\x86" + - "\x1c\xf2<\x89L\x03\x90+y\x94\xe3\x1c6\xe8\xcc(" + - "\xe6L\x8c\xf9\xd7\xec_\xe2T7|\x01\x18&\xfb\x83" + - "\xe1d\x009\xc3\xa3\\\xe0\x10K\xd1\xcbO\x0f\xb0\x01" + - "\x8f\x0e\x0a\x17o\x05\x90M\x1e\xe5\x95\x1cZ\x86sH" + - "\x0b`\xc6\x8dhm\xc60[\x0a\xee_+2\x86\xd9" + - "\xa6\xe9&\x0a\xc0\xa1\x00\x84[\xcd`M\x1dTS-" + - "\x99\x1c\xbb9\xcb\xab&\x86\x81\xc30\x0cXT\x0e>" + - "\xa2DlN\xb5\xbb\xdeL 0\xfc\x1d\x8f\xf2\xdf\x07" + - "\xbc\xa9'\x1e\xbb\x9eG\xf9F\x0e-%\x9d\xd6\x8a\xaa" + - "9\x0fx\xa5\xb3\x0c\xf3)\x06\xd1\xb4\xce|8\x0c=" + - "\xd4.9\x94\x05;\xaa+y#h^\xb2?\xf3(" + - "\xb0\xd7\xf2(O\xed?\x86+\xf2\xcc0\x94N\xd6\xa7" + - "B\xc3\xfd\xb0\x0dUY\x9a\x00\x9bd\x0e\xcfO\xd2\x99" + - "!\x14s&YQmY\x8e\x19\x94\xde\xf1<\xca\xd7" + - "s\x18\xc1\xef-\xc7\x8e\x89\x1b\xfd0\xd52]\xd7t" + - "\x8c\xf9\xf7`\x09}\xe9\xd2\x01\xa8\xa9\xcd\xccT\xb29" + - "\xa4\xca\xf0\x9a\xb22\x8c\x0eV\xda~\xd8\x1c\xf1\xf8\x06" + - "\x02h\xbeWQ\x10\xc2b<\xcaWphu\xeaJ" + - "\x9a\xb51\x1d\xb3Zf\x8e\xa2j)\x9e\xa5\xfb\xe0e" + - "\xc4\xa5\x1e\x9a\xb4K\xcd\x00o\xd5\xc0\xebuV\x0aB" + - "iy[\xadcs\xdc\xb3y\xf98\xff>\xf4\xd2\xbc" + - "\xaa\xdd'l\x8f\x92\xd6\x13^\xef\xe7Q\xde\x14\xa0\xf6" + - "\x0dD^\x8f\xf2(?\xcd\xa1\x18\x0a\xc51\x04 >" + - "I(\xd9\xc4\xa3\xbc\x9d\xeb}k\xb2n\xa6\x9a\xcd\xd9" + - "N\x10\x98\xe1K\xc9\xc4\xe6l'\x03\xde\xb8\\z\xab" + - "\x1c$\x1eZ\xbb\xa1\xe5\x98\xc9\x9aY:\xa7\xe8\x8a\x99" + - "\xedf\xce\xf7\x12\x18\xdd\xa4\x0e\x84\xdbd\x9f\xea!\xfc" + - "F\xddF%\x00\x87q>G\x0a,\xd0_\x0c`\xad" + - "\xb39Y\xa6\xa9}0\xe0WL\x09\x07h\x0ct\x05" + - "\xfa\xeas\x0bfV\xd0T\x83\xec\x0b\xa4~Z\x7f\xa9" + - "\xd7\xfd\xd4\xbbt\xba~u0\xf3%:\xdd\xb0\xd5O" + - "\xb2\x18\xe2\x9c\xcco\xdb\x01 o\xe7Q~\x81\xc3\x06" + - "\xe7\xa6\xc7\x98\xffR.e\xcb\xb9\xcffiP\x9bV" + - "r>\xe5Z:+\xe4\x944\x9b\x81\xa5\xbb\x1b\x10\x81" + - "C\xb4!\x92/\xe8\xcc00\xab\xa9rQ\xc9ey" + - "\xb3\xc7\xeb\xb7\xd4b\xbeMg\xddY\xd4\x8aF\x93i" + - "\xb2\xbcP0\x8d\xa1tc~\x80\x88\x1f\x84l\xce(" + - "c\xe8:\x9f{\xbc\x00M\xec\xf2)0Z,f=" + - "\xee\xb3rZ\xda\xce\x1bD\xe7(\xf9\xbe\x14X1h" + - "\xad\xf6\xaat\x97\x91\xff\x9a\xba\x87\x81\x1bvr\xdd\xee" + - "h\x03&S\x094\xf2(\xcf\x0a\x98\xdc29\xe0\x87" + - "k\xf2\xecv\xdf\x0f\xe1\xdfY\x8fkU-\xcb\x13s" + - "\xbb\xc1,9\xd3\x04\xc2-\xbe\xce@\xf6\x05\x0bjn" + - "\xa1\xd6\xf6\x90l\x9c\xea\xda(\xf5`+@j\x09\xf2" + - "\x98Z\x83\xbe\x99\xd2*\x9c\x0e\x90ZF\xf2\xfb\xd1\xb7" + - "TZ\x8b\x09\x80\xd4J\x92?\x88\xde\xc3BZ\x8f\xbb" + - "\x00R\x0f\x92x\x0b\xa9\x87x\xbb$\xa4\xcd\xf6\xf6\x9b" + - "H\xbe\x9d\xe4\xe1P\x1c\xc3\x00\xd26\xac\x03Hm!" + - "\xf9\xcb$\xaf\xe0\xe2X\x01 \xed\xc1.\x80\xd4n\x92" + - "\xbfFr!\x1c\xa7\xb7\x95\xb4\x17u\x80\xd4\x8fH\xfe" + - "\x06\xc9+G\xc7\xb1\x12@\xdao\xcb_'\xf9\xbb$" + - "\xaf\x1a\x13\xc7*\x00\xe9 \xae\x06H\xbdM\xf2#$" + - "\x1f\x86q\x1c\x06 \x1d\xc6\xad\x00\xa9#$\xff%\xc9" + - "\x87W\xc4q8\x80\xf4\xa9m\xcf1\x92\x7fN\xf2\xea" + - "P\x1c\xab\x01\xa4\x13\xb8\x03 \xf59\xc9\x7fG\xf2\x88" + - "\x10\xc7\x08\x80t\xd6\xf6\xeb\x0c\xc9+\xb9\xb2\xbe\xdeE" + - "TY\xf3\xcek\x86\x972V\xaaqt\xe0\xde\xa6E" + - "\xa9A\xc7\xa8?)\x03\xc4(\xa0U\xd0\xb4\xdc\x9c\xde" + - "H\x8d\x9aJ\xa7\xe1>\x14b\xfe\xf0\x02\x90\x84\xde\xbd" + - "\x0fQMm\xc9xDP\xce:\xae%Y\xa3\xa9h" + - "j\xc5\x02\xd4f\x14\x93e<\xce\xd1\x8b\xeaL]\xcb" + - "\xcfC\xa6\xe7\xb3\xaa\x92\x1b\x84\x8d\xaa\x80\xc3*(Q" + - "\x82\xbb\xf7\xc0\xd4t\xf1g\x8f\x87h\xae\x1c\xd1\xb5\x85" + - "i\xf3\x94\xce\xa1\xf0\xd4d\xbf\x7f\x8b\xaa\x01B\xaa\xed" + - "Vr\xc5\x1fBO\xbd[\x89d\x83\xd3\x8a\x0c\xf6(" + - "pg\x19\x83SI\xef\x86\xb0\xf7\x85\x8a\x811#\x9d" + - "\xc3\x95\xf6\x1f\xb2\xf9\x9d\xcct~\xd1\xeb\x96\xde\x16B" + - "\xf0\x9a\xbf\xb4\xd5IfD\x87\xe2\xba?\xf3\x19\xfc=" + - "\xd4\xcf\xc5\xdf\xcf\xb5\xef\xf6\x9c\x817\x11\xe5~!\x8f" + - "\xf2\xa2@\xeeYk?o\xa2\xa4?\x0c\x11y\xae4" + - "\x0d\xa1\x8b\xa2\xc0\xa3\xbc\x8c\xc3(=^1\xe6\x0f\x87" + - "{\x19\xdd\xfb\xc1NPhQ3\x0cp\x89\x8b\xe6\xc0" + - "\xf5\xe1\x8dI\x07\xef\xce\x86\xe6\xb6\xdb\xf5\x0e\x1apo" + - "\xf4Xv\xf2E\xdfe\x0d\xce\xa1\x84\xb3\xd1\xf6\x1c\xc6" + - "\x1d\xc3\xa2;\xd0\x13\xf7,\x05N\xdc)\xa0?\xaaD" + - "w2)n\xd3\x81\x137\x0b\xc8y\x83mt\x07\xd8" + - "\xe2\xfa\x07\x80\x13\xd7\x0a\xc8{sitGb\xf5=" + - "\xc3\x108q\xb9\x80!o\xde\x8f\xee@M\\\xdc\x05" + - "\x9c\x98\x150\xec\x8d\xbc\xd1\x9d\xb9\x8a\xb7\xaf\x06N\x9c" + - "\xef\x0f~\xa0\xc1\xf1\xa3\x11-\x17\xa3Pk\xa3\xb4\xf7" + - "\x18\xc8\xd1\x02hD\xcb\xed\x81\xf9\x8b5\xc1\xb6\x96;" + - "\xc9\x80hZ1Y#5gN\xfdc\x89\x00\xa0\x11" + - "\xe5\x10\x06\xe6\x89\x00\x97\xfb\x08M\xb2Z;\xcf?\xb4" + - "er\xd7\xff@J\xe2\xfb\xb3\x9a\xce\xf1&b\x81}" + - "\xa9\x0b\xac\xe6Q\x1e\xcd\x0d\xda\xf8\x85.\xe6\x85\x0b\xfe" + - "(-\xa6\xfd\xff\xd6\xdb\xff05N\xef\xf2(\x1f\x0b" + - "\x94\xf5Q\x12~\xc0\xa3|<\xd08}D\xb5~\x8c" + - "G\xf9\xbc?\xe4\xfc\xea\x01\x00\xf9<\x8f\xc9@#\"" + - "\xfe\x89\x14\xbf\xa3\xeb\xdanC\xd0iC\xc2\xb8\x11 " + - "UI\xd7x\xdcnCBN\x1b\"b;@*F" + - "\xf2+\x82m\xc8\x18\xbc\x0d 5\x9a\xe4\xe3\xb1\xf7\xbb" + - "F(\xea~\xa3\x96\xd3:ge\xd5~\xef6w\xea" + - "\x8a\xe6L%\x9b+\xea\x0c\xfc\xab\xb5D6\xcd\x81\xdb" + - "\xde\x19\xc7:\x93\x97\x14\x810\x83\x867\x95\xb9\x84\x17" + - "\xe5\x90n\x9e\x19\xba\xae\xa1^\xd6\xc4N\xf6\x9bX\xaf" + - "\x87\xa5^\xfcf\x1e\xe5y\x94\x8aF'\x15r\xbb\xdf" + - "v\xd7\xa6\x95\xa2\xc1\xfa\xf8\x00<\xd3\xbd)\x80\xb1H" + - "+\xe62I\x06\x82\xa9\xf7\x94\x85`\xd0f6\xc5\xa2" + - ".s9\x13d\xf7\xbf!\xe8\xfe\xd3#0Av\xc7" + - "\xf8\xe8\xfeo\xab\xef\x04\xd9\x8dA\x9f\x09\xb2\xf3\xc1\xc6" + - "h\xef\x09\xf2e<_\x9dk,\xc0\x18\x974X\x1d" + - "\xf2<\xd2\xfb\xf7oY\xa5W]\xee\x98\xc0\xbd\x90\xfe" + - "\x1c\x00\x00\xff\xff\xa1\x1ap\xe9" +type UpdateConfigurationResponse struct{ capnp.Struct } + +// UpdateConfigurationResponse_TypeID is the unique identifier for the type UpdateConfigurationResponse. +const UpdateConfigurationResponse_TypeID = 0xdb58ff694ba05cf9 + +func NewUpdateConfigurationResponse(s *capnp.Segment) (UpdateConfigurationResponse, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 8, PointerCount: 1}) + return UpdateConfigurationResponse{st}, err +} + +func NewRootUpdateConfigurationResponse(s *capnp.Segment) (UpdateConfigurationResponse, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 8, PointerCount: 1}) + return UpdateConfigurationResponse{st}, err +} + +func ReadRootUpdateConfigurationResponse(msg *capnp.Message) (UpdateConfigurationResponse, error) { + root, err := msg.RootPtr() + return UpdateConfigurationResponse{root.Struct()}, err +} + +func (s UpdateConfigurationResponse) String() string { + str, _ := text.Marshal(0xdb58ff694ba05cf9, s.Struct) + return str +} + +func (s UpdateConfigurationResponse) LatestAppliedVersion() int32 { + return int32(s.Struct.Uint32(0)) +} + +func (s UpdateConfigurationResponse) SetLatestAppliedVersion(v int32) { + s.Struct.SetUint32(0, uint32(v)) +} + +func (s UpdateConfigurationResponse) Err() (string, error) { + p, err := s.Struct.Ptr(0) + return p.Text(), err +} + +func (s UpdateConfigurationResponse) HasErr() bool { + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s UpdateConfigurationResponse) ErrBytes() ([]byte, error) { + p, err := s.Struct.Ptr(0) + return p.TextBytes(), err +} + +func (s UpdateConfigurationResponse) SetErr(v string) error { + return s.Struct.SetText(0, v) +} + +// UpdateConfigurationResponse_List is a list of UpdateConfigurationResponse. +type UpdateConfigurationResponse_List struct{ capnp.List } + +// NewUpdateConfigurationResponse creates a new list of UpdateConfigurationResponse. +func NewUpdateConfigurationResponse_List(s *capnp.Segment, sz int32) (UpdateConfigurationResponse_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 8, PointerCount: 1}, sz) + return UpdateConfigurationResponse_List{l}, err +} + +func (s UpdateConfigurationResponse_List) At(i int) UpdateConfigurationResponse { + return UpdateConfigurationResponse{s.List.Struct(i)} +} + +func (s UpdateConfigurationResponse_List) Set(i int, v UpdateConfigurationResponse) error { + return s.List.SetStruct(i, v.Struct) +} + +func (s UpdateConfigurationResponse_List) String() string { + str, _ := text.MarshalList(0xdb58ff694ba05cf9, s.List) + return str +} + +// UpdateConfigurationResponse_Promise is a wrapper for a UpdateConfigurationResponse promised by a client call. +type UpdateConfigurationResponse_Promise struct{ *capnp.Pipeline } + +func (p UpdateConfigurationResponse_Promise) Struct() (UpdateConfigurationResponse, error) { + s, err := p.Pipeline.Struct() + return UpdateConfigurationResponse{s}, err +} + +type ConfigurationManager struct{ Client capnp.Client } + +// ConfigurationManager_TypeID is the unique identifier for the type ConfigurationManager. +const ConfigurationManager_TypeID = 0xb48edfbdaa25db04 + +func (c ConfigurationManager) UpdateConfiguration(ctx context.Context, params func(ConfigurationManager_updateConfiguration_Params) error, opts ...capnp.CallOption) ConfigurationManager_updateConfiguration_Results_Promise { + if c.Client == nil { + return ConfigurationManager_updateConfiguration_Results_Promise{Pipeline: capnp.NewPipeline(capnp.ErrorAnswer(capnp.ErrNullClient))} + } + call := &capnp.Call{ + Ctx: ctx, + Method: capnp.Method{ + InterfaceID: 0xb48edfbdaa25db04, + MethodID: 0, + InterfaceName: "tunnelrpc/tunnelrpc.capnp:ConfigurationManager", + MethodName: "updateConfiguration", + }, + Options: capnp.NewCallOptions(opts), + } + if params != nil { + call.ParamsSize = capnp.ObjectSize{DataSize: 8, PointerCount: 1} + call.ParamsFunc = func(s capnp.Struct) error { return params(ConfigurationManager_updateConfiguration_Params{Struct: s}) } + } + return ConfigurationManager_updateConfiguration_Results_Promise{Pipeline: capnp.NewPipeline(c.Client.Call(call))} +} + +type ConfigurationManager_Server interface { + UpdateConfiguration(ConfigurationManager_updateConfiguration) error +} + +func ConfigurationManager_ServerToClient(s ConfigurationManager_Server) ConfigurationManager { + c, _ := s.(server.Closer) + return ConfigurationManager{Client: server.New(ConfigurationManager_Methods(nil, s), c)} +} + +func ConfigurationManager_Methods(methods []server.Method, s ConfigurationManager_Server) []server.Method { + if cap(methods) == 0 { + methods = make([]server.Method, 0, 1) + } + + methods = append(methods, server.Method{ + Method: capnp.Method{ + InterfaceID: 0xb48edfbdaa25db04, + MethodID: 0, + InterfaceName: "tunnelrpc/tunnelrpc.capnp:ConfigurationManager", + MethodName: "updateConfiguration", + }, + Impl: func(c context.Context, opts capnp.CallOptions, p, r capnp.Struct) error { + call := ConfigurationManager_updateConfiguration{c, opts, ConfigurationManager_updateConfiguration_Params{Struct: p}, ConfigurationManager_updateConfiguration_Results{Struct: r}} + return s.UpdateConfiguration(call) + }, + ResultsSize: capnp.ObjectSize{DataSize: 0, PointerCount: 1}, + }) + + return methods +} + +// ConfigurationManager_updateConfiguration holds the arguments for a server call to ConfigurationManager.updateConfiguration. +type ConfigurationManager_updateConfiguration struct { + Ctx context.Context + Options capnp.CallOptions + Params ConfigurationManager_updateConfiguration_Params + Results ConfigurationManager_updateConfiguration_Results +} + +type ConfigurationManager_updateConfiguration_Params struct{ capnp.Struct } + +// ConfigurationManager_updateConfiguration_Params_TypeID is the unique identifier for the type ConfigurationManager_updateConfiguration_Params. +const ConfigurationManager_updateConfiguration_Params_TypeID = 0xb177ca2526a3ca76 + +func NewConfigurationManager_updateConfiguration_Params(s *capnp.Segment) (ConfigurationManager_updateConfiguration_Params, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 8, PointerCount: 1}) + return ConfigurationManager_updateConfiguration_Params{st}, err +} + +func NewRootConfigurationManager_updateConfiguration_Params(s *capnp.Segment) (ConfigurationManager_updateConfiguration_Params, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 8, PointerCount: 1}) + return ConfigurationManager_updateConfiguration_Params{st}, err +} + +func ReadRootConfigurationManager_updateConfiguration_Params(msg *capnp.Message) (ConfigurationManager_updateConfiguration_Params, error) { + root, err := msg.RootPtr() + return ConfigurationManager_updateConfiguration_Params{root.Struct()}, err +} + +func (s ConfigurationManager_updateConfiguration_Params) String() string { + str, _ := text.Marshal(0xb177ca2526a3ca76, s.Struct) + return str +} + +func (s ConfigurationManager_updateConfiguration_Params) Version() int32 { + return int32(s.Struct.Uint32(0)) +} + +func (s ConfigurationManager_updateConfiguration_Params) SetVersion(v int32) { + s.Struct.SetUint32(0, uint32(v)) +} + +func (s ConfigurationManager_updateConfiguration_Params) Config() ([]byte, error) { + p, err := s.Struct.Ptr(0) + return []byte(p.Data()), err +} + +func (s ConfigurationManager_updateConfiguration_Params) HasConfig() bool { + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s ConfigurationManager_updateConfiguration_Params) SetConfig(v []byte) error { + return s.Struct.SetData(0, v) +} + +// ConfigurationManager_updateConfiguration_Params_List is a list of ConfigurationManager_updateConfiguration_Params. +type ConfigurationManager_updateConfiguration_Params_List struct{ capnp.List } + +// NewConfigurationManager_updateConfiguration_Params creates a new list of ConfigurationManager_updateConfiguration_Params. +func NewConfigurationManager_updateConfiguration_Params_List(s *capnp.Segment, sz int32) (ConfigurationManager_updateConfiguration_Params_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 8, PointerCount: 1}, sz) + return ConfigurationManager_updateConfiguration_Params_List{l}, err +} + +func (s ConfigurationManager_updateConfiguration_Params_List) At(i int) ConfigurationManager_updateConfiguration_Params { + return ConfigurationManager_updateConfiguration_Params{s.List.Struct(i)} +} + +func (s ConfigurationManager_updateConfiguration_Params_List) Set(i int, v ConfigurationManager_updateConfiguration_Params) error { + return s.List.SetStruct(i, v.Struct) +} + +func (s ConfigurationManager_updateConfiguration_Params_List) String() string { + str, _ := text.MarshalList(0xb177ca2526a3ca76, s.List) + return str +} + +// ConfigurationManager_updateConfiguration_Params_Promise is a wrapper for a ConfigurationManager_updateConfiguration_Params promised by a client call. +type ConfigurationManager_updateConfiguration_Params_Promise struct{ *capnp.Pipeline } + +func (p ConfigurationManager_updateConfiguration_Params_Promise) Struct() (ConfigurationManager_updateConfiguration_Params, error) { + s, err := p.Pipeline.Struct() + return ConfigurationManager_updateConfiguration_Params{s}, err +} + +type ConfigurationManager_updateConfiguration_Results struct{ capnp.Struct } + +// ConfigurationManager_updateConfiguration_Results_TypeID is the unique identifier for the type ConfigurationManager_updateConfiguration_Results. +const ConfigurationManager_updateConfiguration_Results_TypeID = 0x958096448eb3373e + +func NewConfigurationManager_updateConfiguration_Results(s *capnp.Segment) (ConfigurationManager_updateConfiguration_Results, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 1}) + return ConfigurationManager_updateConfiguration_Results{st}, err +} + +func NewRootConfigurationManager_updateConfiguration_Results(s *capnp.Segment) (ConfigurationManager_updateConfiguration_Results, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 1}) + return ConfigurationManager_updateConfiguration_Results{st}, err +} + +func ReadRootConfigurationManager_updateConfiguration_Results(msg *capnp.Message) (ConfigurationManager_updateConfiguration_Results, error) { + root, err := msg.RootPtr() + return ConfigurationManager_updateConfiguration_Results{root.Struct()}, err +} + +func (s ConfigurationManager_updateConfiguration_Results) String() string { + str, _ := text.Marshal(0x958096448eb3373e, s.Struct) + return str +} + +func (s ConfigurationManager_updateConfiguration_Results) Result() (UpdateConfigurationResponse, error) { + p, err := s.Struct.Ptr(0) + return UpdateConfigurationResponse{Struct: p.Struct()}, err +} + +func (s ConfigurationManager_updateConfiguration_Results) HasResult() bool { + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s ConfigurationManager_updateConfiguration_Results) SetResult(v UpdateConfigurationResponse) error { + return s.Struct.SetPtr(0, v.Struct.ToPtr()) +} + +// NewResult sets the result field to a newly +// allocated UpdateConfigurationResponse struct, preferring placement in s's segment. +func (s ConfigurationManager_updateConfiguration_Results) NewResult() (UpdateConfigurationResponse, error) { + ss, err := NewUpdateConfigurationResponse(s.Struct.Segment()) + if err != nil { + return UpdateConfigurationResponse{}, err + } + err = s.Struct.SetPtr(0, ss.Struct.ToPtr()) + return ss, err +} + +// ConfigurationManager_updateConfiguration_Results_List is a list of ConfigurationManager_updateConfiguration_Results. +type ConfigurationManager_updateConfiguration_Results_List struct{ capnp.List } + +// NewConfigurationManager_updateConfiguration_Results creates a new list of ConfigurationManager_updateConfiguration_Results. +func NewConfigurationManager_updateConfiguration_Results_List(s *capnp.Segment, sz int32) (ConfigurationManager_updateConfiguration_Results_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 0, PointerCount: 1}, sz) + return ConfigurationManager_updateConfiguration_Results_List{l}, err +} + +func (s ConfigurationManager_updateConfiguration_Results_List) At(i int) ConfigurationManager_updateConfiguration_Results { + return ConfigurationManager_updateConfiguration_Results{s.List.Struct(i)} +} + +func (s ConfigurationManager_updateConfiguration_Results_List) Set(i int, v ConfigurationManager_updateConfiguration_Results) error { + return s.List.SetStruct(i, v.Struct) +} + +func (s ConfigurationManager_updateConfiguration_Results_List) String() string { + str, _ := text.MarshalList(0x958096448eb3373e, s.List) + return str +} + +// ConfigurationManager_updateConfiguration_Results_Promise is a wrapper for a ConfigurationManager_updateConfiguration_Results promised by a client call. +type ConfigurationManager_updateConfiguration_Results_Promise struct{ *capnp.Pipeline } + +func (p ConfigurationManager_updateConfiguration_Results_Promise) Struct() (ConfigurationManager_updateConfiguration_Results, error) { + s, err := p.Pipeline.Struct() + return ConfigurationManager_updateConfiguration_Results{s}, err +} + +func (p ConfigurationManager_updateConfiguration_Results_Promise) Result() UpdateConfigurationResponse_Promise { + return UpdateConfigurationResponse_Promise{Pipeline: p.Pipeline.GetPipeline(0)} +} + +type CloudflaredServer struct{ Client capnp.Client } + +// CloudflaredServer_TypeID is the unique identifier for the type CloudflaredServer. +const CloudflaredServer_TypeID = 0xf548cef9dea2a4a1 + +func (c CloudflaredServer) RegisterUdpSession(ctx context.Context, params func(SessionManager_registerUdpSession_Params) error, opts ...capnp.CallOption) SessionManager_registerUdpSession_Results_Promise { + if c.Client == nil { + return SessionManager_registerUdpSession_Results_Promise{Pipeline: capnp.NewPipeline(capnp.ErrorAnswer(capnp.ErrNullClient))} + } + call := &capnp.Call{ + Ctx: ctx, + Method: capnp.Method{ + InterfaceID: 0x839445a59fb01686, + MethodID: 0, + InterfaceName: "tunnelrpc/tunnelrpc.capnp:SessionManager", + MethodName: "registerUdpSession", + }, + Options: capnp.NewCallOptions(opts), + } + if params != nil { + call.ParamsSize = capnp.ObjectSize{DataSize: 16, PointerCount: 2} + call.ParamsFunc = func(s capnp.Struct) error { return params(SessionManager_registerUdpSession_Params{Struct: s}) } + } + return SessionManager_registerUdpSession_Results_Promise{Pipeline: capnp.NewPipeline(c.Client.Call(call))} +} +func (c CloudflaredServer) UnregisterUdpSession(ctx context.Context, params func(SessionManager_unregisterUdpSession_Params) error, opts ...capnp.CallOption) SessionManager_unregisterUdpSession_Results_Promise { + if c.Client == nil { + return SessionManager_unregisterUdpSession_Results_Promise{Pipeline: capnp.NewPipeline(capnp.ErrorAnswer(capnp.ErrNullClient))} + } + call := &capnp.Call{ + Ctx: ctx, + Method: capnp.Method{ + InterfaceID: 0x839445a59fb01686, + MethodID: 1, + InterfaceName: "tunnelrpc/tunnelrpc.capnp:SessionManager", + MethodName: "unregisterUdpSession", + }, + Options: capnp.NewCallOptions(opts), + } + if params != nil { + call.ParamsSize = capnp.ObjectSize{DataSize: 0, PointerCount: 2} + call.ParamsFunc = func(s capnp.Struct) error { return params(SessionManager_unregisterUdpSession_Params{Struct: s}) } + } + return SessionManager_unregisterUdpSession_Results_Promise{Pipeline: capnp.NewPipeline(c.Client.Call(call))} +} +func (c CloudflaredServer) UpdateConfiguration(ctx context.Context, params func(ConfigurationManager_updateConfiguration_Params) error, opts ...capnp.CallOption) ConfigurationManager_updateConfiguration_Results_Promise { + if c.Client == nil { + return ConfigurationManager_updateConfiguration_Results_Promise{Pipeline: capnp.NewPipeline(capnp.ErrorAnswer(capnp.ErrNullClient))} + } + call := &capnp.Call{ + Ctx: ctx, + Method: capnp.Method{ + InterfaceID: 0xb48edfbdaa25db04, + MethodID: 0, + InterfaceName: "tunnelrpc/tunnelrpc.capnp:ConfigurationManager", + MethodName: "updateConfiguration", + }, + Options: capnp.NewCallOptions(opts), + } + if params != nil { + call.ParamsSize = capnp.ObjectSize{DataSize: 8, PointerCount: 1} + call.ParamsFunc = func(s capnp.Struct) error { return params(ConfigurationManager_updateConfiguration_Params{Struct: s}) } + } + return ConfigurationManager_updateConfiguration_Results_Promise{Pipeline: capnp.NewPipeline(c.Client.Call(call))} +} + +type CloudflaredServer_Server interface { + RegisterUdpSession(SessionManager_registerUdpSession) error + + UnregisterUdpSession(SessionManager_unregisterUdpSession) error + + UpdateConfiguration(ConfigurationManager_updateConfiguration) error +} + +func CloudflaredServer_ServerToClient(s CloudflaredServer_Server) CloudflaredServer { + c, _ := s.(server.Closer) + return CloudflaredServer{Client: server.New(CloudflaredServer_Methods(nil, s), c)} +} + +func CloudflaredServer_Methods(methods []server.Method, s CloudflaredServer_Server) []server.Method { + if cap(methods) == 0 { + methods = make([]server.Method, 0, 3) + } + + methods = append(methods, server.Method{ + Method: capnp.Method{ + InterfaceID: 0x839445a59fb01686, + MethodID: 0, + InterfaceName: "tunnelrpc/tunnelrpc.capnp:SessionManager", + MethodName: "registerUdpSession", + }, + Impl: func(c context.Context, opts capnp.CallOptions, p, r capnp.Struct) error { + call := SessionManager_registerUdpSession{c, opts, SessionManager_registerUdpSession_Params{Struct: p}, SessionManager_registerUdpSession_Results{Struct: r}} + return s.RegisterUdpSession(call) + }, + ResultsSize: capnp.ObjectSize{DataSize: 0, PointerCount: 1}, + }) + + methods = append(methods, server.Method{ + Method: capnp.Method{ + InterfaceID: 0x839445a59fb01686, + MethodID: 1, + InterfaceName: "tunnelrpc/tunnelrpc.capnp:SessionManager", + MethodName: "unregisterUdpSession", + }, + Impl: func(c context.Context, opts capnp.CallOptions, p, r capnp.Struct) error { + call := SessionManager_unregisterUdpSession{c, opts, SessionManager_unregisterUdpSession_Params{Struct: p}, SessionManager_unregisterUdpSession_Results{Struct: r}} + return s.UnregisterUdpSession(call) + }, + ResultsSize: capnp.ObjectSize{DataSize: 0, PointerCount: 0}, + }) + + methods = append(methods, server.Method{ + Method: capnp.Method{ + InterfaceID: 0xb48edfbdaa25db04, + MethodID: 0, + InterfaceName: "tunnelrpc/tunnelrpc.capnp:ConfigurationManager", + MethodName: "updateConfiguration", + }, + Impl: func(c context.Context, opts capnp.CallOptions, p, r capnp.Struct) error { + call := ConfigurationManager_updateConfiguration{c, opts, ConfigurationManager_updateConfiguration_Params{Struct: p}, ConfigurationManager_updateConfiguration_Results{Struct: r}} + return s.UpdateConfiguration(call) + }, + ResultsSize: capnp.ObjectSize{DataSize: 0, PointerCount: 1}, + }) + + return methods +} + +const schema_db8274f9144abc7e = "x\xda\xccZ{t\x14\xe7u\xbfwfW#\x81V" + + "\xab\xf1\xac\xd1\x03T\xb5:P\x179\xd8\x06Jk\xab" + + "9\xd1\xc3\x12\xb1d\x03\x9a]\x94\xe3cC\x8eG\xbb" + + "\x9f\xa4Qwg\x96\x99Y\x19\x11\x130\x01c\xfb\xb8" + + "\x8eq\xc0\xb1Ih0.\xed\x01\xdb\xad\x89\xdd\xa6\xee" + + "1\xa7\xa6\xcd\xabq\xc0&\x87\xf4\x90@\x9a&\x84>" + + "8\xb8\xae14\x876\xf1\xf4\xdc\x99\x9d\x87v\x17\x09" + + "\x8c\xff\xc8\x7f\xab;\xdf\xe3\xde\xdf\xf7\xbb\x8f\xef~\xba" + + "\xed\xe6\x9a.nq\xf4\xed:\x00\xf9\xa5h\x95\xcd\xda" + + "\x7f\xb0a\xef\x82\x7f\xdc\x02r3\xa2\xfd\xf97\x06\x12" + + "\x97\xad-\xa7 \xca\x0b\x00KW\x08\x1bPR\x04\x01" + + "@Z+\xfc;\xa0\xfd\xc8\x9cW\xbe\xb6\xbfo\xe7\x17" + + "@l\xe6\x83\xc1\x80K\xbb\xab\x07P\x1a\xaa\xa6\x91r" + + "\xf5v\xe9\x10\xfd\xb2\xef\x16o\xbd?\xf1\xce1\x1a\x1d" + + "^:BK?W\xdd\x8e\xd2\x01g\xc2\xfejZ\xfa" + + "\x93\xb9\xb7\xf7\xfd\xc1\xae\xb7\xb6\x82\xd8\xccMYzG" + + "\xcd\x06\x94\xf6\xd7\xd0\xc8\xe7kV\x01\xda\x1f\xecl|" + + "\xf1\xf9c\xdf\xdd\x06\xe2M\x08EM_\xaf\xf91\x02" + + "JGk\xfe\x0a\xd0>z\xe9\xfe\x8b\xaf}{\xd9#" + + " .\xa4\x01H\x03r\xb3\xda8@i\xdb\xacN@" + + "\xfb\xdc\xf9\xff\xdb\xfe\xb9\x85+\x9f\x02y!r\x00Q" + + "\x8eF\xec\x9f\xd5L#\x0e\xcf\"m:\x97\x1f}\xbd" + + "y\xe93;KTw\x06\xee\x99\xdd\x8e\xd2\xcb\xb3I" + + "\xa1\x03\xb3\x1f\x04\xb4?\xf5\x87\xaf>\xd9\xfb\xcc\xe6]" + + " \xde\xea\xef\x17\xab\xbd\x8fV[XK\xfb\xfdO\xdd" + + "W\x8e\x15\xee\xfc\xc63E\x85\x9cU\xfak\xdbi\x80" + + "RK+\xb4M,x\xe0\x1f\xbe\xf5\xea\x97A^\x84" + + "h\x9f\x1e\xbe\xf9\x87\xfc\x9e\x83\xa7`\x08\x05\xd2o\xe9" + + "\x91\xda}d\xddqg\xec\xdb\x9fx\xe3\xef\x9ezu" + + "\xfbW@\xbe\x09\x11\xc0AsY\xec\x7fi@\x7f\x8c" + + "v\xdby\xf2\xf0\xca\xdc\x8e\xdd\xfb\\|\x9c\xef\xebb" + + "\x1c\x07\x11{k\xff/sC/\xa4^(\"\x17\xa5" + + "O,v\x01\x01\x97N\xc6Z\x11\xd0^\xf6\xe3\xb3\xab" + + "V|}\xe4/Bsw\xd5m\xa0\xb9\xdbG.\x1c" + + "\xa9O\xe6^,A\xc41vG\xddA\x94\x0e\xd49" + + "\x87YG*\xbc\xfc[w\xd7\xac?\xbb\xfc\x15\x10\x17" + + "y\xcb|\xab.I\xcbL|\xef\x85\xdf]\xf0\xbd\x07" + + "\x0f\x81|+\xfa`\x1d\xa1o(\xfd\xa4\x8e\xec\x8b\x9c" + + "Zp\xf0\xf0O\x9f|\xad\x8ccw\xc47\xa0\xb4\"" + + "N\xbb\xf4\xc7?-M\xd2/;\xb2\x96\xffPy\xf6" + + "\xef_+\xe5\xaf\x83\xb1\x12\x1fF\xa9\x10w\x10\x88;" + + "\xf6=~d\xf7\xcd\xd5_\xfb\xe0\xaf+\x9d\xebs\xf5" + + "\xc3(\xbd\\\xef\x9ck=irc?\x9e~sq" + + "\xe4\x1ba\xa2\xd5\x88\xe7\x08\xe9\x16\x91\x88\xd6\xf2nO" + + "L{o\xcb\x9b%\xab9\x03\x0f\x8b\x03(\x1d\x17i" + + "\xb5\xa3\xce\xe0\x81\xfb\xbf\xf4t\xf4\xec\x97\xbeC\x9a\x86" + + "\x18\x1e%\x1fX\xaa\xde`\xa0\xb4\xf1\x06\xfa9yC" + + "\x03\x0fh7\xbf\xf2G\x7f\xd9\x93\xf9\xd1[\x154\x95" + + ".\xdfxA\x8a\xce\xa1_8\x87\x14=\xb3\xe8\xd0\xe7" + + "\xfe\xf3O\x8e\x9f(*\xea`\xbav\x8eC\x89us" + + "\xe8<.\xaf\xd9{\xb7j\xdf{\xaa\x14%\xf7\xf4\xe6" + + "|\x1d\xa5\x03\xcer\xfb\x9d\xe5|\xfeU\x1a]\xd30" + + "\x8eRK\x03\x8dnj\xa0\xb5\xb9\xb3J\xd3\xe6\x7f\xfe" + + "\xd4\xe9\x10eZ\x1a~\x8e\x10\xb1W~\xe6\xfe\xf1\x9a" + + "\x8dg\xce\x84\xd5\x12\x1b\x1c\xfc\x168S\xff\xeb\xcf\xcf" + + "}\xf1|.\xf3o\x0e\xed=\x84\xfb\x1a:\x88\x0ck" + + "\x1b\xc8\x0f\x1bZc}m'\x07\xcf\xb9Dr\x97\xb8" + + "\xa3\xb1\x87\x06\xc8\x8d\xb4\xc4\xb2\x07\xba\xd9\x9a\xdb\xef=" + + "W\xc6\x96u\x8d\x1d(=\xdcH\x1366nGi" + + "WS\x03\x80=\xf17;\xee}\xf1\x9b+/\xb8\x9e" + + "\xe8(\xbb\xadi\x09\x11\xf3\xc9\xcf\xf7\xae\xba\xa3\xed\xc8" + + "\x85\xb0\xb2\x1b\x9b\xc87\xa4\x1dM\xb4\xd3\xc8\xed\xe7?" + + "\xbd\xe0\xc9o_\xa8\xe4\x00\x87\x9a\xdaQ:\xd2D\xa0" + + "\x1c\xa6\xc1\xef-\xff\xd3\x13\xcd\xf1\xe6\x8b%\x00V\xd1" + + "\xd8\x9f5\x8d\xa3t\x89\xc6.}\xbf\xe9;D\xca\xe7" + + "\xffl\xdf\xbf\\>v\xd7\xa52\x1b\xce\xce\x1dF\xe9" + + "\xf2\\Z\xf6\xd2\\A\xba4\xf7&\x00\xfb\x91S\x9f" + + "]\xff\x83/|p\xa9\x94G\x8e\"\xef\xceM\xa2\x84" + + "\xf3h\xc6\xaf\xe7\x12\xeb\xbe\xbc\xfa?6\x9d\xdf5\xe7" + + "\x97ek\xef\x997\x8e\xd2!g\xe4\xcb\xf3\xb6K\xb1" + + "\x16\xf2\xa6w\x84\x17\x16\xf7nz\xebr\xc8o/\xcd" + + "\x1b x\x9e\x11\xbezf\xf3O?\xfb\xab0<\xef" + + "\xcf\xfb9\xc1\x13m!x\x1ez\xef\xb9\xbb\xbe\xb8\xe6" + + "\xa5\x0fC4X\xd0\xb2\x85\xa6Z\x05McY#\x1f" + + "I\xdf\xea\xfdL\xdf\x92V\xf2Z\xbe\xa3\xbb`\x8d1" + + "\xcdR\xd3\x8a\xc5\x92\xac\xd3\xcc\xeb\x9a\xc9\x06\x11\xe5z" + + ">\x02\x10A\x00Q\x19\x07\x90\x1f\xe0Q\xcer(\"" + + "&\x88(\xa2J\xc21\x1ee\x8bC\x91\xe3\x12\x14%" + + "\xc5um\x00r\x96Gy=\x87\xc8'\x90\x07\x10\x0b" + + "O\x03\xc8\xeby\x94\xb7rh\xe7\x99\x91S4\xa6A" + + "\xdc\xea3\x0c\xac\x05\x0ek\x01m\x83Y\xc6\xa42\x9c" + + "\x858\x0b\x89\x85\xf1\x07-\x8c\x01\x871@{L/" + + "\x18\xe6\x90f\xa1\x9aM\xb2\x11\x83\x998\x86U\xc0a" + + "\x15\xe0t\xe6\xa5\x98i\xaa\xba\xb6B\xd1\x94Qf\x00" + + "\x90e\xd5|\x14\xc0\xcf@\xe8\xe5*q\xf1n\xe0\xc4" + + "E\x02\x06\xd9\x02=\xb2\x8a\xbfs\x108\xb1E\xb0\x0d" + + "6\xaa\x9a\x163p(\x93w\xd6\xe6u\xad\x0b\xed\x82" + + "\xe6~@f\xb8\x1f\xe2\xb4k\x17\x0eb\xa0\x1d_\xae" + + "\xdd\x9dY\x95iV\xbc_\x1b\xd1K \x1f\xa8\x04\xf9" + + "@\x11\xf2\xad!\xc8\x1f\xee\x01\x90\x1f\xe2Q~\x94C" + + "\x91/b\xbe\xad\x1d@\xde\xcc\xa3\xfc\x04\x87v\xda\xd9" + + "\xa4?\x03\x00>\x9a#L\xb1\x0a\x063IV\x078" + + "\xc8\xa3\x03z\x1d\xe0\xa6\x09f\x90\xee\xde!\xc4\x15#" + + "=\xe6\x1f\xd44H\xf7\xadWMK\xd5FW;\xf2" + + "\xceA=\xab\xa6'\xc9\xaaZG\xcf\x96\x0e\x00D\xf1" + + "\xc6\xfb\x00\x90\x13\xc5\x1e\x80NuT\xd3\x0dfgT" + + "3\xadk\x1a\x03>mm\x1aV\xb2\x8a\x96f\xfeF" + + "U\xe5\x1b\xb9\x1b\xa4\x981\xc1\x8c[\x94\x10}\xe7\x0f" + + "*\x86\xc2\xe7L\xb9\xd6\xc7\xb1\xef>\x00\xb9\x97Gy" + + "0\x84\xe3\x0a\xc2\xf1\x1e\x1e\xe5{C8\x0e\x11\x8e\x83" + + "<\xcak8\xb4uC\x1dU\xb5;\x19\xf0F\x98\x81" + + "\xa6\xa5)9F\x98\x15\xf1\xd8\xa4\xe7-U\xd7L\xac" + + "\x0fr\x0b \xd6\x87\x90\x12f\xe2\xe4-\x1e\xa5(\x09>" + + "\x8e]=\xf8B4LV\xa2\xe1\x12\x009\xc3\xa3\x9c" + + "\xe7\x10\x8b\xe8\xe5zB\xd1\x80G\x97\x85\xebv\x03\xc8" + + "\x16\x8f\xf2f\x0em\xd3\xdd\xa4\x1f0\xe3!\xda\x9a1" + + "\xad\xfe\xbc\xf7\xd7\xa6\x8ci\x0d\xea\x86\x85\x02p(\x00" + + "\xf1V7Y\xf7\x08\xf9T\x7f&\xcb\xeeRy\xcd\xc2" + + "(p\x18\x85i\x9d\xca\xe5G\x9c\x02\x9b\xeb\xed\x9e5" + + "\x0b\x89\x0c\xbf\xc7\xa3\xfc\xfb!k\x16S\x1c\xbb\x8dG" + + "\xf9\x93\x1c\xdaJ:\xad\x174k5\xf0\xcah\x09\xe7" + + "S\x0c\xe2i\x83\x05t\xf0\xb6\xad\xae\xe0\xd6\xba6\xa2" + + "\x8e\x16\x0c\xc5\x0a\x01^\xc8g\x14\x8bM\xf9\xe4\x9cs" + + "\x96\xbf\x8as\xf6\xab\x87k>g/2\x95\x9ct\xdc" + + "Prf\x18\x9bd%l\xe8T?\xc1\xa3|{\xe5" + + "\x03\xdc\x94c\xa6\xa9\x8c\xb2\xb2\xf0\x10\xad\x88\x89\xc6\xd2" + + "du\x92\xb9I\xe6\x16\x83\x99B!k\x91\x16\xb5\xb6" + + "\xed\xaaA\xdc\x9a\xcf\xa3|\x1b\x871\xfc\xd0v\xf5X" + + "\xf4tpF\xad\xcc0t\x03\xeb\x83$\\\x84$]" + + "\xdc\x00u\xad\x97Y\x8a\x9aErK\xbf\xda,\x01n" + + "\xa6\xb8\x12\xc0\xe6\x8a\xe7w\x92w\xe4\xa6\x9c\x14\xd1\xbb" + + "\x9eGy\x1e\x87\xf6\xa8\xa1\xa4\xd9 3P\xd53+" + + "\x15MO\xf1,]F\xd6\xbak\xdd\xd4\xe1\x87e\x82" + + "?k\xfa\xf9\x06+\x82P\x9c>\xd8\xea\xea\x9c\xf0u" + + "\xde\xd8\x16$c\xff\x98\x1f\x1e\x0e\xb2\x85\x1f\x0f\x1f#" + + "gy\x94Gyg(\xaf\xec\xa0\xc8\xf9\x14\x8f\xf2W" + + "9\x14#\x91\x04F\x00\xc4\xe7\x88%;y\x94\xf7r" + + "SS6\x9b`\x9a\xd5\xab\x8e\x82\xc0\xcc@J*\xf6" + + "\xaa\xa3\x0cx\xf3zck\xf5\x0cx\xe8\xc3\xa6\x9ee" + + "\x16\xebe\xe9\xacB.7\xc1\xdc\xefE2z\x87:" + + "\x1do\x93e\xdeC\xfc\x8d{UR\x88\x0em\x81\xe3" + + "\x0a,T\xdcL\xa3\xad\xbb\xb8\x1b\x0c\xca8\x10xL" + + "\x91\x07h~,A\xc7\xb1\x19\xa78\x7fO\xe0u\x1e" + + ")\x16u\x04\x01\xc1\xaf\x09\"\xc0a\x04\xb03\xed," + + "X\x16\x0a#3i\xd5\xe9\xaa\xe5\x02GE\x98w\x17" + + "E\xef\x02/\x8a\xfb\x80\x13c\x82\xedi\x8e\xde|\xa1" + + "\xac\xa0\x8aL\x17eV\xe5-U\xd05\x93\xf6\x0a\xf1" + + "\xbf\xa3\x12\xff\x8d\x80\xff^B{lK\x98\xfe\xc5\x84" + + "\xb6cw\xc0t1\xc2\xb9\xf4\xdf\xb3\x0f@\xde\xcb\xa3" + + "\xfc\x12\x87\x9dn\xad\x85\xf5A\xe3\xa5HY\xb7\xa2\xb8" + + "G\x87\xd6\xb4\x92\x0d\x92\x9em\xb0|VI\xb3>," + + "VO\x80\x08\x1c\xa2\xe3'\xb9\xbc\xc1L\x13U]\x93" + + "\x0bJV\xe5\xadI\xbf\xe2\xd5\x0a\xb9A\x83M\xa8\xa8" + + "\x17\xccn\xcbb9!o\x99WS\x0f\x07\x00Q\x90" + + "\x14\xd4\xacY\x92#\xdb\x03*\xf8\x00-\x1a\x0f\xf2@" + + "\xbcPP\xfd\x04`g\xf5\xb4s\xb2\x10_\xa9\xe4\xca" + + "\xf3@\xd5\x8c\x01kJ\xb8\xf3\xd2\xd2oR\xfd6\xfd" + + "\x95\x89Lw\xee\x14!\x95)\x0et\xf1(\xdf\x13R" + + "\xb9\x7fI\xc8\x0eO\xe5\x15\xc3\x81\x1d\xc2\x1f\xb3IO" + + "\xabV\x96\xa3\xf4\xe5\x81Y4\xa6\x1b\x84\xbb\x831\xd3" + + "\xe9\x17\x8e*\xab\xf2\xad\x8e\x85\xa4\xe3\xed\x9e\x8e\xd2$" + + "\x0e\x00\xa4\xd6#\x8f\xa9\xad\x18\xa8)=\x8c=\x00\xa9" + + "\x87H\xfe(\x06\x9aJ\xdb\xb0\x19 \xb5\x99\xe4O\xa0" + + "\x7f\xb5\x93\x1e\xc3\x83\x00\xa9'H\xfc,\x0d\x8f\xf0\x8e" + + "KH\xbb\x9c\xe5w\x92|/\xc9\xa3\x91\x04F\x01\xa4" + + "=\xd8\x0e\x90z\x96\xe4\xaf\x91\xbc\x8aK`\x15\x80t" + + "\x08\xc7\x01R\xaf\x90\xfc\x0d\x92\x0b\xd1\x04\xddn\xa5\xd7" + + "\xd1\x00H\xfd-\xc9\xbfI\xf2\xea\xc6\x04V\x03HG" + + "\x1c\xf9\x9b$\xff>\xc9k\x9a\x12X\x03 \xfd\x13n" + + "\x01H}\x97\xe4'H>\x0b\x138\x0b@:\x8e\xbb" + + "\x01R'H\xfe\xaf$\x9f]\x95\xc0\xd9\x00\xd2O\x1c" + + "}N\x92\xfc\x17$\xaf\x8d$\xb0\x16@\xfa\x19\xee\x03" + + "H\xfd\x82\xe4\xffM\xf2\x98\x90\xc0\x18\x80\xf4\xaec\xd7" + + "y\x92Ws%7+\x8fQ%\xd7'^7\xfd#" + + "cE\x1fG\x97\xee\x83z\x9c\xaeH\x18\x0f\x1a\xaf\x80" + + "\x18\x07\xb4\xf3\xba\x9e]9\x95\xa9qK\x195\xbd\xab" + + "Z}\xd0\x9a\x02$\xa1_\xfc@\\\xd7\xfa3~ " + + "(\x8d:\x9e&\xaa\xd9]\xb0\xf4B\x1eZ)\xc8f" + + "\xfc\x98c\x14\xb4\xe5\x86\x9e[\x8d\xcc\xc8\xa9\x9a\x92\x9d" + + "!\x1a\xd5\x00\x875P\x0c\x09\xde\xda\xd3\x87\xa6+_" + + "<}Fs\xa5\x8cn\xcdw\xacVF\xaf&N-" + + "\x09rV\\\x0b\x05\xa4\xd6\x09%[\xf8(\xe1ij" + + "=\x95\xect\xeb\xb1\x99\xcau\xaf\xf7T\x12J*T" + + "\x17C\xe5\xf99\xc9\xccV\xbf\x09\x132\xf8`\x10\x83" + + "={\x97\xb5\x85\xee.Y\xc5b\xa6\xd5\x9d\xc7|V" + + "e\x99\xcf0#\x1eN\xd9\x15+\x92\xc8Le\xfa\xd4" + + "2\x07C]r2\x9c+\x1a|\xd5x\x8e2\xcb\xfd" + + "\xd5\xaf\x8d\xe8T\x87\x08\xe1\xe2\xeb\xdaf'\x99\x19\xbf" + + "\x9a\xb3\x08\x9a\x863_\x9d*\x94c\x15\x8a1\xef&" + + "\x10\xba&\x13\x19\xd7\xf0(\x8f\x85\xc8\xc8\x06*\\\x93" + + "\x93A\x7fL\xe4\xb9b\x83\x8c2W\x9eG\xf9!\x0e" + + "\xe3J\xc1\x1a\xc3\xfa\xe0\xf1c\x8a\xd2S{8\xc4\xcd" + + "~-\xc3\x00\xd7{\xee\x15\xcag~W~\xe6\x9a\xf9" + + "\xea\xcc\xf6\xee\"3\x02\xee\xf7\xaeKv\xbe\xe2U\xbd" + + "\xd3\xdd\x94x\xd6\xe8T\x85^\xd7\x1f\xbd\x8e\xb0xh" + + "\x03p\xe2\x01\x01\x83^7z\xadmq\x8f\x01\x9c\xb8" + + "K@\xce\x7f\x97A\xef\xfdE|\xecq\xe0\xc4m\x02" + + "\xf2\xfe\xb3\x0az]\xd2\xc5\x93\xb3\x108q\xa3\x80\x11" + + "\xff=\x0b\xbd\x1e\xab\xb8n\x1c8Q\x150\xea\xbf\xd8" + + "\xa0\xd7\xe2\x17\xd7n\x01N\x1c\x0az\x81\xd0\xe9\xda\xd1" + + "\x85\xb6\xc7QhuX:\xb53\xe8\x8e\x02\xe8B\xdb" + + "\xbb\x99\xf0W\xba\x9a8\xa3\xbc\xe6\x16\xc4\xd3\x8a\xc5\xba" + + "\xa8Zt\x03\x12\x16#\x12t\xa1\x1c\xc1P\x8b\x19\xe0" + + "z[\x03I\xd6\xea\x9c\xf3G\xad\xe1\xbc\xf9\x1f1F" + + "\xf2\x95\xb4\xa6}\xfc&ih]*Kky\x94\x1b" + + "\xb9\x19+\xd1\xc8\x95\xac\xf0\xc8\x1f\xa7\xc9\xb4\xfeo\xfb" + + "\xeb\x1f\xa7\xf0\xfa}\x1e\xe5\x93!\xb7\xfe!\x09\xdf\xe1" + + "Q>\x1d\xaa\xe4~D\xbe~\x92G\xf9b\xd0\xf7~" + + "\xffq\x00\xf9\"\x8f\xc9Pe$\xfe\x9a\x06\xfe\x8a\xea" + + "\x07\xa7.B\xb7.\x8a\xe2\xd3\x00\xa9j\xaa+\x12N" + + "]\x14q\xeb\"\x11\x87\x01R\xf5$\x9f\x17\xae\x8b\x9a" + + "\xf0>\x80T#\xc9\xe7\xe3\xd4\xdb\xa6P0\x82\xca1" + + "\xab\x8f\xde\xa3j\x15\x93\xad\xd7\x88Gk\xb9\xa2f\x0b" + + "\x06\x83 \xd7\x17\x83Mo\xa8\xfcp;\xf4n3." + + "E$\xcc\xa0\xe97\xea\xae\xe1\x9e?]\xe6\xc9\xea\x85" + + "\xccHV1X&\xc5\x0c\xc1\x0d\x08\x83|T\xae\xc6" + + "\xd0\xab7@\xf0:\x19\"\xfb\xb4\x99\xac\xcf0t4" + + "J\xaa\xf4%A\x95\xee\x17\xe9t\xd9\xb8\x8bGy5" + + "\x1dm\x97{\xb4\xf2pp\xafhM+\x05\x93\x95a" + + "\x02<3\xfc^\x8f9\xa6\x17\xb2\x99$\x03\xc12&" + + "K \x9d\xb1ZO\xb1\xb8\x17\x09\xddG\x0a\xefy\x0e" + + "\xbdW\xb8\xd0#\x85\xf7R\x84\xdeSo\xf9#\x85\x87" + + "A\xd9#\x85\xfb\xc1\xe1\xfc\xd4;\xf5u4)\xdc\xb4" + + "\x18:\x94k\xea\xdd_u\xcb\xdb\xffw\x89\x92\xc8Q" + + "s\xbd\xcd /\xc1\xfd\x7f\x00\x00\x00\xff\xff\xf1\xc3d" + + "\xc6" func init() { schemas.Register(schema_db8274f9144abc7e, @@ -4089,6 +4546,7 @@ func init() { 0x8635c6b4f45bf5cd, 0x904e297b87fbecea, 0x9496331ab9cd463f, + 0x958096448eb3373e, 0x96b74375ce9b0ef6, 0x97b3c5c260257622, 0x9b87b390babc2ccf, @@ -4097,6 +4555,8 @@ func init() { 0xa766b24d4fe5da35, 0xab6d5210c1f26687, 0xb046e578094b1ead, + 0xb177ca2526a3ca76, + 0xb48edfbdaa25db04, 0xb4bf9861fe035d04, 0xb5f39f082b9ac18a, 0xb70431c0dc014915, @@ -4104,6 +4564,7 @@ func init() { 0xc793e50592935b4a, 0xcbd96442ae3bb01a, 0xd4d18de97bb12de3, + 0xdb58ff694ba05cf9, 0xdbaa9d03d52b62dc, 0xdc3ed6801961e502, 0xe3e37d096a5b564e, @@ -4114,6 +4575,7 @@ func init() { 0xf2c122394f447e8e, 0xf2c68e2547ec3866, 0xf41a0f001ad49e46, + 0xf548cef9dea2a4a1, 0xf5f383d2785edb86, 0xf71695ec7fe85497, 0xf9cb7f4431a307d0, diff --git a/websocket/connection.go b/websocket/connection.go index 79665902..83468b89 100644 --- a/websocket/connection.go +++ b/websocket/connection.go @@ -3,8 +3,10 @@ package websocket import ( "bytes" "context" + "errors" "fmt" "io" + "sync" "time" gobwas "github.com/gobwas/ws" @@ -14,9 +16,6 @@ import ( ) const ( - // Time allowed to write a message to the peer. - writeWait = 10 * time.Second - // Time allowed to read the next pong message from the peer. defaultPongWait = 60 * time.Second @@ -79,34 +78,20 @@ func (c *GorillaConn) SetDeadline(t time.Time) error { return nil } -// pinger simulates the websocket connection to keep it alive -func (c *GorillaConn) pinger(ctx context.Context) { - ticker := time.NewTicker(defaultPingPeriod) - defer ticker.Stop() - for { - select { - case <-ticker.C: - if err := c.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(writeWait)); err != nil { - c.log.Debug().Msgf("failed to send ping message: %s", err) - } - case <-ctx.Done(): - return - } - } -} - type Conn struct { rw io.ReadWriter log *zerolog.Logger - // closed is a channel to indicate if Conn has been fully terminated - shutdownC chan struct{} + // writeLock makes sure + // 1. Only one write at a time. The pinger and Stream function can both call write. + // 2. Close only returns after in progress Write is finished, and no more Write will succeed after calling Close. + writeLock sync.Mutex + done bool } func NewConn(ctx context.Context, rw io.ReadWriter, log *zerolog.Logger) *Conn { c := &Conn{ - rw: rw, - log: log, - shutdownC: make(chan struct{}), + rw: rw, + log: log, } go c.pinger(ctx) return c @@ -121,16 +106,22 @@ func (c *Conn) Read(reader []byte) (int, error) { return copy(reader, data), nil } -// Write will write messages to the websocket connection +// Write will write messages to the websocket connection. +// It will not write to the connection after Close is called to fix TUN-5184 func (c *Conn) Write(p []byte) (int, error) { + c.writeLock.Lock() + defer c.writeLock.Unlock() + if c.done { + return 0, errors.New("write to closed websocket connection") + } if err := wsutil.WriteServerBinary(c.rw, p); err != nil { return 0, err } + return len(p), nil } func (c *Conn) pinger(ctx context.Context) { - defer close(c.shutdownC) pongMessge := wsutil.Message{ OpCode: gobwas.OpPong, Payload: []byte{}, @@ -141,7 +132,11 @@ func (c *Conn) pinger(ctx context.Context) { for { select { case <-ticker.C: - if err := wsutil.WriteServerMessage(c.rw, gobwas.OpPing, []byte{}); err != nil { + done, err := c.ping() + if done { + return + } + if err != nil { c.log.Debug().Err(err).Msgf("failed to write ping message") } if err := wsutil.HandleClientControlMessage(c.rw, pongMessge); err != nil { @@ -153,6 +148,17 @@ func (c *Conn) pinger(ctx context.Context) { } } +func (c *Conn) ping() (bool, error) { + c.writeLock.Lock() + defer c.writeLock.Unlock() + + if c.done { + return true, nil + } + + return false, wsutil.WriteServerMessage(c.rw, gobwas.OpPing, []byte{}) +} + func (c *Conn) pingPeriod(ctx context.Context) time.Duration { if val := ctx.Value(PingPeriodContextKey); val != nil { if period, ok := val.(time.Duration); ok { @@ -162,7 +168,9 @@ func (c *Conn) pingPeriod(ctx context.Context) time.Duration { return defaultPingPeriod } -// Close waits for pinger to terminate -func (c *Conn) WaitForShutdown() { - <-c.shutdownC +// Close waits for the current write to finish. Further writes will return error +func (c *Conn) Close() { + c.writeLock.Lock() + defer c.writeLock.Unlock() + c.done = true } diff --git a/websocket/websocket.go b/websocket/websocket.go index b94b4f54..67c8916b 100644 --- a/websocket/websocket.go +++ b/websocket/websocket.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "net/http" + "sync/atomic" "time" "github.com/gorilla/websocket" @@ -28,28 +29,64 @@ func NewResponseHeader(req *http.Request) http.Header { return header } +type bidirectionalStreamStatus struct { + doneChan chan struct{} + anyDone uint32 +} + +func newBiStreamStatus() *bidirectionalStreamStatus { + return &bidirectionalStreamStatus{ + doneChan: make(chan struct{}, 2), + anyDone: 0, + } +} + +func (s *bidirectionalStreamStatus) markUniStreamDone() { + atomic.StoreUint32(&s.anyDone, 1) + s.doneChan <- struct{}{} +} + +func (s *bidirectionalStreamStatus) waitAnyDone() { + <-s.doneChan +} +func (s *bidirectionalStreamStatus) isAnyDone() bool { + return atomic.LoadUint32(&s.anyDone) > 0 +} + // Stream copies copy data to & from provided io.ReadWriters. func Stream(tunnelConn, originConn io.ReadWriter, log *zerolog.Logger) { - proxyDone := make(chan struct{}, 2) + status := newBiStreamStatus() - go func() { - _, err := copyData(tunnelConn, originConn, "origin->tunnel") - if err != nil { - log.Debug().Msgf("origin to tunnel copy: %v", err) - } - proxyDone <- struct{}{} - }() - - go func() { - _, err := copyData(originConn, tunnelConn, "tunnel->origin") - if err != nil { - log.Debug().Msgf("tunnel to origin copy: %v", err) - } - proxyDone <- struct{}{} - }() + go unidirectionalStream(tunnelConn, originConn, "origin->tunnel", status, log) + go unidirectionalStream(originConn, tunnelConn, "tunnel->origin", status, log) // If one side is done, we are done. - <-proxyDone + status.waitAnyDone() +} + +func unidirectionalStream(dst io.Writer, src io.Reader, dir string, status *bidirectionalStreamStatus, log *zerolog.Logger) { + defer func() { + // The bidirectional streaming spawns 2 goroutines to stream each direction. + // If any ends, the callstack returns, meaning the Tunnel request/stream (depending on http2 vs quic) will + // close. In such case, if the other direction did not stop (due to application level stopping, e.g., if a + // server/origin listens forever until closure), it may read/write from the underlying ReadWriter (backed by + // the Edge<->cloudflared transport) in an unexpected state. + + if status.isAnyDone() { + // Because of this, we set this recover() logic, which kicks-in *only* if any stream is known to have + // exited. In such case, we stop a possible panic from propagating upstream. + if r := recover(); r != nil { + // We handle such unexpected errors only when we detect that one side of the streaming is done. + log.Debug().Msgf("Handled gracefully error %v in Streaming for %s", r, dir) + } + } + }() + + _, err := copyData(dst, src, dir) + if err != nil { + log.Debug().Msgf("%s copy: %v", dir, err) + } + status.markUniStreamDone() } // when set to true, enables logging of content copied to/from origin and tunnel