package websocket

import (
	"bytes"
	"context"
	"fmt"
	"io"
	"time"

	gobwas "github.com/gobwas/ws"
	"github.com/gobwas/ws/wsutil"
	"github.com/gorilla/websocket"
	"github.com/rs/zerolog"
)

const (
	// Time allowed to write a message to the peer.
	writeWait = 10 * time.Second

	// Time allowed to read the next pong message from the peer.
	defaultPongWait = 60 * time.Second

	// Send pings to peer with this period. Must be less than pongWait.
	defaultPingPeriod = (defaultPongWait * 9) / 10

	PingPeriodContextKey = PingPeriodContext("pingPeriod")
)

type PingPeriodContext string

// GorillaConn is a wrapper around the standard gorilla websocket but implements a ReadWriter
// This is still used by access carrier
type GorillaConn struct {
	*websocket.Conn
	log     *zerolog.Logger
	readBuf bytes.Buffer
}

// Read will read messages from the websocket connection
func (c *GorillaConn) Read(p []byte) (int, error) {
	// Intermediate buffer may contain unread bytes from the last read, start there before blocking on a new frame
	if c.readBuf.Len() > 0 {
		return c.readBuf.Read(p)
	}

	_, message, err := c.Conn.ReadMessage()
	if err != nil {
		return 0, err
	}

	copied := copy(p, message)

	// Write unread bytes to readBuf; if everything was read this is a no-op
	// Write returns a nil error always and grows the buffer; everything is always written or panic
	c.readBuf.Write(message[copied:])

	return copied, nil
}

// Write will write messages to the websocket connection
func (c *GorillaConn) Write(p []byte) (int, error) {
	if err := c.Conn.WriteMessage(websocket.BinaryMessage, p); err != nil {
		return 0, err
	}

	return len(p), nil
}

// SetDeadline sets both read and write deadlines, as per net.Conn interface docs:
// "It is equivalent to calling both SetReadDeadline and SetWriteDeadline."
// Note there is no synchronization here, but the gorilla implementation isn't thread safe anyway
func (c *GorillaConn) SetDeadline(t time.Time) error {
	if err := c.Conn.SetReadDeadline(t); err != nil {
		return fmt.Errorf("error setting read deadline: %w", err)
	}
	if err := c.Conn.SetWriteDeadline(t); err != nil {
		return fmt.Errorf("error setting write deadline: %w", err)
	}
	return nil
}

// pinger simulates the websocket connection to keep it alive
func (c *GorillaConn) pinger(ctx context.Context) {
	ticker := time.NewTicker(defaultPingPeriod)
	defer ticker.Stop()
	for {
		select {
		case <-ticker.C:
			if err := c.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(writeWait)); err != nil {
				c.log.Debug().Msgf("failed to send ping message: %s", err)
			}
		case <-ctx.Done():
			return
		}
	}
}

type Conn struct {
	rw  io.ReadWriter
	log *zerolog.Logger
	// closed is a channel to indicate if Conn has been fully terminated
	shutdownC chan struct{}
}

func NewConn(ctx context.Context, rw io.ReadWriter, log *zerolog.Logger) *Conn {
	c := &Conn{
		rw:        rw,
		log:       log,
		shutdownC: make(chan struct{}),
	}
	go c.pinger(ctx)
	return c
}

// Read will read messages from the websocket connection
func (c *Conn) Read(reader []byte) (int, error) {
	data, err := wsutil.ReadClientBinary(c.rw)
	if err != nil {
		return 0, err
	}
	return copy(reader, data), nil
}

// Write will write messages to the websocket connection
func (c *Conn) Write(p []byte) (int, error) {
	if err := wsutil.WriteServerBinary(c.rw, p); err != nil {
		return 0, err
	}
	return len(p), nil
}

func (c *Conn) pinger(ctx context.Context) {
	defer close(c.shutdownC)
	pongMessge := wsutil.Message{
		OpCode:  gobwas.OpPong,
		Payload: []byte{},
	}

	ticker := time.NewTicker(c.pingPeriod(ctx))
	defer ticker.Stop()
	for {
		select {
		case <-ticker.C:
			if err := wsutil.WriteServerMessage(c.rw, gobwas.OpPing, []byte{}); err != nil {
				c.log.Debug().Err(err).Msgf("failed to write ping message")
			}
			if err := wsutil.HandleClientControlMessage(c.rw, pongMessge); err != nil {
				c.log.Debug().Err(err).Msgf("failed to write pong message")
			}
		case <-ctx.Done():
			return
		}
	}
}

func (c *Conn) pingPeriod(ctx context.Context) time.Duration {
	if val := ctx.Value(PingPeriodContextKey); val != nil {
		if period, ok := val.(time.Duration); ok {
			return period
		}
	}
	return defaultPingPeriod
}

// Close waits for pinger to terminate
func (c *Conn) WaitForShutdown() {
	<-c.shutdownC
}