package rpc

import (
	"sync"

	"golang.org/x/net/context"
	"zombiezen.com/go/capnproto2"
	"zombiezen.com/go/capnproto2/internal/fulfiller"
	"zombiezen.com/go/capnproto2/internal/queue"
	rpccapnp "zombiezen.com/go/capnproto2/std/capnp/rpc"
)

// newQuestion creates a new question with an unassigned ID.
func (c *Conn) newQuestion(ctx context.Context, method *capnp.Method) *question {
	id := questionID(c.questionID.next())
	q := &question{
		ctx:      ctx,
		conn:     c,
		method:   method,
		resolved: make(chan struct{}),
		id:       id,
	}
	// TODO(light): populate paramCaps
	if int(id) == len(c.questions) {
		c.questions = append(c.questions, q)
	} else {
		c.questions[id] = q
	}
	return q
}

func (c *Conn) findQuestion(id questionID) *question {
	if int(id) >= len(c.questions) {
		return nil
	}
	return c.questions[id]
}

func (c *Conn) popQuestion(id questionID) *question {
	q := c.findQuestion(id)
	if q == nil {
		return nil
	}
	c.questions[id] = nil
	c.questionID.remove(uint32(id))
	return q
}

type question struct {
	id        questionID
	ctx       context.Context
	conn      *Conn
	method    *capnp.Method // nil if this is bootstrap
	paramCaps []exportID
	resolved  chan struct{}

	// Protected by conn.mu
	derived [][]capnp.PipelineOp

	// Fields below are protected by mu.
	mu    sync.RWMutex
	obj   capnp.Ptr
	err   error
	state questionState
}

type questionState uint8

// Question states
const (
	questionInProgress questionState = iota
	questionResolved
	questionCanceled
)

// start signals that the question has been sent.
func (q *question) start() {
	go func() {
		select {
		case <-q.resolved:
			// Resolved naturally, nothing to do.
		case <-q.conn.bg.Done():
		case <-q.ctx.Done():
			select {
			case <-q.resolved:
			case <-q.conn.bg.Done():
			case <-q.conn.mu:
				if err := q.conn.startWork(); err != nil {
					// teardown calls cancel.
					q.conn.mu.Unlock()
					return
				}
				if q.cancel(q.ctx.Err()) {
					q.conn.sendMessage(newFinishMessage(nil, q.id, true /* release */))
				}
				q.conn.workers.Done()
				q.conn.mu.Unlock()
			}
		}
	}()
}

// fulfill is called to resolve a question successfully.
// The caller must be holding onto q.conn.mu.
func (q *question) fulfill(obj capnp.Ptr) {
	var ctab []capnp.Client
	if obj.IsValid() {
		ctab = obj.Segment().Message().CapTable
	}
	visited := make([]bool, len(ctab))
	for _, d := range q.derived {
		tgt, err := capnp.TransformPtr(obj, d)
		if err != nil {
			continue
		}
		in := tgt.Interface()
		if !in.IsValid() {
			continue
		}
		if ic := isImport(in.Client()); ic != nil && ic.conn == q.conn {
			// Imported from remote vat.  Don't need to disembargo.
			continue
		}
		cn := in.Capability()
		if visited[cn] {
			continue
		}
		visited[cn] = true
		id, e := q.conn.newEmbargo()
		ctab[cn] = newEmbargoClient(ctab[cn], e, q.conn.bg.Done())
		m := newDisembargoMessage(nil, rpccapnp.Disembargo_context_Which_senderLoopback, id)
		dis, _ := m.Disembargo()
		mt, _ := dis.NewTarget()
		pa, _ := mt.NewPromisedAnswer()
		pa.SetQuestionId(uint32(q.id))
		transformToPromisedAnswer(m.Segment(), pa, d)
		mt.SetPromisedAnswer(pa)

		select {
		case q.conn.out <- m:
		case <-q.conn.bg.Done():
			// TODO(soon): perhaps just drop all embargoes in this case?
		}
	}

	q.mu.Lock()
	if q.state != questionInProgress {
		panic("question.fulfill called more than once")
	}
	q.obj, q.state = obj, questionResolved
	close(q.resolved)
	q.mu.Unlock()
}

// reject is called to resolve a question with failure.
// The caller must be holding onto q.conn.mu.
func (q *question) reject(err error) {
	if err == nil {
		panic("question.reject called with nil")
	}
	q.mu.Lock()
	if q.state != questionInProgress {
		panic("question.reject called more than once")
	}
	q.err = err
	q.state = questionResolved
	close(q.resolved)
	q.mu.Unlock()
}

// cancel is called to resolve a question with cancellation.
// The caller must be holding onto q.conn.mu.
func (q *question) cancel(err error) bool {
	if err == nil {
		panic("question.cancel called with nil")
	}
	q.mu.Lock()
	canceled := q.state == questionInProgress
	if canceled {
		q.err = err
		q.state = questionCanceled
		close(q.resolved)
	}
	q.mu.Unlock()
	return canceled
}

// addPromise records a returned capability as being used for a call.
// This is needed for determining embargoes upon resolution.  The
// caller must be holding onto q.conn.mu.
func (q *question) addPromise(transform []capnp.PipelineOp) {
	for _, d := range q.derived {
		if transformsEqual(transform, d) {
			return
		}
	}
	q.derived = append(q.derived, transform)
}

func transformsEqual(t, u []capnp.PipelineOp) bool {
	if len(t) != len(u) {
		return false
	}
	for i := range t {
		if t[i].Field != u[i].Field {
			return false
		}
	}
	return true
}

func (q *question) Struct() (capnp.Struct, error) {
	select {
	case <-q.resolved:
	case <-q.conn.bg.Done():
		return capnp.Struct{}, ErrConnClosed
	}
	q.mu.RLock()
	s, err := q.obj.Struct(), q.err
	q.mu.RUnlock()
	return s, err
}

func (q *question) PipelineCall(transform []capnp.PipelineOp, ccall *capnp.Call) capnp.Answer {
	select {
	case <-q.conn.mu:
		if err := q.conn.startWork(); err != nil {
			q.conn.mu.Unlock()
			return capnp.ErrorAnswer(err)
		}
	case <-ccall.Ctx.Done():
		return capnp.ErrorAnswer(ccall.Ctx.Err())
	}
	ans := q.lockedPipelineCall(transform, ccall)
	q.conn.workers.Done()
	q.conn.mu.Unlock()
	return ans
}

// lockedPipelineCall is equivalent to PipelineCall but assumes that the
// caller is already holding onto q.conn.mu.
func (q *question) lockedPipelineCall(transform []capnp.PipelineOp, ccall *capnp.Call) capnp.Answer {
	if q.conn.findQuestion(q.id) != q {
		// Question has been finished.  The call should happen as if it is
		// back in application code.
		q.mu.RLock()
		obj, err, state := q.obj, q.err, q.state
		q.mu.RUnlock()
		if state == questionInProgress {
			panic("question popped but not done")
		}
		client := clientFromResolution(transform, obj, err)
		return q.conn.lockedCall(client, ccall)
	}

	pipeq := q.conn.newQuestion(ccall.Ctx, &ccall.Method)
	msg := newMessage(nil)
	msgCall, _ := msg.NewCall()
	msgCall.SetQuestionId(uint32(pipeq.id))
	msgCall.SetInterfaceId(ccall.Method.InterfaceID)
	msgCall.SetMethodId(ccall.Method.MethodID)
	target, _ := msgCall.NewTarget()
	a, _ := target.NewPromisedAnswer()
	a.SetQuestionId(uint32(q.id))
	err := transformToPromisedAnswer(a.Segment(), a, transform)
	if err != nil {
		q.conn.popQuestion(pipeq.id)
		return capnp.ErrorAnswer(err)
	}
	payload, _ := msgCall.NewParams()
	if err := q.conn.fillParams(payload, ccall); err != nil {
		q.conn.popQuestion(q.id)
		return capnp.ErrorAnswer(err)
	}

	select {
	case q.conn.out <- msg:
	case <-ccall.Ctx.Done():
		q.conn.popQuestion(pipeq.id)
		return capnp.ErrorAnswer(ccall.Ctx.Err())
	case <-q.conn.bg.Done():
		q.conn.popQuestion(pipeq.id)
		return capnp.ErrorAnswer(ErrConnClosed)
	}
	q.addPromise(transform)
	pipeq.start()
	return pipeq
}

func (q *question) PipelineClose(transform []capnp.PipelineOp) error {
	<-q.resolved
	q.mu.RLock()
	obj, err := q.obj, q.err
	q.mu.RUnlock()
	if err != nil {
		return err
	}
	x, err := capnp.TransformPtr(obj, transform)
	if err != nil {
		return err
	}
	c := x.Interface().Client()
	if c == nil {
		return capnp.ErrNullClient
	}
	return c.Close()
}

// embargoClient is a client that waits until an embargo signal is
// received to deliver calls.
type embargoClient struct {
	cancel  <-chan struct{}
	client  capnp.Client
	embargo embargo

	mu    sync.RWMutex
	q     queue.Queue
	calls ecallList
}

func newEmbargoClient(client capnp.Client, e embargo, cancel <-chan struct{}) *embargoClient {
	ec := &embargoClient{
		client:  client,
		embargo: e,
		cancel:  cancel,
		calls:   make(ecallList, callQueueSize),
	}
	ec.q.Init(ec.calls, 0)
	go ec.flushQueue()
	return ec
}

func (ec *embargoClient) push(cl *capnp.Call) capnp.Answer {
	f := new(fulfiller.Fulfiller)
	cl, err := cl.Copy(nil)
	if err != nil {
		return capnp.ErrorAnswer(err)
	}
	i := ec.q.Push()
	if i == -1 {
		return capnp.ErrorAnswer(errQueueFull)
	}
	ec.calls[i] = ecall{cl, f}
	return f
}

func (ec *embargoClient) Call(cl *capnp.Call) capnp.Answer {
	// Fast path: queue is flushed.
	ec.mu.RLock()
	ok := ec.isPassthrough()
	ec.mu.RUnlock()
	if ok {
		return ec.client.Call(cl)
	}

	ec.mu.Lock()
	if ec.isPassthrough() {
		ec.mu.Unlock()
		return ec.client.Call(cl)
	}
	ans := ec.push(cl)
	ec.mu.Unlock()
	return ans
}

func (ec *embargoClient) tryQueue(cl *capnp.Call) capnp.Answer {
	ec.mu.Lock()
	if ec.isPassthrough() {
		ec.mu.Unlock()
		return nil
	}
	ans := ec.push(cl)
	ec.mu.Unlock()
	return ans
}

func (ec *embargoClient) isPassthrough() bool {
	select {
	case <-ec.embargo:
	default:
		return false
	}
	return ec.q.Len() == 0
}

func (ec *embargoClient) Close() error {
	ec.mu.Lock()
	for ; ec.q.Len() > 0; ec.q.Pop() {
		c := ec.calls[ec.q.Front()]
		c.f.Reject(errQueueCallCancel)
	}
	ec.mu.Unlock()
	return ec.client.Close()
}

// flushQueue is run in its own goroutine.
func (ec *embargoClient) flushQueue() {
	select {
	case <-ec.embargo:
	case <-ec.cancel:
		ec.mu.Lock()
		for ec.q.Len() > 0 {
			ec.q.Pop()
		}
		ec.mu.Unlock()
		return
	}
	var c ecall
	ec.mu.RLock()
	if i := ec.q.Front(); i != -1 {
		c = ec.calls[i]
	}
	ec.mu.RUnlock()
	for c.call != nil {
		ans := ec.client.Call(c.call)
		go joinFulfiller(c.f, ans)

		ec.mu.Lock()
		ec.q.Pop()
		if i := ec.q.Front(); i != -1 {
			c = ec.calls[i]
		} else {
			c = ecall{}
		}
		ec.mu.Unlock()
	}
}

type ecall struct {
	call *capnp.Call
	f    *fulfiller.Fulfiller
}

type ecallList []ecall

func (el ecallList) Len() int {
	return len(el)
}

func (el ecallList) Clear(i int) {
	el[i] = ecall{}
}