package v3

import (
	"errors"
	"net"
	"net/netip"
	"sync"

	"github.com/rs/zerolog"
)

var (
	// ErrSessionNotFound indicates that a session has not been registered yet for the request id.
	ErrSessionNotFound = errors.New("flow not found")
	// ErrSessionBoundToOtherConn is returned when a registration already exists for a different connection.
	ErrSessionBoundToOtherConn = errors.New("flow is in use by another connection")
	// ErrSessionAlreadyRegistered is returned when a registration already exists for this connection.
	ErrSessionAlreadyRegistered = errors.New("flow is already registered for this connection")
)

type SessionManager interface {
	// RegisterSession will register a new session if it does not already exist for the request ID.
	// During new session creation, the session will also bind the UDP socket for the origin.
	// If the session exists for a different connection, it will return [ErrSessionBoundToOtherConn].
	RegisterSession(request *UDPSessionRegistrationDatagram, conn DatagramConn) (Session, error)
	// GetSession returns an active session if available for the provided connection.
	// If the session does not exist, it will return [ErrSessionNotFound]. If the session exists for a different
	// connection, it will return [ErrSessionBoundToOtherConn].
	GetSession(requestID RequestID) (Session, error)
	// UnregisterSession will remove a session from the current session manager. It will attempt to close the session
	// before removal.
	UnregisterSession(requestID RequestID)
}

type DialUDP func(dest netip.AddrPort) (*net.UDPConn, error)

type sessionManager struct {
	sessions     map[RequestID]Session
	mutex        sync.RWMutex
	originDialer DialUDP
	metrics      Metrics
	log          *zerolog.Logger
}

func NewSessionManager(metrics Metrics, log *zerolog.Logger, originDialer DialUDP) SessionManager {
	return &sessionManager{
		sessions:     make(map[RequestID]Session),
		originDialer: originDialer,
		metrics:      metrics,
		log:          log,
	}
}

func (s *sessionManager) RegisterSession(request *UDPSessionRegistrationDatagram, conn DatagramConn) (Session, error) {
	s.mutex.Lock()
	defer s.mutex.Unlock()
	// Check to make sure session doesn't already exist for requestID
	if session, exists := s.sessions[request.RequestID]; exists {
		if conn.ID() == session.ConnectionID() {
			return nil, ErrSessionAlreadyRegistered
		}
		return nil, ErrSessionBoundToOtherConn
	}
	// Attempt to bind the UDP socket for the new session
	origin, err := s.originDialer(request.Dest)
	if err != nil {
		return nil, err
	}
	// Create and insert the new session in the map
	session := NewSession(
		request.RequestID,
		request.IdleDurationHint,
		origin,
		origin.RemoteAddr(),
		origin.LocalAddr(),
		conn,
		s.metrics,
		s.log)
	s.sessions[request.RequestID] = session
	return session, nil
}

func (s *sessionManager) GetSession(requestID RequestID) (Session, error) {
	s.mutex.RLock()
	defer s.mutex.RUnlock()
	session, exists := s.sessions[requestID]
	if exists {
		return session, nil
	}
	return nil, ErrSessionNotFound
}

func (s *sessionManager) UnregisterSession(requestID RequestID) {
	s.mutex.Lock()
	defer s.mutex.Unlock()
	// Get the session and make sure to close it if it isn't already closed
	session, exists := s.sessions[requestID]
	if exists {
		// We ignore any errors when attempting to close the session
		_ = session.Close()
	}
	delete(s.sessions, requestID)
}