cloudflared-mirror/vendor/zombiezen.com/go/capnproto2/rpc/answer.go

499 lines
11 KiB
Go

package rpc
import (
"errors"
"sync"
"golang.org/x/net/context"
"zombiezen.com/go/capnproto2"
"zombiezen.com/go/capnproto2/internal/fulfiller"
"zombiezen.com/go/capnproto2/internal/queue"
rpccapnp "zombiezen.com/go/capnproto2/std/capnp/rpc"
)
// callQueueSize is the maximum number of calls that can be queued per answer or client.
// TODO(light): make this a ConnOption
const callQueueSize = 64
// insertAnswer creates a new answer with the given ID, returning nil
// if the ID is already in use.
func (c *Conn) insertAnswer(id answerID, cancel context.CancelFunc) *answer {
if c.answers == nil {
c.answers = make(map[answerID]*answer)
} else if _, exists := c.answers[id]; exists {
return nil
}
a := &answer{
id: id,
cancel: cancel,
conn: c,
resolved: make(chan struct{}),
queue: make([]pcall, 0, callQueueSize),
}
c.answers[id] = a
return a
}
func (c *Conn) popAnswer(id answerID) *answer {
if c.answers == nil {
return nil
}
a := c.answers[id]
delete(c.answers, id)
return a
}
type answer struct {
id answerID
cancel context.CancelFunc
resultCaps []exportID
conn *Conn
resolved chan struct{}
mu sync.RWMutex
obj capnp.Ptr
err error
done bool
queue []pcall
}
// fulfill is called to resolve an answer successfully. It returns an
// error if its connection is shut down while sending messages. The
// caller must be holding onto a.conn.mu.
func (a *answer) fulfill(obj capnp.Ptr) error {
a.mu.Lock()
if a.done {
panic("answer.fulfill called more than once")
}
a.obj, a.done = obj, true
// TODO(light): populate resultCaps
var firstErr error
if err := a.conn.startWork(); err != nil {
firstErr = err
for i := range a.queue {
a.queue[i].a.reject(err)
}
a.queue = nil
} else {
retmsg := newReturnMessage(nil, a.id)
ret, _ := retmsg.Return()
payload, _ := ret.NewResults()
payload.SetContentPtr(obj)
if payloadTab, err := a.conn.makeCapTable(ret.Segment()); err != nil {
firstErr = err
} else {
payload.SetCapTable(payloadTab)
if err := a.conn.sendMessage(retmsg); err != nil {
firstErr = err
}
}
queues, err := a.emptyQueue(obj)
if err != nil && firstErr == nil {
firstErr = err
}
ctab := obj.Segment().Message().CapTable
for capIdx, q := range queues {
ctab[capIdx] = newQueueClient(a.conn, ctab[capIdx], q)
}
a.conn.workers.Done()
}
close(a.resolved)
a.mu.Unlock()
return firstErr
}
// reject is called to resolve an answer with failure. It returns an
// error if its connection is shut down while sending messages. The
// caller must be holding onto a.conn.mu.
func (a *answer) reject(err error) error {
if err == nil {
panic("answer.reject called with nil")
}
a.mu.Lock()
if a.done {
panic("answer.reject called more than once")
}
a.err, a.done = err, true
m := newReturnMessage(nil, a.id)
mret, _ := m.Return()
setReturnException(mret, err)
var firstErr error
if err := a.conn.sendMessage(m); err != nil {
firstErr = err
}
for i := range a.queue {
if err := a.queue[i].a.reject(err); err != nil && firstErr == nil {
firstErr = err
}
}
a.queue = nil
close(a.resolved)
a.mu.Unlock()
return firstErr
}
// emptyQueue splits the queue by which capability it targets
// and drops any invalid calls. Once this function returns, a.queue
// will be nil.
func (a *answer) emptyQueue(obj capnp.Ptr) (map[capnp.CapabilityID][]qcall, error) {
var firstErr error
qs := make(map[capnp.CapabilityID][]qcall, len(a.queue))
for i, pc := range a.queue {
c, err := capnp.TransformPtr(obj, pc.transform)
if err != nil {
if err := pc.a.reject(err); err != nil && firstErr == nil {
firstErr = err
}
continue
}
ci := c.Interface()
if !ci.IsValid() {
if err := pc.a.reject(capnp.ErrNullClient); err != nil && firstErr == nil {
firstErr = err
}
continue
}
cn := ci.Capability()
if qs[cn] == nil {
qs[cn] = make([]qcall, 0, len(a.queue)-i)
}
qs[cn] = append(qs[cn], pc.qcall)
}
a.queue = nil
return qs, firstErr
}
// queueCallLocked enqueues a call to be made after the answer has been
// resolved. The answer must not be resolved yet. pc should have
// transform and one of pc.a or pc.f to be set. The caller must be
// holding onto a.mu.
func (a *answer) queueCallLocked(call *capnp.Call, pc pcall) error {
if len(a.queue) == cap(a.queue) {
return errQueueFull
}
var err error
pc.call, err = call.Copy(nil)
if err != nil {
return err
}
a.queue = append(a.queue, pc)
return nil
}
// queueDisembargo enqueues a disembargo message.
func (a *answer) queueDisembargo(transform []capnp.PipelineOp, id embargoID, target rpccapnp.MessageTarget) (queued bool, err error) {
a.mu.Lock()
defer a.mu.Unlock()
if !a.done {
return false, errDisembargoOngoingAnswer
}
if a.err != nil {
return false, errDisembargoNonImport
}
targetPtr, err := capnp.TransformPtr(a.obj, transform)
if err != nil {
return false, err
}
client := targetPtr.Interface().Client()
qc, ok := client.(*queueClient)
if !ok {
// No need to embargo, disembargo immediately.
return false, nil
}
if ic := isImport(qc.client); ic == nil || a.conn != ic.conn {
return false, errDisembargoNonImport
}
qc.mu.Lock()
if !qc.isPassthrough() {
err = qc.pushEmbargoLocked(id, target)
if err == nil {
queued = true
}
}
qc.mu.Unlock()
return queued, err
}
func (a *answer) pipelineClient(transform []capnp.PipelineOp) capnp.Client {
return &localAnswerClient{a: a, transform: transform}
}
// joinAnswer resolves an RPC answer by waiting on a generic answer.
// The caller must not be holding onto a.conn.mu.
func joinAnswer(a *answer, ca capnp.Answer) {
s, err := ca.Struct()
a.conn.mu.Lock()
if err == nil {
a.fulfill(s.ToPtr())
} else {
a.reject(err)
}
a.conn.mu.Unlock()
}
// joinFulfiller resolves a fulfiller by waiting on a generic answer.
func joinFulfiller(f *fulfiller.Fulfiller, ca capnp.Answer) {
s, err := ca.Struct()
if err != nil {
f.Reject(err)
} else {
f.Fulfill(s)
}
}
type queueClient struct {
client capnp.Client
conn *Conn
mu sync.RWMutex
q queue.Queue
calls qcallList
}
func newQueueClient(c *Conn, client capnp.Client, queue []qcall) *queueClient {
qc := &queueClient{
client: client,
conn: c,
calls: make(qcallList, callQueueSize),
}
qc.q.Init(qc.calls, copy(qc.calls, queue))
go qc.flushQueue()
return qc
}
func (qc *queueClient) pushCallLocked(cl *capnp.Call) capnp.Answer {
f := new(fulfiller.Fulfiller)
cl, err := cl.Copy(nil)
if err != nil {
return capnp.ErrorAnswer(err)
}
i := qc.q.Push()
if i == -1 {
return capnp.ErrorAnswer(errQueueFull)
}
qc.calls[i] = qcall{call: cl, f: f}
return f
}
func (qc *queueClient) pushEmbargoLocked(id embargoID, tgt rpccapnp.MessageTarget) error {
i := qc.q.Push()
if i == -1 {
return errQueueFull
}
qc.calls[i] = qcall{embargoID: id, embargoTarget: tgt}
return nil
}
// flushQueue is run in its own goroutine.
func (qc *queueClient) flushQueue() {
var c qcall
qc.mu.RLock()
if i := qc.q.Front(); i != -1 {
c = qc.calls[i]
}
qc.mu.RUnlock()
for c.which() != qcallInvalid {
qc.handle(&c)
qc.mu.Lock()
qc.q.Pop()
if i := qc.q.Front(); i != -1 {
c = qc.calls[i]
} else {
c = qcall{}
}
qc.mu.Unlock()
}
}
func (qc *queueClient) handle(c *qcall) {
switch c.which() {
case qcallRemoteCall:
answer := qc.client.Call(c.call)
go joinAnswer(c.a, answer)
case qcallLocalCall:
answer := qc.client.Call(c.call)
go joinFulfiller(c.f, answer)
case qcallDisembargo:
msg := newDisembargoMessage(nil, rpccapnp.Disembargo_context_Which_receiverLoopback, c.embargoID)
d, _ := msg.Disembargo()
d.SetTarget(c.embargoTarget)
qc.conn.sendMessage(msg)
}
}
func (qc *queueClient) isPassthrough() bool {
return qc.q.Len() == 0
}
func (qc *queueClient) Call(cl *capnp.Call) capnp.Answer {
// Fast path: queue is flushed.
qc.mu.RLock()
ok := qc.isPassthrough()
qc.mu.RUnlock()
if ok {
return qc.client.Call(cl)
}
// Add to queue.
qc.mu.Lock()
// Since we released the lock, check that the queue hasn't been flushed.
if qc.isPassthrough() {
qc.mu.Unlock()
return qc.client.Call(cl)
}
ans := qc.pushCallLocked(cl)
qc.mu.Unlock()
return ans
}
func (qc *queueClient) tryQueue(cl *capnp.Call) capnp.Answer {
qc.mu.Lock()
if qc.isPassthrough() {
qc.mu.Unlock()
return nil
}
ans := qc.pushCallLocked(cl)
qc.mu.Unlock()
return ans
}
func (qc *queueClient) Close() error {
qc.conn.mu.Lock()
if err := qc.conn.startWork(); err != nil {
qc.conn.mu.Unlock()
return err
}
rejErr := qc.rejectQueue()
qc.conn.workers.Done()
qc.conn.mu.Unlock()
if err := qc.client.Close(); err != nil {
return err
}
return rejErr
}
// rejectQueue drains the client's queue. It returns an error if the
// connection was shut down while messages are sent. The caller must be
// holding onto qc.conn.mu.
func (qc *queueClient) rejectQueue() error {
var firstErr error
qc.mu.Lock()
for ; qc.q.Len() > 0; qc.q.Pop() {
c := qc.calls[qc.q.Front()]
switch c.which() {
case qcallRemoteCall:
if err := c.a.reject(errQueueCallCancel); err != nil && firstErr == nil {
firstErr = err
}
case qcallLocalCall:
c.f.Reject(errQueueCallCancel)
case qcallDisembargo:
m := newDisembargoMessage(nil, rpccapnp.Disembargo_context_Which_receiverLoopback, c.embargoID)
d, _ := m.Disembargo()
d.SetTarget(c.embargoTarget)
if err := qc.conn.sendMessage(m); err != nil && firstErr == nil {
firstErr = err
}
}
}
qc.mu.Unlock()
return firstErr
}
// pcall is a queued pipeline call.
type pcall struct {
transform []capnp.PipelineOp
qcall
}
// qcall is a queued call.
type qcall struct {
// Calls
a *answer // non-nil if remote call
f *fulfiller.Fulfiller // non-nil if local call
call *capnp.Call
// Disembargo
embargoID embargoID
embargoTarget rpccapnp.MessageTarget
}
// Queued call types.
const (
qcallInvalid = iota
qcallRemoteCall
qcallLocalCall
qcallDisembargo
)
func (c *qcall) which() int {
switch {
case c.a != nil:
return qcallRemoteCall
case c.f != nil:
return qcallLocalCall
case c.embargoTarget.IsValid():
return qcallDisembargo
default:
return qcallInvalid
}
}
type qcallList []qcall
func (ql qcallList) Len() int {
return len(ql)
}
func (ql qcallList) Clear(i int) {
ql[i] = qcall{}
}
// A localAnswerClient is used to provide a pipelined client of an answer.
type localAnswerClient struct {
a *answer
transform []capnp.PipelineOp
}
func (lac *localAnswerClient) Call(call *capnp.Call) capnp.Answer {
lac.a.mu.Lock()
if lac.a.done {
obj, err := lac.a.obj, lac.a.err
lac.a.mu.Unlock()
return clientFromResolution(lac.transform, obj, err).Call(call)
}
f := new(fulfiller.Fulfiller)
err := lac.a.queueCallLocked(call, pcall{
transform: lac.transform,
qcall: qcall{f: f},
})
lac.a.mu.Unlock()
if err != nil {
return capnp.ErrorAnswer(errQueueFull)
}
return f
}
func (lac *localAnswerClient) Close() error {
lac.a.mu.RLock()
obj, err, done := lac.a.obj, lac.a.err, lac.a.done
lac.a.mu.RUnlock()
if !done {
return nil
}
client := clientFromResolution(lac.transform, obj, err)
return client.Close()
}
var (
errQueueFull = errors.New("rpc: pipeline queue full")
errQueueCallCancel = errors.New("rpc: queued call canceled")
errDisembargoOngoingAnswer = errors.New("rpc: disembargo attempted on in-progress answer")
errDisembargoNonImport = errors.New("rpc: disembargo attempted on non-import capability")
errDisembargoMissingAnswer = errors.New("rpc: disembargo attempted on missing answer (finished too early?)")
)