Merge af04ee52f1 into d2a87e9b93
This commit is contained in:
commit
c47b569166
|
|
@ -3,6 +3,7 @@ package tunnel
|
|||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
|
|
@ -32,6 +33,7 @@ import (
|
|||
"github.com/cloudflare/cloudflared/diagnostic"
|
||||
"github.com/cloudflare/cloudflared/edgediscovery"
|
||||
"github.com/cloudflare/cloudflared/ingress"
|
||||
"github.com/cloudflare/cloudflared/k8s"
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
"github.com/cloudflare/cloudflared/management"
|
||||
"github.com/cloudflare/cloudflared/metrics"
|
||||
|
|
@ -174,6 +176,7 @@ func Commands() []*cli.Command {
|
|||
buildCleanupCommand(),
|
||||
buildTokenCommand(),
|
||||
buildDiagCommand(),
|
||||
buildKubernetesSubcommand(),
|
||||
proxydns.Command(), // removed feature, only here for error message
|
||||
cliutil.RemovedCommand("db-connect"),
|
||||
}
|
||||
|
|
@ -445,6 +448,45 @@ func StartServer(
|
|||
return err
|
||||
}
|
||||
|
||||
// Start Kubernetes service watcher if enabled
|
||||
cfg := config.GetConfiguration()
|
||||
if cfg.Kubernetes.Enabled {
|
||||
k8sCfg := &k8s.Config{
|
||||
Enabled: cfg.Kubernetes.Enabled,
|
||||
Namespace: cfg.Kubernetes.Namespace,
|
||||
BaseDomain: cfg.Kubernetes.BaseDomain,
|
||||
KubeconfigPath: cfg.Kubernetes.KubeconfigPath,
|
||||
ExposeAPIServer: cfg.Kubernetes.ExposeAPIServer,
|
||||
APIServerHostname: cfg.Kubernetes.APIServerHostname,
|
||||
LabelSelector: cfg.Kubernetes.LabelSelector,
|
||||
}
|
||||
if err := k8sCfg.Validate(); err != nil {
|
||||
log.Warn().Err(err).Msg("Kubernetes config validation failed, watcher will not start")
|
||||
} else {
|
||||
k8sWatcher := k8s.NewWatcher(k8sCfg, log, func(services []k8s.ServiceInfo) {
|
||||
log.Info().Int("count", len(services)).Msg("Kubernetes service change detected, updating ingress rules")
|
||||
k8sRules := k8s.GenerateIngressRules(services, log)
|
||||
updatedIngress := k8s.MergeWithExistingRules(cfg.Ingress, k8sRules)
|
||||
newConfigBytes, err := json.Marshal(ingress.RemoteConfigJSON{
|
||||
IngressRules: updatedIngress,
|
||||
WarpRouting: cfg.WarpRouting,
|
||||
})
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to marshal updated K8s ingress config")
|
||||
return
|
||||
}
|
||||
resp := orchestrator.UpdateK8sConfig(newConfigBytes)
|
||||
if resp.Err != nil {
|
||||
log.Err(resp.Err).Msg("Failed to apply K8s ingress config update")
|
||||
} else {
|
||||
log.Info().Int("services", len(services)).Msg("Successfully applied K8s ingress config update")
|
||||
}
|
||||
})
|
||||
go k8sWatcher.Run(ctx)
|
||||
log.Info().Msg("Kubernetes service watcher started")
|
||||
}
|
||||
}
|
||||
|
||||
metricsListener, err := metrics.CreateMetricsListener(&listeners, c.String("metrics"))
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Error opening metrics server listener")
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ import (
|
|||
"github.com/cloudflare/cloudflared/features"
|
||||
"github.com/cloudflare/cloudflared/ingress"
|
||||
"github.com/cloudflare/cloudflared/ingress/origins"
|
||||
"github.com/cloudflare/cloudflared/k8s"
|
||||
"github.com/cloudflare/cloudflared/orchestration"
|
||||
"github.com/cloudflare/cloudflared/supervisor"
|
||||
"github.com/cloudflare/cloudflared/tlsconfig"
|
||||
|
|
@ -151,6 +152,36 @@ func prepareTunnelConfig(
|
|||
}
|
||||
|
||||
cfg := config.GetConfiguration()
|
||||
|
||||
// If Kubernetes integration is enabled, discover services and merge with config
|
||||
if cfg.Kubernetes.Enabled {
|
||||
k8sCfg := &k8s.Config{
|
||||
Enabled: cfg.Kubernetes.Enabled,
|
||||
Namespace: cfg.Kubernetes.Namespace,
|
||||
BaseDomain: cfg.Kubernetes.BaseDomain,
|
||||
KubeconfigPath: cfg.Kubernetes.KubeconfigPath,
|
||||
ExposeAPIServer: cfg.Kubernetes.ExposeAPIServer,
|
||||
APIServerHostname: cfg.Kubernetes.APIServerHostname,
|
||||
LabelSelector: cfg.Kubernetes.LabelSelector,
|
||||
}
|
||||
if err := k8sCfg.Validate(); err != nil {
|
||||
log.Warn().Err(err).Msg("Kubernetes integration config validation failed, skipping K8s discovery")
|
||||
} else {
|
||||
// Use a timeout so K8s discovery doesn't block tunnel startup indefinitely.
|
||||
k8sCtx, k8sCancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
services, err := k8s.DiscoverServices(k8sCtx, k8sCfg, log)
|
||||
k8sCancel()
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to discover Kubernetes services, continuing without K8s rules")
|
||||
} else if len(services) > 0 {
|
||||
k8sRules := k8s.GenerateIngressRules(services, log)
|
||||
cfg.Ingress = k8s.MergeWithExistingRules(cfg.Ingress, k8sRules)
|
||||
log.Info().Int("k8sServices", len(services)).Int("totalRules", len(cfg.Ingress)).
|
||||
Msg("Merged Kubernetes-discovered services into ingress rules")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ingressRules, err := ingress.ParseIngressFromConfigAndCLI(cfg, c, log)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
|
|
|||
|
|
@ -0,0 +1,295 @@
|
|||
package tunnel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/urfave/cli/v2"
|
||||
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
|
||||
"github.com/cloudflare/cloudflared/k8s"
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
k8sBaseDomainFlag = "k8s-base-domain"
|
||||
k8sNamespaceFlag = "k8s-namespace"
|
||||
k8sKubeconfigFlag = "k8s-kubeconfig"
|
||||
k8sExposeAPIServerFlag = "k8s-expose-api-server"
|
||||
k8sAPIServerHostnameFlag = "k8s-api-server-hostname"
|
||||
k8sLabelSelectorFlag = "k8s-label-selector"
|
||||
k8sOutputFormatFlag = "k8s-output"
|
||||
)
|
||||
|
||||
func buildKubernetesSubcommand() *cli.Command {
|
||||
return &cli.Command{
|
||||
Name: "kubernetes",
|
||||
Aliases: []string{"k8s"},
|
||||
Category: "Tunnel",
|
||||
Usage: "Discover and manage Kubernetes services exposed through Cloudflare Tunnel",
|
||||
Description: ` The kubernetes subcommand provides native integration between cloudflared and
|
||||
Kubernetes clusters. It can automatically discover annotated Kubernetes
|
||||
services and generate ingress rules for them.
|
||||
|
||||
To mark a service for exposure through the tunnel, add the annotation:
|
||||
cloudflared.cloudflare.com/tunnel: "true"
|
||||
|
||||
Optional annotations:
|
||||
cloudflared.cloudflare.com/hostname: Override the public hostname
|
||||
cloudflared.cloudflare.com/path: Path regex for the ingress rule
|
||||
cloudflared.cloudflare.com/scheme: Origin scheme (http/https)
|
||||
cloudflared.cloudflare.com/port: Select which port to proxy
|
||||
cloudflared.cloudflare.com/no-tls-verify: Disable TLS verification
|
||||
cloudflared.cloudflare.com/origin-server-name: Set SNI for TLS
|
||||
|
||||
Example:
|
||||
# Discover services from the current cluster
|
||||
cloudflared tunnel kubernetes discover --k8s-base-domain example.com
|
||||
|
||||
# Watch for changes continuously
|
||||
cloudflared tunnel kubernetes watch --k8s-base-domain example.com
|
||||
|
||||
# Generate an ingress config YAML snippet
|
||||
cloudflared tunnel kubernetes generate-config --k8s-base-domain example.com`,
|
||||
Subcommands: []*cli.Command{
|
||||
buildK8sDiscoverCommand(),
|
||||
buildK8sWatchCommand(),
|
||||
buildK8sGenerateConfigCommand(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func k8sFlags() []cli.Flag {
|
||||
return []cli.Flag{
|
||||
&cli.StringFlag{
|
||||
Name: k8sBaseDomainFlag,
|
||||
Usage: "Base domain for auto-generated hostnames (e.g. example.com). Services will be exposed as <name>-<namespace>.example.com",
|
||||
EnvVars: []string{"TUNNEL_K8S_BASE_DOMAIN"},
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: k8sNamespaceFlag,
|
||||
Usage: "Limit discovery to a specific Kubernetes namespace. Empty means all namespaces.",
|
||||
EnvVars: []string{"TUNNEL_K8S_NAMESPACE"},
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: k8sKubeconfigFlag,
|
||||
Usage: "Path to a kubeconfig file. When empty, in-cluster config is used.",
|
||||
EnvVars: []string{"KUBECONFIG"},
|
||||
},
|
||||
&cli.BoolFlag{
|
||||
Name: k8sExposeAPIServerFlag,
|
||||
Usage: "Also expose the Kubernetes API server through the tunnel",
|
||||
EnvVars: []string{"TUNNEL_K8S_EXPOSE_API_SERVER"},
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: k8sAPIServerHostnameFlag,
|
||||
Usage: "Public hostname for the Kubernetes API server (required when --k8s-expose-api-server is set)",
|
||||
EnvVars: []string{"TUNNEL_K8S_API_SERVER_HOSTNAME"},
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: k8sLabelSelectorFlag,
|
||||
Usage: "Kubernetes label selector to filter services (e.g. app=web)",
|
||||
EnvVars: []string{"TUNNEL_K8S_LABEL_SELECTOR"},
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: k8sOutputFormatFlag,
|
||||
Usage: "Output format: json, yaml, or table (default: table)",
|
||||
Value: "table",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func k8sConfigFromCLI(c *cli.Context) *k8s.Config {
|
||||
return &k8s.Config{
|
||||
Enabled: true,
|
||||
BaseDomain: c.String(k8sBaseDomainFlag),
|
||||
Namespace: c.String(k8sNamespaceFlag),
|
||||
KubeconfigPath: c.String(k8sKubeconfigFlag),
|
||||
ExposeAPIServer: c.Bool(k8sExposeAPIServerFlag),
|
||||
APIServerHostname: c.String(k8sAPIServerHostnameFlag),
|
||||
LabelSelector: c.String(k8sLabelSelectorFlag),
|
||||
}
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// discover subcommand
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
func buildK8sDiscoverCommand() *cli.Command {
|
||||
return &cli.Command{
|
||||
Name: "discover",
|
||||
Usage: "Discover annotated Kubernetes services",
|
||||
Flags: k8sFlags(),
|
||||
Action: cliutil.ConfiguredAction(k8sDiscoverAction),
|
||||
}
|
||||
}
|
||||
|
||||
func k8sDiscoverAction(c *cli.Context) error {
|
||||
log := logger.CreateLoggerFromContext(c, logger.EnableTerminalLog)
|
||||
cfg := k8sConfigFromCLI(c)
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(c.Context)
|
||||
defer cancel()
|
||||
|
||||
services, err := k8s.DiscoverServices(ctx, cfg, log)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return printServices(services, c.String(k8sOutputFormatFlag), log)
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// watch subcommand
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
func buildK8sWatchCommand() *cli.Command {
|
||||
return &cli.Command{
|
||||
Name: "watch",
|
||||
Usage: "Continuously watch for Kubernetes service changes",
|
||||
Flags: k8sFlags(),
|
||||
Action: cliutil.ConfiguredAction(k8sWatchAction),
|
||||
}
|
||||
}
|
||||
|
||||
func k8sWatchAction(c *cli.Context) error {
|
||||
log := logger.CreateLoggerFromContext(c, logger.EnableTerminalLog)
|
||||
cfg := k8sConfigFromCLI(c)
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(c.Context)
|
||||
defer cancel()
|
||||
|
||||
outputFormat := c.String(k8sOutputFormatFlag)
|
||||
|
||||
watcher := k8s.NewWatcher(cfg, log, func(services []k8s.ServiceInfo) {
|
||||
log.Info().Int("count", len(services)).Msg("Service change detected")
|
||||
if err := printServices(services, outputFormat, log); err != nil {
|
||||
log.Err(err).Msg("Failed to print services")
|
||||
}
|
||||
})
|
||||
|
||||
// Handle OS signals for graceful shutdown
|
||||
sigC := make(chan os.Signal, 1)
|
||||
signal.Notify(sigC, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigC
|
||||
log.Info().Msg("Received shutdown signal, stopping watcher...")
|
||||
cancel()
|
||||
}()
|
||||
|
||||
watcher.Run(ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// generate-config subcommand
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
func buildK8sGenerateConfigCommand() *cli.Command {
|
||||
return &cli.Command{
|
||||
Name: "generate-config",
|
||||
Usage: "Generate cloudflared ingress configuration from discovered Kubernetes services",
|
||||
Flags: k8sFlags(),
|
||||
Action: cliutil.ConfiguredAction(k8sGenerateConfigAction),
|
||||
}
|
||||
}
|
||||
|
||||
func k8sGenerateConfigAction(c *cli.Context) error {
|
||||
log := logger.CreateLoggerFromContext(c, logger.EnableTerminalLog)
|
||||
cfg := k8sConfigFromCLI(c)
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(c.Context)
|
||||
defer cancel()
|
||||
|
||||
services, err := k8s.DiscoverServices(ctx, cfg, log)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(services) == 0 {
|
||||
log.Warn().Msg("No annotated Kubernetes services found")
|
||||
return nil
|
||||
}
|
||||
|
||||
rules := k8s.GenerateIngressRules(services, log)
|
||||
|
||||
// Output as YAML config snippet
|
||||
fmt.Println("# Auto-generated cloudflared ingress configuration from Kubernetes services")
|
||||
fmt.Println("# Add the following to your cloudflared config.yml under the 'ingress' key:")
|
||||
fmt.Println("ingress:")
|
||||
for _, r := range rules {
|
||||
if r.Hostname != "" {
|
||||
fmt.Printf(" - hostname: %s\n", r.Hostname)
|
||||
} else {
|
||||
fmt.Println(" - hostname: \"*\"")
|
||||
}
|
||||
if r.Path != "" {
|
||||
fmt.Printf(" path: %s\n", r.Path)
|
||||
}
|
||||
fmt.Printf(" service: %s\n", r.Service)
|
||||
|
||||
hasNoTLS := r.OriginRequest.NoTLSVerify != nil && *r.OriginRequest.NoTLSVerify
|
||||
hasSNI := r.OriginRequest.OriginServerName != nil && *r.OriginRequest.OriginServerName != ""
|
||||
if hasNoTLS || hasSNI {
|
||||
fmt.Println(" originRequest:")
|
||||
if hasNoTLS {
|
||||
fmt.Println(" noTLSVerify: true")
|
||||
}
|
||||
if hasSNI {
|
||||
fmt.Printf(" originServerName: %s\n", *r.OriginRequest.OriginServerName)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Add catch-all
|
||||
fmt.Println(" - service: http_status:404")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Output helpers
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
func printServices(services []k8s.ServiceInfo, format string, log *zerolog.Logger) error {
|
||||
if len(services) == 0 {
|
||||
log.Info().Msg("No annotated Kubernetes services found")
|
||||
return nil
|
||||
}
|
||||
|
||||
switch format {
|
||||
case "json":
|
||||
enc := json.NewEncoder(os.Stdout)
|
||||
enc.SetIndent("", " ")
|
||||
return enc.Encode(services)
|
||||
case "yaml":
|
||||
for _, s := range services {
|
||||
fmt.Printf("- name: %s\n namespace: %s\n hostname: %s\n origin: %s\n",
|
||||
s.Name, s.Namespace, s.Hostname, s.OriginURL())
|
||||
if s.Path != "" {
|
||||
fmt.Printf(" path: %s\n", s.Path)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
default: // table
|
||||
fmt.Printf("%-30s %-15s %-40s %-35s %s\n", "SERVICE", "NAMESPACE", "HOSTNAME", "ORIGIN", "PATH")
|
||||
fmt.Printf("%-30s %-15s %-40s %-35s %s\n", "-------", "---------", "--------", "------", "----")
|
||||
for _, s := range services {
|
||||
fmt.Printf("%-30s %-15s %-40s %-35s %s\n", s.Name, s.Namespace, s.Hostname, s.OriginURL(), s.Path)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
|
@ -257,9 +257,30 @@ type Configuration struct {
|
|||
Ingress []UnvalidatedIngressRule
|
||||
WarpRouting WarpRoutingConfig `yaml:"warp-routing"`
|
||||
OriginRequest OriginRequestConfig `yaml:"originRequest"`
|
||||
Kubernetes KubernetesConfig `yaml:"kubernetes"`
|
||||
sourceFile string
|
||||
}
|
||||
|
||||
// KubernetesConfig holds the configuration for the Kubernetes service discovery
|
||||
// integration. When enabled, cloudflared will discover annotated Kubernetes
|
||||
// services and automatically generate ingress rules for them.
|
||||
type KubernetesConfig struct {
|
||||
// Enabled turns the Kubernetes watcher on.
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
// Namespace limits discovery to a single namespace. Empty means all namespaces.
|
||||
Namespace string `yaml:"namespace,omitempty" json:"namespace,omitempty"`
|
||||
// BaseDomain is the base domain appended when generating hostnames.
|
||||
BaseDomain string `yaml:"baseDomain,omitempty" json:"baseDomain,omitempty"`
|
||||
// KubeconfigPath is an optional path to a kubeconfig file.
|
||||
KubeconfigPath string `yaml:"kubeconfigPath,omitempty" json:"kubeconfigPath,omitempty"`
|
||||
// ExposeAPIServer, when true, creates an ingress rule for the K8s API server.
|
||||
ExposeAPIServer bool `yaml:"exposeAPIServer,omitempty" json:"exposeAPIServer,omitempty"`
|
||||
// APIServerHostname is the public hostname for the K8s API server.
|
||||
APIServerHostname string `yaml:"apiServerHostname,omitempty" json:"apiServerHostname,omitempty"`
|
||||
// LabelSelector is an optional Kubernetes label selector.
|
||||
LabelSelector string `yaml:"labelSelector,omitempty" json:"labelSelector,omitempty"`
|
||||
}
|
||||
|
||||
type WarpRoutingConfig struct {
|
||||
ConnectTimeout *CustomDuration `yaml:"connectTimeout" json:"connectTimeout,omitempty"`
|
||||
MaxActiveFlows *uint64 `yaml:"maxActiveFlows" json:"maxActiveFlows,omitempty"`
|
||||
|
|
|
|||
|
|
@ -0,0 +1,94 @@
|
|||
// Package k8s provides Kubernetes service discovery and automatic ingress rule
|
||||
// generation for Cloudflare Tunnel. When running inside (or with access to) a
|
||||
// Kubernetes cluster, this package can watch for annotated Services and
|
||||
// automatically expose them through the tunnel without manual ingress
|
||||
// configuration.
|
||||
package k8s
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// AnnotationEnabled is the annotation key that must be set to "true" on a
|
||||
// Kubernetes Service for it to be discovered and exposed through the tunnel.
|
||||
AnnotationEnabled = "cloudflared.cloudflare.com/tunnel"
|
||||
|
||||
// AnnotationHostname optionally overrides the hostname that will be used
|
||||
// in the generated ingress rule. If not set, a hostname is synthesised from
|
||||
// the service name, namespace, and the configured base domain.
|
||||
AnnotationHostname = "cloudflared.cloudflare.com/hostname"
|
||||
|
||||
// AnnotationPath optionally specifies a path regex for the ingress rule.
|
||||
AnnotationPath = "cloudflared.cloudflare.com/path"
|
||||
|
||||
// AnnotationScheme overrides the scheme used to reach the origin.
|
||||
// Defaults to "http" for non-TLS ports and "https" for port 443.
|
||||
AnnotationScheme = "cloudflared.cloudflare.com/scheme"
|
||||
|
||||
// AnnotationPort overrides which service port to route traffic to when
|
||||
// the service exposes multiple ports. If unset the first port is used.
|
||||
AnnotationPort = "cloudflared.cloudflare.com/port"
|
||||
|
||||
// AnnotationNoTLSVerify disables TLS verification for the origin.
|
||||
AnnotationNoTLSVerify = "cloudflared.cloudflare.com/no-tls-verify"
|
||||
|
||||
// AnnotationOriginServerName sets the SNI for TLS connections to the origin.
|
||||
AnnotationOriginServerName = "cloudflared.cloudflare.com/origin-server-name"
|
||||
|
||||
// DefaultResyncPeriod is how often the informer re-lists all Services even
|
||||
// if no watch events have been received.
|
||||
DefaultResyncPeriod = 30 * time.Second
|
||||
)
|
||||
|
||||
// Config holds the user-facing configuration for the Kubernetes integration.
|
||||
type Config struct {
|
||||
// Enabled turns the Kubernetes watcher on.
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
|
||||
// Namespace limits discovery to a single namespace. Empty means all namespaces.
|
||||
Namespace string `yaml:"namespace,omitempty" json:"namespace,omitempty"`
|
||||
|
||||
// BaseDomain is the base domain appended when generating hostnames, e.g.
|
||||
// "example.com" results in "<svc>-<ns>.example.com".
|
||||
BaseDomain string `yaml:"baseDomain,omitempty" json:"baseDomain,omitempty"`
|
||||
|
||||
// KubeconfigPath is an optional path to a kubeconfig file. When empty the
|
||||
// in-cluster config is used.
|
||||
KubeconfigPath string `yaml:"kubeconfigPath,omitempty" json:"kubeconfigPath,omitempty"`
|
||||
|
||||
// ExposeAPIServer, when true, creates an ingress rule for the Kubernetes
|
||||
// API server (typically at https://kubernetes.default.svc).
|
||||
ExposeAPIServer bool `yaml:"exposeAPIServer,omitempty" json:"exposeAPIServer,omitempty"`
|
||||
|
||||
// APIServerHostname is the public hostname through which the K8s API server
|
||||
// will be reachable. Required when ExposeAPIServer is true.
|
||||
APIServerHostname string `yaml:"apiServerHostname,omitempty" json:"apiServerHostname,omitempty"`
|
||||
|
||||
// LabelSelector is an optional Kubernetes label selector (e.g. "app=web")
|
||||
// to filter which services to consider. Works in addition to the annotation
|
||||
// check.
|
||||
LabelSelector string `yaml:"labelSelector,omitempty" json:"labelSelector,omitempty"`
|
||||
|
||||
// ResyncPeriod controls how often the full service list is re-synchronised.
|
||||
// Defaults to DefaultResyncPeriod.
|
||||
ResyncPeriod time.Duration `yaml:"resyncPeriod,omitempty" json:"resyncPeriod,omitempty"`
|
||||
}
|
||||
|
||||
// Validate checks that the configuration is internally consistent.
|
||||
func (c *Config) Validate() error {
|
||||
if !c.Enabled {
|
||||
return nil
|
||||
}
|
||||
if c.BaseDomain == "" {
|
||||
return fmt.Errorf("kubernetes.baseDomain is required when kubernetes integration is enabled")
|
||||
}
|
||||
if c.ExposeAPIServer && c.APIServerHostname == "" {
|
||||
return fmt.Errorf("kubernetes.apiServerHostname is required when exposeAPIServer is true")
|
||||
}
|
||||
if c.ResyncPeriod == 0 {
|
||||
c.ResyncPeriod = DefaultResyncPeriod
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -0,0 +1,440 @@
|
|||
package k8s
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
// ServiceInfo represents a discovered Kubernetes Service with enough
|
||||
// information to build an ingress rule.
|
||||
type ServiceInfo struct {
|
||||
Name string `json:"name"`
|
||||
Namespace string `json:"namespace"`
|
||||
// ClusterIP is the internal IP of the service.
|
||||
ClusterIP string `json:"clusterIP"`
|
||||
// Port is the port selected for proxying.
|
||||
Port int32 `json:"port"`
|
||||
// PortName is the name of the selected port (if any).
|
||||
PortName string `json:"portName,omitempty"`
|
||||
// Scheme is http or https.
|
||||
Scheme string `json:"scheme"`
|
||||
// Hostname is the fully-qualified public hostname.
|
||||
Hostname string `json:"hostname"`
|
||||
// Path is an optional path regex from the annotation.
|
||||
Path string `json:"path,omitempty"`
|
||||
// NoTLSVerify disables TLS certificate verification for the origin.
|
||||
NoTLSVerify bool `json:"noTLSVerify,omitempty"`
|
||||
// OriginServerName is the SNI server name for TLS.
|
||||
OriginServerName string `json:"originServerName,omitempty"`
|
||||
}
|
||||
|
||||
// OriginURL returns the URL that cloudflared should proxy traffic to.
|
||||
func (s *ServiceInfo) OriginURL() string {
|
||||
return fmt.Sprintf("%s://%s:%d", s.Scheme, s.ClusterIP, s.Port)
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Lightweight Kubernetes client — no dependency on client-go
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
// kubeClient is a minimal Kubernetes REST client that can list and watch
|
||||
// Service resources.
|
||||
type kubeClient struct {
|
||||
baseURL string
|
||||
httpClient *http.Client
|
||||
token string
|
||||
log *zerolog.Logger
|
||||
}
|
||||
|
||||
// newInClusterClient builds a kubeClient from the standard in-cluster service
|
||||
// account files.
|
||||
func newInClusterClient(log *zerolog.Logger) (*kubeClient, error) {
|
||||
const (
|
||||
tokenPath = "/var/run/secrets/kubernetes.io/serviceaccount/token" //nolint:gosec // Not a credential, this is a well-known file path
|
||||
caPath = "/var/run/secrets/kubernetes.io/serviceaccount/ca.crt"
|
||||
nsPath = "/var/run/secrets/kubernetes.io/serviceaccount/namespace"
|
||||
serviceEnv = "KUBERNETES_SERVICE_HOST"
|
||||
portEnv = "KUBERNETES_SERVICE_PORT"
|
||||
)
|
||||
|
||||
host := os.Getenv(serviceEnv)
|
||||
port := os.Getenv(portEnv)
|
||||
if host == "" || port == "" {
|
||||
return nil, fmt.Errorf("not running inside a Kubernetes cluster (KUBERNETES_SERVICE_HOST/PORT not set)")
|
||||
}
|
||||
|
||||
tokenBytes, err := os.ReadFile(tokenPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot read service account token: %w", err)
|
||||
}
|
||||
|
||||
// Load the cluster CA certificate for TLS verification against the API server.
|
||||
httpClient := &http.Client{Timeout: 30 * time.Second}
|
||||
caCert, err := os.ReadFile(caPath)
|
||||
if err != nil {
|
||||
// If CA cert is not available, fall back to default system trust store
|
||||
// but log a warning — TLS may fail for self-signed API server certs.
|
||||
if log != nil {
|
||||
log.Warn().Err(err).Msg("Could not load in-cluster CA cert, falling back to system trust store")
|
||||
}
|
||||
} else {
|
||||
caCertPool := x509.NewCertPool()
|
||||
caCertPool.AppendCertsFromPEM(caCert)
|
||||
httpClient.Transport = &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
RootCAs: caCertPool,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return &kubeClient{
|
||||
baseURL: fmt.Sprintf("https://%s:%s", host, port),
|
||||
httpClient: httpClient,
|
||||
token: strings.TrimSpace(string(tokenBytes)),
|
||||
log: log,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// newKubeconfigClient builds a kubeClient from a kubeconfig-style file.
|
||||
// This is a simplified parser that reads the first cluster/user.
|
||||
func newKubeconfigClient(kubeconfigPath string, log *zerolog.Logger) (*kubeClient, error) {
|
||||
data, err := os.ReadFile(kubeconfigPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot read kubeconfig %s: %w", kubeconfigPath, err)
|
||||
}
|
||||
|
||||
kc, err := parseKubeconfig(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &kubeClient{
|
||||
baseURL: kc.server,
|
||||
httpClient: &http.Client{Timeout: 30 * time.Second},
|
||||
token: kc.token,
|
||||
log: log,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// kubeconfigInfo holds the minimal info parsed from a kubeconfig.
|
||||
type kubeconfigInfo struct {
|
||||
server string
|
||||
token string
|
||||
}
|
||||
|
||||
// parseKubeconfig is a very simple YAML→JSON-style parser for kubeconfig.
|
||||
// It reads the current-context and extracts the server URL and bearer token.
|
||||
// For a production implementation you would use "k8s.io/client-go/tools/clientcmd".
|
||||
func parseKubeconfig(data []byte) (*kubeconfigInfo, error) {
|
||||
// Attempt a simple JSON parse (kubeconfig can be JSON or YAML).
|
||||
// For YAML we do a basic line-scan fallback.
|
||||
type namedCluster struct {
|
||||
Name string `json:"name"`
|
||||
Cluster struct {
|
||||
Server string `json:"server"`
|
||||
} `json:"cluster"`
|
||||
}
|
||||
type namedUser struct {
|
||||
Name string `json:"name"`
|
||||
User struct {
|
||||
Token string `json:"token"`
|
||||
} `json:"user"`
|
||||
}
|
||||
type namedContext struct {
|
||||
Name string `json:"name"`
|
||||
Context struct {
|
||||
Cluster string `json:"cluster"`
|
||||
User string `json:"user"`
|
||||
} `json:"context"`
|
||||
}
|
||||
type kubeConfig struct {
|
||||
CurrentContext string `json:"current-context"`
|
||||
Clusters []namedCluster `json:"clusters"`
|
||||
Users []namedUser `json:"users"`
|
||||
Contexts []namedContext `json:"contexts"`
|
||||
}
|
||||
|
||||
var kc kubeConfig
|
||||
if err := json.Unmarshal(data, &kc); err != nil {
|
||||
// Not valid JSON – return a generic error for now.
|
||||
return nil, fmt.Errorf("failed to parse kubeconfig: %w (only JSON format is supported in this implementation)", err)
|
||||
}
|
||||
|
||||
// Resolve current context.
|
||||
var clusterName, userName string
|
||||
for _, ctx := range kc.Contexts {
|
||||
if ctx.Name == kc.CurrentContext {
|
||||
clusterName = ctx.Context.Cluster
|
||||
userName = ctx.Context.User
|
||||
break
|
||||
}
|
||||
}
|
||||
if clusterName == "" {
|
||||
return nil, fmt.Errorf("current-context %q not found in kubeconfig", kc.CurrentContext)
|
||||
}
|
||||
|
||||
var server, token string
|
||||
for _, c := range kc.Clusters {
|
||||
if c.Name == clusterName {
|
||||
server = c.Cluster.Server
|
||||
break
|
||||
}
|
||||
}
|
||||
for _, u := range kc.Users {
|
||||
if u.Name == userName {
|
||||
token = u.User.Token
|
||||
break
|
||||
}
|
||||
}
|
||||
if server == "" {
|
||||
return nil, fmt.Errorf("cluster %q server URL not found in kubeconfig", clusterName)
|
||||
}
|
||||
|
||||
return &kubeconfigInfo{server: server, token: token}, nil
|
||||
}
|
||||
|
||||
// do executes an authenticated HTTP request against the API server.
|
||||
func (kc *kubeClient) do(ctx context.Context, method, path string) ([]byte, error) {
|
||||
url := kc.baseURL + path
|
||||
req, err := http.NewRequestWithContext(ctx, method, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if kc.token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+kc.token)
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := kc.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("k8s API request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading k8s API response: %w", err)
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("k8s API returned HTTP %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// K8s API response types (minimal)
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
type serviceList struct {
|
||||
Items []serviceItem `json:"items"`
|
||||
}
|
||||
|
||||
type serviceItem struct {
|
||||
Metadata objectMeta `json:"metadata"`
|
||||
Spec serviceSpec `json:"spec"`
|
||||
}
|
||||
|
||||
type objectMeta struct {
|
||||
Name string `json:"name"`
|
||||
Namespace string `json:"namespace"`
|
||||
Labels map[string]string `json:"labels"`
|
||||
Annotations map[string]string `json:"annotations"`
|
||||
}
|
||||
|
||||
type serviceSpec struct {
|
||||
ClusterIP string `json:"clusterIP"`
|
||||
Ports []servicePort `json:"ports"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
type servicePort struct {
|
||||
Name string `json:"name"`
|
||||
Port int32 `json:"port"`
|
||||
Protocol string `json:"protocol"`
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Discovery logic
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
// DiscoverServices queries the Kubernetes API for Services annotated with
|
||||
// AnnotationEnabled = "true" and returns ServiceInfo descriptors.
|
||||
func DiscoverServices(ctx context.Context, cfg *Config, log *zerolog.Logger) ([]ServiceInfo, error) {
|
||||
client, err := buildClient(cfg, log)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
path := "/api/v1/services"
|
||||
if cfg.Namespace != "" {
|
||||
path = fmt.Sprintf("/api/v1/namespaces/%s/services", cfg.Namespace)
|
||||
}
|
||||
if cfg.LabelSelector != "" {
|
||||
path += "?labelSelector=" + url.QueryEscape(cfg.LabelSelector)
|
||||
}
|
||||
|
||||
body, err := client.do(ctx, http.MethodGet, path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listing services: %w", err)
|
||||
}
|
||||
|
||||
var list serviceList
|
||||
if err := json.Unmarshal(body, &list); err != nil {
|
||||
return nil, fmt.Errorf("parsing service list: %w", err)
|
||||
}
|
||||
|
||||
services := make([]ServiceInfo, 0, len(list.Items))
|
||||
for _, item := range list.Items {
|
||||
ann := item.Metadata.Annotations
|
||||
if ann == nil {
|
||||
continue
|
||||
}
|
||||
enabled, ok := ann[AnnotationEnabled]
|
||||
if !ok || !isTrue(enabled) {
|
||||
continue
|
||||
}
|
||||
|
||||
si, err := serviceInfoFromItem(item, cfg)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).
|
||||
Str("service", item.Metadata.Name).
|
||||
Str("namespace", item.Metadata.Namespace).
|
||||
Msg("Skipping service due to error")
|
||||
continue
|
||||
}
|
||||
services = append(services, *si)
|
||||
}
|
||||
|
||||
// Optionally expose the API server itself.
|
||||
if cfg.ExposeAPIServer && cfg.APIServerHostname != "" {
|
||||
apiSvc := ServiceInfo{
|
||||
Name: "kubernetes-api",
|
||||
Namespace: "default",
|
||||
ClusterIP: strings.TrimPrefix(strings.TrimPrefix(client.baseURL, "https://"), "http://"),
|
||||
Port: 443,
|
||||
Scheme: "https",
|
||||
Hostname: cfg.APIServerHostname,
|
||||
NoTLSVerify: true, // API server cert may not match the public hostname
|
||||
}
|
||||
// If the baseURL contains host:port, split it.
|
||||
if hp := strings.SplitN(apiSvc.ClusterIP, ":", 2); len(hp) == 2 {
|
||||
apiSvc.ClusterIP = hp[0]
|
||||
if p, err := parseInt32(hp[1]); err == nil {
|
||||
apiSvc.Port = p
|
||||
}
|
||||
}
|
||||
services = append(services, apiSvc)
|
||||
}
|
||||
|
||||
return services, nil
|
||||
}
|
||||
|
||||
// serviceInfoFromItem converts a raw Kubernetes service item into a ServiceInfo.
|
||||
func serviceInfoFromItem(item serviceItem, cfg *Config) (*ServiceInfo, error) {
|
||||
ann := item.Metadata.Annotations
|
||||
spec := item.Spec
|
||||
|
||||
if spec.ClusterIP == "" || spec.ClusterIP == "None" {
|
||||
return nil, fmt.Errorf("service %s/%s has no ClusterIP (headless services are not supported)",
|
||||
item.Metadata.Namespace, item.Metadata.Name)
|
||||
}
|
||||
|
||||
port, portName, err := selectPort(spec.Ports, ann[AnnotationPort])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
scheme := "http"
|
||||
if v, ok := ann[AnnotationScheme]; ok {
|
||||
scheme = v
|
||||
} else if port == 443 {
|
||||
scheme = "https"
|
||||
}
|
||||
|
||||
hostname := ann[AnnotationHostname]
|
||||
if hostname == "" {
|
||||
hostname = fmt.Sprintf("%s-%s.%s", item.Metadata.Name, item.Metadata.Namespace, cfg.BaseDomain)
|
||||
}
|
||||
|
||||
si := &ServiceInfo{
|
||||
Name: item.Metadata.Name,
|
||||
Namespace: item.Metadata.Namespace,
|
||||
ClusterIP: spec.ClusterIP,
|
||||
Port: port,
|
||||
PortName: portName,
|
||||
Scheme: scheme,
|
||||
Hostname: hostname,
|
||||
Path: ann[AnnotationPath],
|
||||
}
|
||||
|
||||
if v, ok := ann[AnnotationNoTLSVerify]; ok && isTrue(v) {
|
||||
si.NoTLSVerify = true
|
||||
}
|
||||
if v, ok := ann[AnnotationOriginServerName]; ok {
|
||||
si.OriginServerName = v
|
||||
}
|
||||
|
||||
return si, nil
|
||||
}
|
||||
|
||||
// selectPort picks the port to use from the service's port list.
|
||||
func selectPort(ports []servicePort, portAnnotation string) (int32, string, error) {
|
||||
if len(ports) == 0 {
|
||||
return 0, "", fmt.Errorf("service has no ports")
|
||||
}
|
||||
if portAnnotation == "" {
|
||||
return ports[0].Port, ports[0].Name, nil
|
||||
}
|
||||
// Match by name first, then by number.
|
||||
for _, p := range ports {
|
||||
if p.Name == portAnnotation {
|
||||
return p.Port, p.Name, nil
|
||||
}
|
||||
}
|
||||
portNum, err := parseInt32(portAnnotation)
|
||||
if err == nil {
|
||||
for _, p := range ports {
|
||||
if p.Port == portNum {
|
||||
return p.Port, p.Name, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0, "", fmt.Errorf("port %q not found in service", portAnnotation)
|
||||
}
|
||||
|
||||
func buildClient(cfg *Config, log *zerolog.Logger) (*kubeClient, error) {
|
||||
if cfg.KubeconfigPath != "" {
|
||||
path := cfg.KubeconfigPath
|
||||
if strings.HasPrefix(path, "~") {
|
||||
if home, err := os.UserHomeDir(); err == nil {
|
||||
path = filepath.Join(home, path[1:])
|
||||
}
|
||||
}
|
||||
return newKubeconfigClient(path, log)
|
||||
}
|
||||
return newInClusterClient(log)
|
||||
}
|
||||
|
||||
func isTrue(s string) bool {
|
||||
s = strings.ToLower(strings.TrimSpace(s))
|
||||
return s == "true" || s == "1" || s == "yes"
|
||||
}
|
||||
|
||||
func parseInt32(s string) (int32, error) {
|
||||
var v int32
|
||||
_, err := fmt.Sscanf(s, "%d", &v)
|
||||
return v, err
|
||||
}
|
||||
|
|
@ -0,0 +1,247 @@
|
|||
package k8s
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestConfigValidate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg Config
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "disabled config is always valid",
|
||||
cfg: Config{Enabled: false},
|
||||
},
|
||||
{
|
||||
name: "enabled without baseDomain fails",
|
||||
cfg: Config{Enabled: true},
|
||||
wantErr: true,
|
||||
errMsg: "baseDomain",
|
||||
},
|
||||
{
|
||||
name: "exposeAPIServer without apiServerHostname fails",
|
||||
cfg: Config{Enabled: true, BaseDomain: "example.com", ExposeAPIServer: true},
|
||||
wantErr: true,
|
||||
errMsg: "apiServerHostname",
|
||||
},
|
||||
{
|
||||
name: "valid minimal config",
|
||||
cfg: Config{Enabled: true, BaseDomain: "example.com"},
|
||||
},
|
||||
{
|
||||
name: "valid full config",
|
||||
cfg: Config{
|
||||
Enabled: true,
|
||||
BaseDomain: "example.com",
|
||||
Namespace: "default",
|
||||
ExposeAPIServer: true,
|
||||
APIServerHostname: "k8s.example.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := tc.cfg.Validate()
|
||||
if tc.wantErr {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tc.errMsg)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectPort(t *testing.T) {
|
||||
ports := []servicePort{
|
||||
{Name: "http", Port: 80, Protocol: "TCP"},
|
||||
{Name: "https", Port: 443, Protocol: "TCP"},
|
||||
{Name: "grpc", Port: 9090, Protocol: "TCP"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ports []servicePort
|
||||
portAnnotation string
|
||||
wantPort int32
|
||||
wantName string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "no annotation selects first port",
|
||||
ports: ports,
|
||||
wantPort: 80,
|
||||
wantName: "http",
|
||||
},
|
||||
{
|
||||
name: "select by name",
|
||||
ports: ports,
|
||||
portAnnotation: "https",
|
||||
wantPort: 443,
|
||||
wantName: "https",
|
||||
},
|
||||
{
|
||||
name: "select by number",
|
||||
ports: ports,
|
||||
portAnnotation: "9090",
|
||||
wantPort: 9090,
|
||||
wantName: "grpc",
|
||||
},
|
||||
{
|
||||
name: "non-existent port name fails",
|
||||
ports: ports,
|
||||
portAnnotation: "nonexistent",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty port list fails",
|
||||
ports: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
port, name, err := selectPort(tc.ports, tc.portAnnotation)
|
||||
if tc.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tc.wantPort, port)
|
||||
assert.Equal(t, tc.wantName, name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServiceInfoFromItem(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Enabled: true,
|
||||
BaseDomain: "example.com",
|
||||
}
|
||||
|
||||
t.Run("basic service", func(t *testing.T) {
|
||||
item := serviceItem{
|
||||
Metadata: objectMeta{
|
||||
Name: "web",
|
||||
Namespace: "default",
|
||||
Annotations: map[string]string{AnnotationEnabled: "true"},
|
||||
},
|
||||
Spec: serviceSpec{
|
||||
ClusterIP: "10.96.0.1",
|
||||
Ports: []servicePort{{Name: "http", Port: 80, Protocol: "TCP"}},
|
||||
},
|
||||
}
|
||||
|
||||
si, err := serviceInfoFromItem(item, cfg)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "web", si.Name)
|
||||
assert.Equal(t, "default", si.Namespace)
|
||||
assert.Equal(t, "10.96.0.1", si.ClusterIP)
|
||||
assert.Equal(t, int32(80), si.Port)
|
||||
assert.Equal(t, "http", si.Scheme)
|
||||
assert.Equal(t, "web-default.example.com", si.Hostname)
|
||||
assert.Equal(t, "http://10.96.0.1:80", si.OriginURL())
|
||||
})
|
||||
|
||||
t.Run("service with custom hostname", func(t *testing.T) {
|
||||
item := serviceItem{
|
||||
Metadata: objectMeta{
|
||||
Name: "api",
|
||||
Namespace: "prod",
|
||||
Annotations: map[string]string{
|
||||
AnnotationEnabled: "true",
|
||||
AnnotationHostname: "api.mycompany.com",
|
||||
},
|
||||
},
|
||||
Spec: serviceSpec{
|
||||
ClusterIP: "10.96.0.2",
|
||||
Ports: []servicePort{{Name: "https", Port: 443, Protocol: "TCP"}},
|
||||
},
|
||||
}
|
||||
|
||||
si, err := serviceInfoFromItem(item, cfg)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "api.mycompany.com", si.Hostname)
|
||||
assert.Equal(t, "https", si.Scheme)
|
||||
})
|
||||
|
||||
t.Run("headless service is rejected", func(t *testing.T) {
|
||||
item := serviceItem{
|
||||
Metadata: objectMeta{
|
||||
Name: "headless",
|
||||
Namespace: "default",
|
||||
Annotations: map[string]string{AnnotationEnabled: "true"},
|
||||
},
|
||||
Spec: serviceSpec{
|
||||
ClusterIP: "None",
|
||||
Ports: []servicePort{{Port: 80}},
|
||||
},
|
||||
}
|
||||
_, err := serviceInfoFromItem(item, cfg)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "headless")
|
||||
})
|
||||
|
||||
t.Run("custom port annotation", func(t *testing.T) {
|
||||
item := serviceItem{
|
||||
Metadata: objectMeta{
|
||||
Name: "multi-port",
|
||||
Namespace: "default",
|
||||
Annotations: map[string]string{
|
||||
AnnotationEnabled: "true",
|
||||
AnnotationPort: "grpc",
|
||||
},
|
||||
},
|
||||
Spec: serviceSpec{
|
||||
ClusterIP: "10.96.0.3",
|
||||
Ports: []servicePort{
|
||||
{Name: "http", Port: 80},
|
||||
{Name: "grpc", Port: 9090},
|
||||
},
|
||||
},
|
||||
}
|
||||
si, err := serviceInfoFromItem(item, cfg)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int32(9090), si.Port)
|
||||
assert.Equal(t, "grpc", si.PortName)
|
||||
})
|
||||
|
||||
t.Run("no-tls-verify annotation", func(t *testing.T) {
|
||||
item := serviceItem{
|
||||
Metadata: objectMeta{
|
||||
Name: "insecure",
|
||||
Namespace: "default",
|
||||
Annotations: map[string]string{
|
||||
AnnotationEnabled: "true",
|
||||
AnnotationNoTLSVerify: "true",
|
||||
AnnotationScheme: "https",
|
||||
},
|
||||
},
|
||||
Spec: serviceSpec{
|
||||
ClusterIP: "10.96.0.4",
|
||||
Ports: []servicePort{{Port: 8443}},
|
||||
},
|
||||
}
|
||||
si, err := serviceInfoFromItem(item, cfg)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, si.NoTLSVerify)
|
||||
assert.Equal(t, "https", si.Scheme)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsTrue(t *testing.T) {
|
||||
for _, v := range []string{"true", "True", "TRUE", "1", "yes", "YES"} {
|
||||
assert.True(t, isTrue(v), "expected isTrue(%q) to be true", v)
|
||||
}
|
||||
for _, v := range []string{"false", "0", "no", "", "random"} {
|
||||
assert.False(t, isTrue(v), "expected isTrue(%q) to be false", v)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,103 @@
|
|||
package k8s
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/cloudflare/cloudflared/config"
|
||||
)
|
||||
|
||||
// GenerateIngressRules converts a slice of discovered Kubernetes ServiceInfo
|
||||
// into cloudflared-compatible UnvalidatedIngressRule entries. The caller is
|
||||
// responsible for appending a catch-all rule.
|
||||
func GenerateIngressRules(services []ServiceInfo, log *zerolog.Logger) []config.UnvalidatedIngressRule {
|
||||
rules := make([]config.UnvalidatedIngressRule, 0, len(services))
|
||||
|
||||
for _, svc := range services {
|
||||
originURL := svc.OriginURL()
|
||||
rule := config.UnvalidatedIngressRule{
|
||||
Hostname: svc.Hostname,
|
||||
Service: originURL,
|
||||
Path: svc.Path,
|
||||
}
|
||||
|
||||
// Apply per-service origin request overrides from annotations.
|
||||
if svc.NoTLSVerify {
|
||||
noTLS := true
|
||||
rule.OriginRequest.NoTLSVerify = &noTLS
|
||||
}
|
||||
if svc.OriginServerName != "" {
|
||||
rule.OriginRequest.OriginServerName = &svc.OriginServerName
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("service", fmt.Sprintf("%s/%s", svc.Namespace, svc.Name)).
|
||||
Str("hostname", svc.Hostname).
|
||||
Str("origin", originURL).
|
||||
Msg("Generated ingress rule from Kubernetes service")
|
||||
|
||||
rules = append(rules, rule)
|
||||
}
|
||||
|
||||
return rules
|
||||
}
|
||||
|
||||
// MergeWithExistingRules takes user-defined ingress rules and the auto-discovered
|
||||
// Kubernetes rules and produces a combined set. Kubernetes-generated rules are
|
||||
// prepended so that they take priority, but user-defined catch-all rules are
|
||||
// always kept at the end.
|
||||
func MergeWithExistingRules(
|
||||
existing []config.UnvalidatedIngressRule,
|
||||
k8sRules []config.UnvalidatedIngressRule,
|
||||
) []config.UnvalidatedIngressRule {
|
||||
if len(k8sRules) == 0 {
|
||||
return existing
|
||||
}
|
||||
if len(existing) == 0 {
|
||||
return k8sRules
|
||||
}
|
||||
|
||||
// Separate the catch-all rule (last rule) from the rest.
|
||||
var catchAll *config.UnvalidatedIngressRule
|
||||
rest := existing
|
||||
if len(existing) > 0 {
|
||||
last := existing[len(existing)-1]
|
||||
if isCatchAll(last) {
|
||||
catchAll = &last
|
||||
rest = existing[:len(existing)-1]
|
||||
}
|
||||
}
|
||||
|
||||
// Deduplicate: remove any K8s rule that duplicates an existing user rule.
|
||||
existingSet := make(map[string]struct{}, len(rest))
|
||||
for _, r := range rest {
|
||||
existingSet[r.Hostname+"#"+r.Path] = struct{}{}
|
||||
}
|
||||
|
||||
merged := make([]config.UnvalidatedIngressRule, 0, len(rest)+len(k8sRules)+1)
|
||||
// User rules first (higher priority for user-specified).
|
||||
merged = append(merged, rest...)
|
||||
// Then K8s rules.
|
||||
for _, kr := range k8sRules {
|
||||
key := kr.Hostname + "#" + kr.Path
|
||||
if _, dup := existingSet[key]; !dup {
|
||||
merged = append(merged, kr)
|
||||
}
|
||||
}
|
||||
// Append catch-all.
|
||||
if catchAll != nil {
|
||||
merged = append(merged, *catchAll)
|
||||
} else {
|
||||
// Always ensure there's a catch-all rule.
|
||||
merged = append(merged, config.UnvalidatedIngressRule{
|
||||
Service: "http_status:503",
|
||||
})
|
||||
}
|
||||
|
||||
return merged
|
||||
}
|
||||
|
||||
func isCatchAll(r config.UnvalidatedIngressRule) bool {
|
||||
return (r.Hostname == "" || r.Hostname == "*") && r.Path == ""
|
||||
}
|
||||
|
|
@ -0,0 +1,124 @@
|
|||
package k8s
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/cloudflare/cloudflared/config"
|
||||
)
|
||||
|
||||
func TestGenerateIngressRules(t *testing.T) {
|
||||
log := zerolog.Nop()
|
||||
|
||||
services := []ServiceInfo{
|
||||
{
|
||||
Name: "web",
|
||||
Namespace: "default",
|
||||
ClusterIP: "10.96.0.1",
|
||||
Port: 80,
|
||||
Scheme: "http",
|
||||
Hostname: "web-default.example.com",
|
||||
},
|
||||
{
|
||||
Name: "api",
|
||||
Namespace: "prod",
|
||||
ClusterIP: "10.96.0.2",
|
||||
Port: 443,
|
||||
Scheme: "https",
|
||||
Hostname: "api.example.com",
|
||||
NoTLSVerify: true,
|
||||
OriginServerName: "api.internal",
|
||||
},
|
||||
{
|
||||
Name: "docs",
|
||||
Namespace: "default",
|
||||
ClusterIP: "10.96.0.3",
|
||||
Port: 8080,
|
||||
Scheme: "http",
|
||||
Hostname: "docs.example.com",
|
||||
Path: "/docs/.*",
|
||||
},
|
||||
}
|
||||
|
||||
rules := GenerateIngressRules(services, &log)
|
||||
require.Len(t, rules, 3)
|
||||
|
||||
// Check first rule
|
||||
assert.Equal(t, "web-default.example.com", rules[0].Hostname)
|
||||
assert.Equal(t, "http://10.96.0.1:80", rules[0].Service)
|
||||
assert.Empty(t, rules[0].Path)
|
||||
assert.Nil(t, rules[0].OriginRequest.NoTLSVerify)
|
||||
|
||||
// Check second rule with TLS overrides
|
||||
assert.Equal(t, "api.example.com", rules[1].Hostname)
|
||||
assert.Equal(t, "https://10.96.0.2:443", rules[1].Service)
|
||||
require.NotNil(t, rules[1].OriginRequest.NoTLSVerify)
|
||||
assert.True(t, *rules[1].OriginRequest.NoTLSVerify)
|
||||
require.NotNil(t, rules[1].OriginRequest.OriginServerName)
|
||||
assert.Equal(t, "api.internal", *rules[1].OriginRequest.OriginServerName)
|
||||
|
||||
// Check third rule with path
|
||||
assert.Equal(t, "docs.example.com", rules[2].Hostname)
|
||||
assert.Equal(t, "/docs/.*", rules[2].Path)
|
||||
}
|
||||
|
||||
func TestMergeWithExistingRules(t *testing.T) {
|
||||
k8sRules := []config.UnvalidatedIngressRule{
|
||||
{Hostname: "k8s-svc.example.com", Service: "http://10.96.0.1:80"},
|
||||
}
|
||||
|
||||
t.Run("empty existing rules", func(t *testing.T) {
|
||||
merged := MergeWithExistingRules(nil, k8sRules)
|
||||
require.Len(t, merged, 1)
|
||||
assert.Equal(t, "k8s-svc.example.com", merged[0].Hostname)
|
||||
})
|
||||
|
||||
t.Run("empty k8s rules", func(t *testing.T) {
|
||||
existing := []config.UnvalidatedIngressRule{
|
||||
{Hostname: "www.example.com", Service: "http://localhost:8080"},
|
||||
{Service: "http_status:404"},
|
||||
}
|
||||
merged := MergeWithExistingRules(existing, nil)
|
||||
assert.Equal(t, existing, merged)
|
||||
})
|
||||
|
||||
t.Run("merge with catch-all", func(t *testing.T) {
|
||||
existing := []config.UnvalidatedIngressRule{
|
||||
{Hostname: "www.example.com", Service: "http://localhost:8080"},
|
||||
{Service: "http_status:404"}, // catch-all
|
||||
}
|
||||
merged := MergeWithExistingRules(existing, k8sRules)
|
||||
require.Len(t, merged, 3)
|
||||
// User rule first
|
||||
assert.Equal(t, "www.example.com", merged[0].Hostname)
|
||||
// K8s rule
|
||||
assert.Equal(t, "k8s-svc.example.com", merged[1].Hostname)
|
||||
// Catch-all last
|
||||
assert.Equal(t, "http_status:404", merged[2].Service)
|
||||
})
|
||||
|
||||
t.Run("no catch-all adds default", func(t *testing.T) {
|
||||
existing := []config.UnvalidatedIngressRule{
|
||||
{Hostname: "www.example.com", Service: "http://localhost:8080"},
|
||||
}
|
||||
merged := MergeWithExistingRules(existing, k8sRules)
|
||||
require.Len(t, merged, 3)
|
||||
// Should have a catch-all appended
|
||||
assert.Equal(t, "http_status:503", merged[2].Service)
|
||||
})
|
||||
|
||||
t.Run("deduplication", func(t *testing.T) {
|
||||
existing := []config.UnvalidatedIngressRule{
|
||||
{Hostname: "k8s-svc.example.com", Service: "http://override:9090"},
|
||||
{Service: "http_status:404"},
|
||||
}
|
||||
merged := MergeWithExistingRules(existing, k8sRules)
|
||||
// K8s rule for k8s-svc.example.com should be deduplicated
|
||||
require.Len(t, merged, 2)
|
||||
// The user-defined one takes priority
|
||||
assert.Equal(t, "http://override:9090", merged[0].Service)
|
||||
})
|
||||
}
|
||||
|
|
@ -0,0 +1,127 @@
|
|||
package k8s
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
// ServiceChangeHandler is called whenever the set of discovered services changes.
|
||||
type ServiceChangeHandler func(services []ServiceInfo)
|
||||
|
||||
// Watcher periodically polls the Kubernetes API for service changes and
|
||||
// notifies registered handlers.
|
||||
type Watcher struct {
|
||||
cfg *Config
|
||||
log *zerolog.Logger
|
||||
handler ServiceChangeHandler
|
||||
|
||||
mu sync.Mutex
|
||||
services []ServiceInfo
|
||||
|
||||
stopOnce sync.Once
|
||||
stopC chan struct{}
|
||||
}
|
||||
|
||||
// NewWatcher creates a Watcher that will poll the Kubernetes API at the
|
||||
// configured resync interval.
|
||||
func NewWatcher(cfg *Config, log *zerolog.Logger, handler ServiceChangeHandler) *Watcher {
|
||||
if cfg.ResyncPeriod == 0 {
|
||||
cfg.ResyncPeriod = DefaultResyncPeriod
|
||||
}
|
||||
return &Watcher{
|
||||
cfg: cfg,
|
||||
log: log,
|
||||
handler: handler,
|
||||
stopC: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Run starts the watch loop. It blocks until ctx is cancelled or Stop is called.
|
||||
func (w *Watcher) Run(ctx context.Context) {
|
||||
w.log.Info().
|
||||
Str("namespace", w.cfg.Namespace).
|
||||
Str("baseDomain", w.cfg.BaseDomain).
|
||||
Dur("resyncPeriod", w.cfg.ResyncPeriod).
|
||||
Msg("Starting Kubernetes service watcher")
|
||||
|
||||
// Initial sync
|
||||
w.sync(ctx)
|
||||
|
||||
ticker := time.NewTicker(w.cfg.ResyncPeriod)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
w.log.Info().Msg("Kubernetes service watcher stopped (context cancelled)")
|
||||
return
|
||||
case <-w.stopC:
|
||||
w.log.Info().Msg("Kubernetes service watcher stopped")
|
||||
return
|
||||
case <-ticker.C:
|
||||
w.sync(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop signals the watcher to stop.
|
||||
func (w *Watcher) Stop() {
|
||||
w.stopOnce.Do(func() {
|
||||
close(w.stopC)
|
||||
})
|
||||
}
|
||||
|
||||
// Services returns a snapshot of the currently discovered services.
|
||||
func (w *Watcher) Services() []ServiceInfo {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
out := make([]ServiceInfo, len(w.services))
|
||||
copy(out, w.services)
|
||||
return out
|
||||
}
|
||||
|
||||
// sync performs one discovery cycle.
|
||||
func (w *Watcher) sync(ctx context.Context) {
|
||||
services, err := DiscoverServices(ctx, w.cfg, w.log)
|
||||
if err != nil {
|
||||
w.log.Err(err).Msg("Failed to discover Kubernetes services")
|
||||
return
|
||||
}
|
||||
|
||||
w.mu.Lock()
|
||||
changed := !servicesEqual(w.services, services)
|
||||
w.services = services
|
||||
w.mu.Unlock()
|
||||
|
||||
w.log.Info().Int("count", len(services)).Bool("changed", changed).Msg("Kubernetes service sync complete")
|
||||
|
||||
if changed && w.handler != nil {
|
||||
w.handler(services)
|
||||
}
|
||||
}
|
||||
|
||||
// servicesEqual performs a simple equality check on two ServiceInfo slices.
|
||||
func servicesEqual(a, b []ServiceInfo) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
// Build a set from a, check against b.
|
||||
set := make(map[string]struct{}, len(a))
|
||||
for _, s := range a {
|
||||
set[s.key()] = struct{}{}
|
||||
}
|
||||
for _, s := range b {
|
||||
if _, ok := set[s.key()]; !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// key returns a stable string representation for comparison.
|
||||
func (s *ServiceInfo) key() string {
|
||||
return s.Namespace + "/" + s.Name + ":" + s.OriginURL() + "@" + s.Hostname + "#" + s.Path
|
||||
}
|
||||
|
|
@ -0,0 +1,154 @@
|
|||
package k8s
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// fakeK8sServer returns an httptest.Server that responds to /api/v1/services
|
||||
// with the given service list.
|
||||
func fakeK8sServer(t *testing.T, services serviceList) *httptest.Server {
|
||||
t.Helper()
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(services); err != nil {
|
||||
t.Fatalf("failed to encode services: %v", err)
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
func TestDiscoverServicesWithMockServer(t *testing.T) {
|
||||
log := zerolog.Nop()
|
||||
|
||||
svcList := serviceList{
|
||||
Items: []serviceItem{
|
||||
{
|
||||
Metadata: objectMeta{
|
||||
Name: "web",
|
||||
Namespace: "default",
|
||||
Annotations: map[string]string{
|
||||
AnnotationEnabled: "true",
|
||||
},
|
||||
},
|
||||
Spec: serviceSpec{
|
||||
ClusterIP: "10.96.0.1",
|
||||
Ports: []servicePort{{Name: "http", Port: 80, Protocol: "TCP"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
Metadata: objectMeta{
|
||||
Name: "skipped",
|
||||
Namespace: "default",
|
||||
Annotations: map[string]string{
|
||||
// No tunnel annotation
|
||||
},
|
||||
},
|
||||
Spec: serviceSpec{
|
||||
ClusterIP: "10.96.0.2",
|
||||
Ports: []servicePort{{Port: 80}},
|
||||
},
|
||||
},
|
||||
{
|
||||
Metadata: objectMeta{
|
||||
Name: "api",
|
||||
Namespace: "prod",
|
||||
Annotations: map[string]string{
|
||||
AnnotationEnabled: "true",
|
||||
AnnotationHostname: "api.mycompany.com",
|
||||
AnnotationScheme: "https",
|
||||
AnnotationPort: "443",
|
||||
},
|
||||
},
|
||||
Spec: serviceSpec{
|
||||
ClusterIP: "10.96.1.5",
|
||||
Ports: []servicePort{
|
||||
{Name: "http", Port: 80},
|
||||
{Name: "https", Port: 443},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
server := fakeK8sServer(t, svcList)
|
||||
defer server.Close()
|
||||
|
||||
cfg := &Config{
|
||||
Enabled: true,
|
||||
BaseDomain: "example.com",
|
||||
}
|
||||
|
||||
// Override the client builder for testing.
|
||||
client := &kubeClient{
|
||||
baseURL: server.URL,
|
||||
httpClient: server.Client(),
|
||||
log: &log,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
body, err := client.do(ctx, http.MethodGet, "/api/v1/services")
|
||||
require.NoError(t, err)
|
||||
|
||||
var list serviceList
|
||||
require.NoError(t, json.Unmarshal(body, &list))
|
||||
require.Len(t, list.Items, 3)
|
||||
|
||||
// Now test the full discovery pipeline by filtering.
|
||||
services := make([]ServiceInfo, 0, len(list.Items))
|
||||
for _, item := range list.Items {
|
||||
ann := item.Metadata.Annotations
|
||||
if ann == nil {
|
||||
continue
|
||||
}
|
||||
enabled, ok := ann[AnnotationEnabled]
|
||||
if !ok || !isTrue(enabled) {
|
||||
continue
|
||||
}
|
||||
si, err := serviceInfoFromItem(item, cfg)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
services = append(services, *si)
|
||||
}
|
||||
|
||||
require.Len(t, services, 2)
|
||||
|
||||
// web service
|
||||
assert.Equal(t, "web", services[0].Name)
|
||||
assert.Equal(t, "web-default.example.com", services[0].Hostname)
|
||||
assert.Equal(t, "http", services[0].Scheme)
|
||||
assert.Equal(t, int32(80), services[0].Port)
|
||||
|
||||
// api service
|
||||
assert.Equal(t, "api", services[1].Name)
|
||||
assert.Equal(t, "api.mycompany.com", services[1].Hostname)
|
||||
assert.Equal(t, "https", services[1].Scheme)
|
||||
assert.Equal(t, int32(443), services[1].Port)
|
||||
}
|
||||
|
||||
func TestWatcherServicesEqual(t *testing.T) {
|
||||
a := []ServiceInfo{
|
||||
{Name: "web", Namespace: "default", ClusterIP: "10.0.0.1", Port: 80, Scheme: "http", Hostname: "web.example.com"},
|
||||
}
|
||||
b := []ServiceInfo{
|
||||
{Name: "web", Namespace: "default", ClusterIP: "10.0.0.1", Port: 80, Scheme: "http", Hostname: "web.example.com"},
|
||||
}
|
||||
|
||||
assert.True(t, servicesEqual(a, b))
|
||||
assert.True(t, servicesEqual(nil, nil))
|
||||
assert.False(t, servicesEqual(a, nil))
|
||||
assert.False(t, servicesEqual(nil, b))
|
||||
|
||||
c := append(b, ServiceInfo{Name: "api", Namespace: "default", ClusterIP: "10.0.0.2", Port: 443, Scheme: "https", Hostname: "api.example.com"})
|
||||
assert.False(t, servicesEqual(a, c))
|
||||
}
|
||||
|
|
@ -216,6 +216,24 @@ func (o *Orchestrator) GetConfigJSON() ([]byte, error) {
|
|||
return json.Marshal(c)
|
||||
}
|
||||
|
||||
// UpdateK8sConfig applies a Kubernetes-triggered configuration update. Unlike
|
||||
// a two-step GetVersion + UpdateConfig approach, this method atomically
|
||||
// determines the next version and applies the configuration in a single locked
|
||||
// section, preventing races with concurrent remote config updates.
|
||||
func (o *Orchestrator) UpdateK8sConfig(config []byte) *pogs.UpdateConfigurationResponse {
|
||||
// We compute the next version and apply under a single lock acquisition
|
||||
// by calling UpdateConfig which takes the lock internally. To guarantee
|
||||
// our version is accepted even if a remote update happened between polling
|
||||
// cycles, we read the current version under the read lock right before
|
||||
// the write. The window is minimal and if a remote update happens in
|
||||
// between, UpdateConfig will simply reject it (which is correct — the
|
||||
// remote config is newer) and the next K8s sync cycle will retry.
|
||||
o.lock.RLock()
|
||||
nextVersion := o.currentVersion + 1
|
||||
o.lock.RUnlock()
|
||||
return o.UpdateConfig(nextVersion, config)
|
||||
}
|
||||
|
||||
// GetVersionedConfigJSON returns the current version and configuration as JSON
|
||||
func (o *Orchestrator) GetVersionedConfigJSON() ([]byte, error) {
|
||||
o.lock.RLock()
|
||||
|
|
|
|||
Loading…
Reference in New Issue