package supervisor

import (
	"context"
	"errors"
	"net"
	"strings"
	"time"

	"github.com/prometheus/client_golang/prometheus"
	"github.com/quic-go/quic-go"
	"github.com/rs/zerolog"

	"github.com/cloudflare/cloudflared/connection"
	"github.com/cloudflare/cloudflared/edgediscovery"
	"github.com/cloudflare/cloudflared/ingress"
	"github.com/cloudflare/cloudflared/orchestration"
	v3 "github.com/cloudflare/cloudflared/quic/v3"
	"github.com/cloudflare/cloudflared/retry"
	"github.com/cloudflare/cloudflared/signal"
	"github.com/cloudflare/cloudflared/tunnelstate"
)

const (
	// Waiting time before retrying a failed tunnel connection
	tunnelRetryDuration = time.Second * 10
	// Interval between registering new tunnels
	registrationInterval = time.Second

	subsystemRefreshAuth = "refresh_auth"
	// Maximum exponent for 'Authenticate' exponential backoff
	refreshAuthMaxBackoff = 10
	// Waiting time before retrying a failed 'Authenticate' connection
	refreshAuthRetryDuration = time.Second * 10
)

// Supervisor manages non-declarative tunnels. Establishes TCP connections with the edge, and
// reconnects them if they disconnect.
type Supervisor struct {
	config                  *TunnelConfig
	orchestrator            *orchestration.Orchestrator
	edgeIPs                 *edgediscovery.Edge
	edgeTunnelServer        TunnelServer
	tunnelErrors            chan tunnelError
	tunnelsConnecting       map[int]chan struct{}
	tunnelsProtocolFallback map[int]*protocolFallback
	// nextConnectedIndex and nextConnectedSignal are used to wait for all
	// currently-connecting tunnels to finish connecting so we can reset backoff timer
	nextConnectedIndex  int
	nextConnectedSignal chan struct{}

	log          *ConnAwareLogger
	logTransport *zerolog.Logger

	reconnectCh       chan ReconnectSignal
	gracefulShutdownC <-chan struct{}
}

var errEarlyShutdown = errors.New("shutdown started")

type tunnelError struct {
	index int
	err   error
}

func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrator, reconnectCh chan ReconnectSignal, gracefulShutdownC <-chan struct{}) (*Supervisor, error) {
	isStaticEdge := len(config.EdgeAddrs) > 0

	var err error
	var edgeIPs *edgediscovery.Edge
	if isStaticEdge { // static edge addresses
		edgeIPs, err = edgediscovery.StaticEdge(config.Log, config.EdgeAddrs)
	} else {
		edgeIPs, err = edgediscovery.ResolveEdge(config.Log, config.Region, config.EdgeIPVersion)
	}
	if err != nil {
		return nil, err
	}

	tracker := tunnelstate.NewConnTracker(config.Log)
	log := NewConnAwareLogger(config.Log, tracker, config.Observer)

	edgeAddrHandler := NewIPAddrFallback(config.MaxEdgeAddrRetries)
	edgeBindAddr := config.EdgeBindAddr

	datagramMetrics := v3.NewMetrics(prometheus.DefaultRegisterer)
	sessionManager := v3.NewSessionManager(datagramMetrics, config.Log, ingress.DialUDPAddrPort)

	edgeTunnelServer := EdgeTunnelServer{
		config:            config,
		orchestrator:      orchestrator,
		sessionManager:    sessionManager,
		datagramMetrics:   datagramMetrics,
		edgeAddrs:         edgeIPs,
		edgeAddrHandler:   edgeAddrHandler,
		edgeBindAddr:      edgeBindAddr,
		tracker:           tracker,
		reconnectCh:       reconnectCh,
		gracefulShutdownC: gracefulShutdownC,
		connAwareLogger:   log,
	}

	return &Supervisor{
		config:                  config,
		orchestrator:            orchestrator,
		edgeIPs:                 edgeIPs,
		edgeTunnelServer:        &edgeTunnelServer,
		tunnelErrors:            make(chan tunnelError),
		tunnelsConnecting:       map[int]chan struct{}{},
		tunnelsProtocolFallback: map[int]*protocolFallback{},
		log:                     log,
		logTransport:            config.LogTransport,
		reconnectCh:             reconnectCh,
		gracefulShutdownC:       gracefulShutdownC,
	}, nil
}

func (s *Supervisor) Run(
	ctx context.Context,
	connectedSignal *signal.Signal,
) error {
	if s.config.ICMPRouterServer != nil {
		go func() {
			if err := s.config.ICMPRouterServer.Serve(ctx); err != nil {
				if errors.Is(err, net.ErrClosed) {
					s.log.Logger().Info().Err(err).Msg("icmp router terminated")
				} else {
					s.log.Logger().Err(err).Msg("icmp router terminated")
				}
			}
		}()
	}

	if err := s.initialize(ctx, connectedSignal); err != nil {
		if err == errEarlyShutdown {
			return nil
		}
		return err
	}
	var tunnelsWaiting []int
	tunnelsActive := s.config.HAConnections

	backoff := retry.NewBackoff(s.config.Retries, tunnelRetryDuration, true)
	var backoffTimer <-chan time.Time

	shuttingDown := false
	for {
		select {
		// Context cancelled
		case <-ctx.Done():
			for tunnelsActive > 0 {
				<-s.tunnelErrors
				tunnelsActive--
			}
			return nil
		// startTunnel completed with a response
		// (note that this may also be caused by context cancellation)
		case tunnelError := <-s.tunnelErrors:
			tunnelsActive--
			if tunnelError.err != nil && !shuttingDown {
				switch tunnelError.err.(type) {
				case ReconnectSignal:
					// For tunnels that closed with reconnect signal, we reconnect immediately
					go s.startTunnel(ctx, tunnelError.index, s.newConnectedTunnelSignal(tunnelError.index))
					tunnelsActive++
					continue
				}
				// Make sure we don't continue if there is no more fallback allowed
				if _, retry := s.tunnelsProtocolFallback[tunnelError.index].GetMaxBackoffDuration(ctx); !retry {
					continue
				}
				s.log.ConnAwareLogger().Err(tunnelError.err).Int(connection.LogFieldConnIndex, tunnelError.index).Msg("Connection terminated")
				tunnelsWaiting = append(tunnelsWaiting, tunnelError.index)
				s.waitForNextTunnel(tunnelError.index)

				if backoffTimer == nil {
					backoffTimer = backoff.BackoffTimer()
				}
			} else if tunnelsActive == 0 {
				s.log.ConnAwareLogger().Msg("no more connections active and exiting")
				// All connected tunnels exited gracefully, no more work to do
				return nil
			}
		// Backoff was set and its timer expired
		case <-backoffTimer:
			backoffTimer = nil
			for _, index := range tunnelsWaiting {
				go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index))
			}
			tunnelsActive += len(tunnelsWaiting)
			tunnelsWaiting = nil
		// Tunnel successfully connected
		case <-s.nextConnectedSignal:
			if !s.waitForNextTunnel(s.nextConnectedIndex) && len(tunnelsWaiting) == 0 {
				// No more tunnels outstanding, clear backoff timer
				backoff.SetGracePeriod()
			}
		case <-s.gracefulShutdownC:
			shuttingDown = true
		}
	}
}

// Returns nil if initialization succeeded, else the initialization error.
// Attempts here will be made to connect one tunnel, if successful, it will
// connect the available tunnels up to config.HAConnections.
func (s *Supervisor) initialize(
	ctx context.Context,
	connectedSignal *signal.Signal,
) error {
	availableAddrs := s.edgeIPs.AvailableAddrs()
	if s.config.HAConnections > availableAddrs {
		s.log.Logger().Info().Msgf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, availableAddrs)
		s.config.HAConnections = availableAddrs
	}
	s.tunnelsProtocolFallback[0] = &protocolFallback{
		retry.NewBackoff(s.config.Retries, retry.DefaultBaseTime, true),
		s.config.ProtocolSelector.Current(),
		false,
	}

	go s.startFirstTunnel(ctx, connectedSignal)

	// Wait for response from first tunnel before proceeding to attempt other HA edge tunnels
	select {
	case <-ctx.Done():
		<-s.tunnelErrors
		return ctx.Err()
	case tunnelError := <-s.tunnelErrors:
		return tunnelError.err
	case <-s.gracefulShutdownC:
		return errEarlyShutdown
	case <-connectedSignal.Wait():
	}

	// At least one successful connection, so start the rest
	for i := 1; i < s.config.HAConnections; i++ {
		s.tunnelsProtocolFallback[i] = &protocolFallback{
			retry.NewBackoff(s.config.Retries, retry.DefaultBaseTime, true),
			// Set the protocol we know the first tunnel connected with.
			s.tunnelsProtocolFallback[0].protocol,
			false,
		}
		go s.startTunnel(ctx, i, s.newConnectedTunnelSignal(i))
		time.Sleep(registrationInterval)
	}
	return nil
}

// startTunnel starts the first tunnel connection. The resulting error will be sent on
// s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
func (s *Supervisor) startFirstTunnel(
	ctx context.Context,
	connectedSignal *signal.Signal,
) {
	var (
		err error
	)
	const firstConnIndex = 0
	isStaticEdge := len(s.config.EdgeAddrs) > 0
	defer func() {
		s.tunnelErrors <- tunnelError{index: firstConnIndex, err: err}
	}()

	// If the first tunnel disconnects, keep restarting it.
	for {
		err = s.edgeTunnelServer.Serve(ctx, firstConnIndex, s.tunnelsProtocolFallback[firstConnIndex], connectedSignal)
		if ctx.Err() != nil {
			return
		}
		if err == nil {
			return
		}
		// Make sure we don't continue if there is no more fallback allowed
		if _, retry := s.tunnelsProtocolFallback[firstConnIndex].GetMaxBackoffDuration(ctx); !retry {
			return
		}
		// Try again for Unauthorized errors because we hope them to be
		// transient due to edge propagation lag on new Tunnels.
		if strings.Contains(err.Error(), "Unauthorized") {
			continue
		}
		switch err.(type) {
		case edgediscovery.ErrNoAddressesLeft:
			// If your provided addresses are not available, we will keep trying regardless.
			if !isStaticEdge {
				return
			}
		case connection.DupConnRegisterTunnelError,
			*quic.IdleTimeoutError,
			*quic.ApplicationError,
			edgediscovery.DialError,
			*connection.EdgeQuicDialError:
			// Try again for these types of errors
		default:
			// Uncaught errors should bail startup
			return
		}
	}
}

// startTunnel starts a new tunnel connection. The resulting error will be sent on
// s.tunnelError as this is expected to run in a goroutine.
func (s *Supervisor) startTunnel(
	ctx context.Context,
	index int,
	connectedSignal *signal.Signal,
) {
	var (
		err error
	)
	defer func() {
		s.tunnelErrors <- tunnelError{index: index, err: err}
	}()

	err = s.edgeTunnelServer.Serve(ctx, uint8(index), s.tunnelsProtocolFallback[index], connectedSignal)
}

func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal {
	sig := make(chan struct{})
	s.tunnelsConnecting[index] = sig
	s.nextConnectedSignal = sig
	s.nextConnectedIndex = index
	return signal.New(sig)
}

func (s *Supervisor) waitForNextTunnel(index int) bool {
	delete(s.tunnelsConnecting, index)
	s.nextConnectedSignal = nil
	for k, v := range s.tunnelsConnecting {
		s.nextConnectedIndex = k
		s.nextConnectedSignal = v
		return true
	}
	return false
}

func (s *Supervisor) unusedIPs() bool {
	return s.edgeIPs.AvailableAddrs() > s.config.HAConnections
}