package connection

import (
	"context"
	"fmt"
	"io"
	"time"

	"github.com/rs/zerolog"
	"zombiezen.com/go/capnproto2/rpc"

	"github.com/cloudflare/cloudflared/tunnelrpc"
	tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)

type tunnelServerClient struct {
	client    tunnelpogs.TunnelServer_PogsClient
	transport rpc.Transport
}

// NewTunnelRPCClient creates and returns a new RPC client, which will communicate using a stream on the given muxer.
// This method is exported for supervisor to call Authenticate RPC
func NewTunnelServerClient(
	ctx context.Context,
	stream io.ReadWriteCloser,
	log *zerolog.Logger,
) *tunnelServerClient {
	transport := tunnelrpc.NewTransportLogger(log, rpc.StreamTransport(stream))
	conn := rpc.NewConn(
		transport,
		tunnelrpc.ConnLog(log),
	)
	registrationClient := tunnelpogs.RegistrationServer_PogsClient{Client: conn.Bootstrap(ctx), Conn: conn}
	return &tunnelServerClient{
		client:    tunnelpogs.TunnelServer_PogsClient{RegistrationServer_PogsClient: registrationClient, Client: conn.Bootstrap(ctx), Conn: conn},
		transport: transport,
	}
}

func (tsc *tunnelServerClient) Authenticate(ctx context.Context, classicTunnel *ClassicTunnelConfig, registrationOptions *tunnelpogs.RegistrationOptions) (tunnelpogs.AuthOutcome, error) {
	authResp, err := tsc.client.Authenticate(ctx, classicTunnel.OriginCert, classicTunnel.Hostname, registrationOptions)
	if err != nil {
		return nil, err
	}
	return authResp.Outcome(), nil
}

func (tsc *tunnelServerClient) Close() {
	// Closing the client will also close the connection
	_ = tsc.client.Close()
	_ = tsc.transport.Close()
}

type NamedTunnelRPCClient interface {
	RegisterConnection(
		c context.Context,
		config *NamedTunnelConfig,
		options *tunnelpogs.ConnectionOptions,
		connIndex uint8,
		observer *Observer,
	) error
	GracefulShutdown(ctx context.Context, gracePeriod time.Duration)
	Close()
}

type registrationServerClient struct {
	client    tunnelpogs.RegistrationServer_PogsClient
	transport rpc.Transport
}

func newRegistrationRPCClient(
	ctx context.Context,
	stream io.ReadWriteCloser,
	log *zerolog.Logger,
) NamedTunnelRPCClient {
	transport := tunnelrpc.NewTransportLogger(log, rpc.StreamTransport(stream))
	conn := rpc.NewConn(
		transport,
		tunnelrpc.ConnLog(log),
	)
	return &registrationServerClient{
		client:    tunnelpogs.RegistrationServer_PogsClient{Client: conn.Bootstrap(ctx), Conn: conn},
		transport: transport,
	}
}

func (rsc *registrationServerClient) RegisterConnection(
	ctx context.Context,
	config *NamedTunnelConfig,
	options *tunnelpogs.ConnectionOptions,
	connIndex uint8,
	observer *Observer,
) error {
	conn, err := rsc.client.RegisterConnection(
		ctx,
		config.Credentials.Auth(),
		config.Credentials.TunnelID,
		connIndex,
		options,
	)
	if err != nil {
		if err.Error() == DuplicateConnectionError {
			observer.metrics.regFail.WithLabelValues("dup_edge_conn", "registerConnection").Inc()
			return errDuplicationConnection
		}
		observer.metrics.regFail.WithLabelValues("server_error", "registerConnection").Inc()
		return serverRegistrationErrorFromRPC(err)
	}

	observer.metrics.regSuccess.WithLabelValues("registerConnection").Inc()

	observer.logServerInfo(connIndex, conn.Location, fmt.Sprintf("Connection %s registered", conn.UUID))
	observer.sendConnectedEvent(connIndex, conn.Location)

	return nil
}

func (rsc *registrationServerClient) GracefulShutdown(ctx context.Context, gracePeriod time.Duration) {
	ctx, cancel := context.WithTimeout(ctx, gracePeriod)
	defer cancel()
	_ = rsc.client.UnregisterConnection(ctx)
}

func (rsc *registrationServerClient) Close() {
	// Closing the client will also close the connection
	_ = rsc.client.Close()
	// Closing the transport also closes the stream
	_ = rsc.transport.Close()
}

type rpcName string

const (
	register     rpcName = "register"
	reconnect    rpcName = "reconnect"
	unregister   rpcName = "unregister"
	authenticate rpcName = " authenticate"
)

func (h *h2muxConnection) registerTunnel(ctx context.Context, credentialSetter CredentialManager, classicTunnel *ClassicTunnelConfig, registrationOptions *tunnelpogs.RegistrationOptions) error {
	h.observer.sendRegisteringEvent(registrationOptions.ConnectionID)

	stream, err := h.newRPCStream(ctx, register)
	if err != nil {
		return err
	}
	rpcClient := NewTunnelServerClient(ctx, stream, h.observer.log)
	defer rpcClient.Close()

	_ = h.logServerInfo(ctx, rpcClient)
	registration := rpcClient.client.RegisterTunnel(
		ctx,
		classicTunnel.OriginCert,
		classicTunnel.Hostname,
		registrationOptions,
	)
	if registrationErr := registration.DeserializeError(); registrationErr != nil {
		// RegisterTunnel RPC failure
		return h.processRegisterTunnelError(registrationErr, register)
	}

	credentialSetter.SetEventDigest(h.connIndex, registration.EventDigest)
	return h.processRegistrationSuccess(registration, register, credentialSetter, classicTunnel)
}

type CredentialManager interface {
	ReconnectToken() ([]byte, error)
	EventDigest(connID uint8) ([]byte, error)
	SetEventDigest(connID uint8, digest []byte)
	ConnDigest(connID uint8) ([]byte, error)
	SetConnDigest(connID uint8, digest []byte)
}

func (h *h2muxConnection) processRegistrationSuccess(
	registration *tunnelpogs.TunnelRegistration,
	name rpcName,
	credentialManager CredentialManager, classicTunnel *ClassicTunnelConfig,
) error {
	for _, logLine := range registration.LogLines {
		h.observer.log.Info().Msg(logLine)
	}

	if registration.TunnelID != "" {
		h.observer.metrics.tunnelsHA.AddTunnelID(h.connIndex, registration.TunnelID)
		h.observer.log.Info().Msgf("Each HA connection's tunnel IDs: %v", h.observer.metrics.tunnelsHA.String())
	}

	credentialManager.SetConnDigest(h.connIndex, registration.ConnDigest)
	h.observer.metrics.userHostnamesCounts.WithLabelValues(registration.Url).Inc()

	h.observer.log.Info().Msgf("Route propagating, it may take up to 1 minute for your new route to become functional")
	h.observer.metrics.regSuccess.WithLabelValues(string(name)).Inc()
	return nil
}

func (h *h2muxConnection) processRegisterTunnelError(err tunnelpogs.TunnelRegistrationError, name rpcName) error {
	if err.Error() == DuplicateConnectionError {
		h.observer.metrics.regFail.WithLabelValues("dup_edge_conn", string(name)).Inc()
		return errDuplicationConnection
	}
	h.observer.metrics.regFail.WithLabelValues("server_error", string(name)).Inc()
	return ServerRegisterTunnelError{
		Cause:     err,
		Permanent: err.IsPermanent(),
	}
}

func (h *h2muxConnection) reconnectTunnel(ctx context.Context, credentialManager CredentialManager, classicTunnel *ClassicTunnelConfig, registrationOptions *tunnelpogs.RegistrationOptions) error {
	token, err := credentialManager.ReconnectToken()
	if err != nil {
		return err
	}
	eventDigest, err := credentialManager.EventDigest(h.connIndex)
	if err != nil {
		return err
	}
	connDigest, err := credentialManager.ConnDigest(h.connIndex)
	if err != nil {
		return err
	}

	h.observer.log.Debug().Msg("initiating RPC stream to reconnect")
	stream, err := h.newRPCStream(ctx, register)
	if err != nil {
		return err
	}
	rpcClient := NewTunnelServerClient(ctx, stream, h.observer.log)
	defer rpcClient.Close()

	_ = h.logServerInfo(ctx, rpcClient)
	registration := rpcClient.client.ReconnectTunnel(
		ctx,
		token,
		eventDigest,
		connDigest,
		classicTunnel.Hostname,
		registrationOptions,
	)
	if registrationErr := registration.DeserializeError(); registrationErr != nil {
		// ReconnectTunnel RPC failure
		return h.processRegisterTunnelError(registrationErr, reconnect)
	}
	return h.processRegistrationSuccess(registration, reconnect, credentialManager, classicTunnel)
}

func (h *h2muxConnection) logServerInfo(ctx context.Context, rpcClient *tunnelServerClient) error {
	// Request server info without blocking tunnel registration; must use capnp library directly.
	serverInfoPromise := tunnelrpc.TunnelServer{Client: rpcClient.client.Client}.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
		return nil
	})
	serverInfoMessage, err := serverInfoPromise.Result().Struct()
	if err != nil {
		h.observer.log.Err(err).Msg("Failed to retrieve server information")
		return err
	}
	serverInfo, err := tunnelpogs.UnmarshalServerInfo(serverInfoMessage)
	if err != nil {
		h.observer.log.Err(err).Msg("Failed to retrieve server information")
		return err
	}
	h.observer.logServerInfo(h.connIndex, serverInfo.LocationName, "Connection established")
	return nil
}

func (h *h2muxConnection) registerNamedTunnel(
	ctx context.Context,
	namedTunnel *NamedTunnelConfig,
	connOptions *tunnelpogs.ConnectionOptions,
) error {
	stream, err := h.newRPCStream(ctx, register)
	if err != nil {
		return err
	}
	rpcClient := h.newRPCClientFunc(ctx, stream, h.observer.log)
	defer rpcClient.Close()

	if err = rpcClient.RegisterConnection(ctx, namedTunnel, connOptions, h.connIndex, h.observer); err != nil {
		return err
	}
	return nil
}

func (h *h2muxConnection) unregister(isNamedTunnel bool) {
	h.observer.sendUnregisteringEvent(h.connIndex)

	unregisterCtx, cancel := context.WithTimeout(context.Background(), h.config.GracePeriod)
	defer cancel()

	stream, err := h.newRPCStream(unregisterCtx, unregister)
	if err != nil {
		return
	}
	defer stream.Close()

	if isNamedTunnel {
		rpcClient := h.newRPCClientFunc(unregisterCtx, stream, h.observer.log)
		defer rpcClient.Close()

		rpcClient.GracefulShutdown(unregisterCtx, h.config.GracePeriod)
	} else {
		rpcClient := NewTunnelServerClient(unregisterCtx, stream, h.observer.log)
		defer rpcClient.Close()

		// gracePeriod is encoded in int64 using capnproto
		_ = rpcClient.client.UnregisterTunnel(unregisterCtx, h.config.GracePeriod.Nanoseconds())
	}

	h.observer.log.Info().Uint8(LogFieldConnIndex, h.connIndex).Msg("Unregistered tunnel connection")
}