package rpc

import (
	"errors"
	"sync"

	"golang.org/x/net/context"
	"zombiezen.com/go/capnproto2"
	"zombiezen.com/go/capnproto2/internal/fulfiller"
	"zombiezen.com/go/capnproto2/internal/queue"
	rpccapnp "zombiezen.com/go/capnproto2/std/capnp/rpc"
)

// callQueueSize is the maximum number of calls that can be queued per answer or client.
// TODO(light): make this a ConnOption
const callQueueSize = 64

// insertAnswer creates a new answer with the given ID, returning nil
// if the ID is already in use.
func (c *Conn) insertAnswer(id answerID, cancel context.CancelFunc) *answer {
	if c.answers == nil {
		c.answers = make(map[answerID]*answer)
	} else if _, exists := c.answers[id]; exists {
		return nil
	}
	a := &answer{
		id:       id,
		cancel:   cancel,
		conn:     c,
		resolved: make(chan struct{}),
		queue:    make([]pcall, 0, callQueueSize),
	}
	c.answers[id] = a
	return a
}

func (c *Conn) popAnswer(id answerID) *answer {
	if c.answers == nil {
		return nil
	}
	a := c.answers[id]
	delete(c.answers, id)
	return a
}

type answer struct {
	id         answerID
	cancel     context.CancelFunc
	resultCaps []exportID
	conn       *Conn
	resolved   chan struct{}

	mu    sync.RWMutex
	obj   capnp.Ptr
	err   error
	done  bool
	queue []pcall
}

// fulfill is called to resolve an answer successfully.  It returns an
// error if its connection is shut down while sending messages.  The
// caller must be holding onto a.conn.mu.
func (a *answer) fulfill(obj capnp.Ptr) error {
	a.mu.Lock()
	if a.done {
		panic("answer.fulfill called more than once")
	}
	a.obj, a.done = obj, true
	// TODO(light): populate resultCaps

	var firstErr error
	if err := a.conn.startWork(); err != nil {
		firstErr = err
		for i := range a.queue {
			a.queue[i].a.reject(err)
		}
		a.queue = nil
	} else {
		retmsg := newReturnMessage(nil, a.id)
		ret, _ := retmsg.Return()
		payload, _ := ret.NewResults()
		payload.SetContentPtr(obj)
		if payloadTab, err := a.conn.makeCapTable(ret.Segment()); err != nil {
			firstErr = err
		} else {
			payload.SetCapTable(payloadTab)
			if err := a.conn.sendMessage(retmsg); err != nil {
				firstErr = err
			}
		}

		queues, err := a.emptyQueue(obj)
		if err != nil && firstErr == nil {
			firstErr = err
		}
		ctab := obj.Segment().Message().CapTable
		for capIdx, q := range queues {
			ctab[capIdx] = newQueueClient(a.conn, ctab[capIdx], q)
		}
		a.conn.workers.Done()
	}
	close(a.resolved)
	a.mu.Unlock()
	return firstErr
}

// reject is called to resolve an answer with failure.  It returns an
// error if its connection is shut down while sending messages.  The
// caller must be holding onto a.conn.mu.
func (a *answer) reject(err error) error {
	if err == nil {
		panic("answer.reject called with nil")
	}
	a.mu.Lock()
	if a.done {
		panic("answer.reject called more than once")
	}
	a.err, a.done = err, true
	m := newReturnMessage(nil, a.id)
	mret, _ := m.Return()
	setReturnException(mret, err)
	var firstErr error
	if err := a.conn.sendMessage(m); err != nil {
		firstErr = err
	}
	for i := range a.queue {
		if err := a.queue[i].a.reject(err); err != nil && firstErr == nil {
			firstErr = err
		}
	}
	a.queue = nil
	close(a.resolved)
	a.mu.Unlock()
	return firstErr
}

// emptyQueue splits the queue by which capability it targets
// and drops any invalid calls.  Once this function returns, a.queue
// will be nil.
func (a *answer) emptyQueue(obj capnp.Ptr) (map[capnp.CapabilityID][]qcall, error) {
	var firstErr error
	qs := make(map[capnp.CapabilityID][]qcall, len(a.queue))
	for i, pc := range a.queue {
		c, err := capnp.TransformPtr(obj, pc.transform)
		if err != nil {
			if err := pc.a.reject(err); err != nil && firstErr == nil {
				firstErr = err
			}
			continue
		}
		ci := c.Interface()
		if !ci.IsValid() {
			if err := pc.a.reject(capnp.ErrNullClient); err != nil && firstErr == nil {
				firstErr = err
			}
			continue
		}
		cn := ci.Capability()
		if qs[cn] == nil {
			qs[cn] = make([]qcall, 0, len(a.queue)-i)
		}
		qs[cn] = append(qs[cn], pc.qcall)
	}
	a.queue = nil
	return qs, firstErr
}

// queueCallLocked enqueues a call to be made after the answer has been
// resolved.  The answer must not be resolved yet.  pc should have
// transform and one of pc.a or pc.f to be set.  The caller must be
// holding onto a.mu.
func (a *answer) queueCallLocked(call *capnp.Call, pc pcall) error {
	if len(a.queue) == cap(a.queue) {
		return errQueueFull
	}
	var err error
	pc.call, err = call.Copy(nil)
	if err != nil {
		return err
	}
	a.queue = append(a.queue, pc)
	return nil
}

// queueDisembargo enqueues a disembargo message.
func (a *answer) queueDisembargo(transform []capnp.PipelineOp, id embargoID, target rpccapnp.MessageTarget) (queued bool, err error) {
	a.mu.Lock()
	defer a.mu.Unlock()
	if !a.done {
		return false, errDisembargoOngoingAnswer
	}
	if a.err != nil {
		return false, errDisembargoNonImport
	}
	targetPtr, err := capnp.TransformPtr(a.obj, transform)
	if err != nil {
		return false, err
	}
	client := targetPtr.Interface().Client()
	qc, ok := client.(*queueClient)
	if !ok {
		// No need to embargo, disembargo immediately.
		return false, nil
	}
	if ic := isImport(qc.client); ic == nil || a.conn != ic.conn {
		return false, errDisembargoNonImport
	}
	qc.mu.Lock()
	if !qc.isPassthrough() {
		err = qc.pushEmbargoLocked(id, target)
		if err == nil {
			queued = true
		}
	}
	qc.mu.Unlock()
	return queued, err
}

func (a *answer) pipelineClient(transform []capnp.PipelineOp) capnp.Client {
	return &localAnswerClient{a: a, transform: transform}
}

// joinAnswer resolves an RPC answer by waiting on a generic answer.
// The caller must not be holding onto a.conn.mu.
func joinAnswer(a *answer, ca capnp.Answer) {
	s, err := ca.Struct()
	a.conn.mu.Lock()
	if err == nil {
		a.fulfill(s.ToPtr())
	} else {
		a.reject(err)
	}
	a.conn.mu.Unlock()
}

// joinFulfiller resolves a fulfiller by waiting on a generic answer.
func joinFulfiller(f *fulfiller.Fulfiller, ca capnp.Answer) {
	s, err := ca.Struct()
	if err != nil {
		f.Reject(err)
	} else {
		f.Fulfill(s)
	}
}

type queueClient struct {
	client capnp.Client
	conn   *Conn

	mu    sync.RWMutex
	q     queue.Queue
	calls qcallList
}

func newQueueClient(c *Conn, client capnp.Client, queue []qcall) *queueClient {
	qc := &queueClient{
		client: client,
		conn:   c,
		calls:  make(qcallList, callQueueSize),
	}
	qc.q.Init(qc.calls, copy(qc.calls, queue))
	go qc.flushQueue()
	return qc
}

func (qc *queueClient) pushCallLocked(cl *capnp.Call) capnp.Answer {
	f := new(fulfiller.Fulfiller)
	cl, err := cl.Copy(nil)
	if err != nil {
		return capnp.ErrorAnswer(err)
	}
	i := qc.q.Push()
	if i == -1 {
		return capnp.ErrorAnswer(errQueueFull)
	}
	qc.calls[i] = qcall{call: cl, f: f}
	return f
}

func (qc *queueClient) pushEmbargoLocked(id embargoID, tgt rpccapnp.MessageTarget) error {
	i := qc.q.Push()
	if i == -1 {
		return errQueueFull
	}
	qc.calls[i] = qcall{embargoID: id, embargoTarget: tgt}
	return nil
}

// flushQueue is run in its own goroutine.
func (qc *queueClient) flushQueue() {
	var c qcall
	qc.mu.RLock()
	if i := qc.q.Front(); i != -1 {
		c = qc.calls[i]
	}
	qc.mu.RUnlock()
	for c.which() != qcallInvalid {
		qc.handle(&c)

		qc.mu.Lock()
		qc.q.Pop()
		if i := qc.q.Front(); i != -1 {
			c = qc.calls[i]
		} else {
			c = qcall{}
		}
		qc.mu.Unlock()
	}
}

func (qc *queueClient) handle(c *qcall) {
	switch c.which() {
	case qcallRemoteCall:
		answer := qc.client.Call(c.call)
		go joinAnswer(c.a, answer)
	case qcallLocalCall:
		answer := qc.client.Call(c.call)
		go joinFulfiller(c.f, answer)
	case qcallDisembargo:
		msg := newDisembargoMessage(nil, rpccapnp.Disembargo_context_Which_receiverLoopback, c.embargoID)
		d, _ := msg.Disembargo()
		d.SetTarget(c.embargoTarget)
		qc.conn.sendMessage(msg)
	}
}

func (qc *queueClient) isPassthrough() bool {
	return qc.q.Len() == 0
}

func (qc *queueClient) Call(cl *capnp.Call) capnp.Answer {
	// Fast path: queue is flushed.
	qc.mu.RLock()
	ok := qc.isPassthrough()
	qc.mu.RUnlock()
	if ok {
		return qc.client.Call(cl)
	}

	// Add to queue.
	qc.mu.Lock()
	// Since we released the lock, check that the queue hasn't been flushed.
	if qc.isPassthrough() {
		qc.mu.Unlock()
		return qc.client.Call(cl)
	}
	ans := qc.pushCallLocked(cl)
	qc.mu.Unlock()
	return ans
}

func (qc *queueClient) tryQueue(cl *capnp.Call) capnp.Answer {
	qc.mu.Lock()
	if qc.isPassthrough() {
		qc.mu.Unlock()
		return nil
	}
	ans := qc.pushCallLocked(cl)
	qc.mu.Unlock()
	return ans
}

func (qc *queueClient) Close() error {
	qc.conn.mu.Lock()
	if err := qc.conn.startWork(); err != nil {
		qc.conn.mu.Unlock()
		return err
	}
	rejErr := qc.rejectQueue()
	qc.conn.workers.Done()
	qc.conn.mu.Unlock()
	if err := qc.client.Close(); err != nil {
		return err
	}
	return rejErr
}

// rejectQueue drains the client's queue.  It returns an error if the
// connection was shut down while messages are sent.  The caller must be
// holding onto qc.conn.mu.
func (qc *queueClient) rejectQueue() error {
	var firstErr error
	qc.mu.Lock()
	for ; qc.q.Len() > 0; qc.q.Pop() {
		c := qc.calls[qc.q.Front()]
		switch c.which() {
		case qcallRemoteCall:
			if err := c.a.reject(errQueueCallCancel); err != nil && firstErr == nil {
				firstErr = err
			}
		case qcallLocalCall:
			c.f.Reject(errQueueCallCancel)
		case qcallDisembargo:
			m := newDisembargoMessage(nil, rpccapnp.Disembargo_context_Which_receiverLoopback, c.embargoID)
			d, _ := m.Disembargo()
			d.SetTarget(c.embargoTarget)
			if err := qc.conn.sendMessage(m); err != nil && firstErr == nil {
				firstErr = err
			}
		}
	}
	qc.mu.Unlock()
	return firstErr
}

// pcall is a queued pipeline call.
type pcall struct {
	transform []capnp.PipelineOp
	qcall
}

// qcall is a queued call.
type qcall struct {
	// Calls
	a    *answer              // non-nil if remote call
	f    *fulfiller.Fulfiller // non-nil if local call
	call *capnp.Call

	// Disembargo
	embargoID     embargoID
	embargoTarget rpccapnp.MessageTarget
}

// Queued call types.
const (
	qcallInvalid = iota
	qcallRemoteCall
	qcallLocalCall
	qcallDisembargo
)

func (c *qcall) which() int {
	switch {
	case c.a != nil:
		return qcallRemoteCall
	case c.f != nil:
		return qcallLocalCall
	case c.embargoTarget.IsValid():
		return qcallDisembargo
	default:
		return qcallInvalid
	}
}

type qcallList []qcall

func (ql qcallList) Len() int {
	return len(ql)
}

func (ql qcallList) Clear(i int) {
	ql[i] = qcall{}
}

// A localAnswerClient is used to provide a pipelined client of an answer.
type localAnswerClient struct {
	a         *answer
	transform []capnp.PipelineOp
}

func (lac *localAnswerClient) Call(call *capnp.Call) capnp.Answer {
	lac.a.mu.Lock()
	if lac.a.done {
		obj, err := lac.a.obj, lac.a.err
		lac.a.mu.Unlock()
		return clientFromResolution(lac.transform, obj, err).Call(call)
	}
	f := new(fulfiller.Fulfiller)
	err := lac.a.queueCallLocked(call, pcall{
		transform: lac.transform,
		qcall:     qcall{f: f},
	})
	lac.a.mu.Unlock()
	if err != nil {
		return capnp.ErrorAnswer(errQueueFull)
	}
	return f
}

func (lac *localAnswerClient) Close() error {
	lac.a.mu.RLock()
	obj, err, done := lac.a.obj, lac.a.err, lac.a.done
	lac.a.mu.RUnlock()
	if !done {
		return nil
	}
	client := clientFromResolution(lac.transform, obj, err)
	return client.Close()
}

var (
	errQueueFull       = errors.New("rpc: pipeline queue full")
	errQueueCallCancel = errors.New("rpc: queued call canceled")

	errDisembargoOngoingAnswer = errors.New("rpc: disembargo attempted on in-progress answer")
	errDisembargoNonImport     = errors.New("rpc: disembargo attempted on non-import capability")
	errDisembargoMissingAnswer = errors.New("rpc: disembargo attempted on missing answer (finished too early?)")
)