443 lines
9.5 KiB
Go
443 lines
9.5 KiB
Go
|
package rpc
|
||
|
|
||
|
import (
|
||
|
"sync"
|
||
|
|
||
|
"golang.org/x/net/context"
|
||
|
"zombiezen.com/go/capnproto2"
|
||
|
"zombiezen.com/go/capnproto2/internal/fulfiller"
|
||
|
"zombiezen.com/go/capnproto2/internal/queue"
|
||
|
rpccapnp "zombiezen.com/go/capnproto2/std/capnp/rpc"
|
||
|
)
|
||
|
|
||
|
// newQuestion creates a new question with an unassigned ID.
|
||
|
func (c *Conn) newQuestion(ctx context.Context, method *capnp.Method) *question {
|
||
|
id := questionID(c.questionID.next())
|
||
|
q := &question{
|
||
|
ctx: ctx,
|
||
|
conn: c,
|
||
|
method: method,
|
||
|
resolved: make(chan struct{}),
|
||
|
id: id,
|
||
|
}
|
||
|
// TODO(light): populate paramCaps
|
||
|
if int(id) == len(c.questions) {
|
||
|
c.questions = append(c.questions, q)
|
||
|
} else {
|
||
|
c.questions[id] = q
|
||
|
}
|
||
|
return q
|
||
|
}
|
||
|
|
||
|
func (c *Conn) findQuestion(id questionID) *question {
|
||
|
if int(id) >= len(c.questions) {
|
||
|
return nil
|
||
|
}
|
||
|
return c.questions[id]
|
||
|
}
|
||
|
|
||
|
func (c *Conn) popQuestion(id questionID) *question {
|
||
|
q := c.findQuestion(id)
|
||
|
if q == nil {
|
||
|
return nil
|
||
|
}
|
||
|
c.questions[id] = nil
|
||
|
c.questionID.remove(uint32(id))
|
||
|
return q
|
||
|
}
|
||
|
|
||
|
type question struct {
|
||
|
id questionID
|
||
|
ctx context.Context
|
||
|
conn *Conn
|
||
|
method *capnp.Method // nil if this is bootstrap
|
||
|
paramCaps []exportID
|
||
|
resolved chan struct{}
|
||
|
|
||
|
// Protected by conn.mu
|
||
|
derived [][]capnp.PipelineOp
|
||
|
|
||
|
// Fields below are protected by mu.
|
||
|
mu sync.RWMutex
|
||
|
obj capnp.Ptr
|
||
|
err error
|
||
|
state questionState
|
||
|
}
|
||
|
|
||
|
type questionState uint8
|
||
|
|
||
|
// Question states
|
||
|
const (
|
||
|
questionInProgress questionState = iota
|
||
|
questionResolved
|
||
|
questionCanceled
|
||
|
)
|
||
|
|
||
|
// start signals that the question has been sent.
|
||
|
func (q *question) start() {
|
||
|
go func() {
|
||
|
select {
|
||
|
case <-q.resolved:
|
||
|
// Resolved naturally, nothing to do.
|
||
|
case <-q.conn.bg.Done():
|
||
|
case <-q.ctx.Done():
|
||
|
select {
|
||
|
case <-q.resolved:
|
||
|
case <-q.conn.bg.Done():
|
||
|
case <-q.conn.mu:
|
||
|
if err := q.conn.startWork(); err != nil {
|
||
|
// teardown calls cancel.
|
||
|
q.conn.mu.Unlock()
|
||
|
return
|
||
|
}
|
||
|
if q.cancel(q.ctx.Err()) {
|
||
|
q.conn.sendMessage(newFinishMessage(nil, q.id, true /* release */))
|
||
|
}
|
||
|
q.conn.workers.Done()
|
||
|
q.conn.mu.Unlock()
|
||
|
}
|
||
|
}
|
||
|
}()
|
||
|
}
|
||
|
|
||
|
// fulfill is called to resolve a question successfully.
|
||
|
// The caller must be holding onto q.conn.mu.
|
||
|
func (q *question) fulfill(obj capnp.Ptr) {
|
||
|
var ctab []capnp.Client
|
||
|
if obj.IsValid() {
|
||
|
ctab = obj.Segment().Message().CapTable
|
||
|
}
|
||
|
visited := make([]bool, len(ctab))
|
||
|
for _, d := range q.derived {
|
||
|
tgt, err := capnp.TransformPtr(obj, d)
|
||
|
if err != nil {
|
||
|
continue
|
||
|
}
|
||
|
in := tgt.Interface()
|
||
|
if !in.IsValid() {
|
||
|
continue
|
||
|
}
|
||
|
if ic := isImport(in.Client()); ic != nil && ic.conn == q.conn {
|
||
|
// Imported from remote vat. Don't need to disembargo.
|
||
|
continue
|
||
|
}
|
||
|
cn := in.Capability()
|
||
|
if visited[cn] {
|
||
|
continue
|
||
|
}
|
||
|
visited[cn] = true
|
||
|
id, e := q.conn.newEmbargo()
|
||
|
ctab[cn] = newEmbargoClient(ctab[cn], e, q.conn.bg.Done())
|
||
|
m := newDisembargoMessage(nil, rpccapnp.Disembargo_context_Which_senderLoopback, id)
|
||
|
dis, _ := m.Disembargo()
|
||
|
mt, _ := dis.NewTarget()
|
||
|
pa, _ := mt.NewPromisedAnswer()
|
||
|
pa.SetQuestionId(uint32(q.id))
|
||
|
transformToPromisedAnswer(m.Segment(), pa, d)
|
||
|
mt.SetPromisedAnswer(pa)
|
||
|
|
||
|
select {
|
||
|
case q.conn.out <- m:
|
||
|
case <-q.conn.bg.Done():
|
||
|
// TODO(soon): perhaps just drop all embargoes in this case?
|
||
|
}
|
||
|
}
|
||
|
|
||
|
q.mu.Lock()
|
||
|
if q.state != questionInProgress {
|
||
|
panic("question.fulfill called more than once")
|
||
|
}
|
||
|
q.obj, q.state = obj, questionResolved
|
||
|
close(q.resolved)
|
||
|
q.mu.Unlock()
|
||
|
}
|
||
|
|
||
|
// reject is called to resolve a question with failure.
|
||
|
// The caller must be holding onto q.conn.mu.
|
||
|
func (q *question) reject(err error) {
|
||
|
if err == nil {
|
||
|
panic("question.reject called with nil")
|
||
|
}
|
||
|
q.mu.Lock()
|
||
|
if q.state != questionInProgress {
|
||
|
panic("question.reject called more than once")
|
||
|
}
|
||
|
q.err = err
|
||
|
q.state = questionResolved
|
||
|
close(q.resolved)
|
||
|
q.mu.Unlock()
|
||
|
}
|
||
|
|
||
|
// cancel is called to resolve a question with cancellation.
|
||
|
// The caller must be holding onto q.conn.mu.
|
||
|
func (q *question) cancel(err error) bool {
|
||
|
if err == nil {
|
||
|
panic("question.cancel called with nil")
|
||
|
}
|
||
|
q.mu.Lock()
|
||
|
canceled := q.state == questionInProgress
|
||
|
if canceled {
|
||
|
q.err = err
|
||
|
q.state = questionCanceled
|
||
|
close(q.resolved)
|
||
|
}
|
||
|
q.mu.Unlock()
|
||
|
return canceled
|
||
|
}
|
||
|
|
||
|
// addPromise records a returned capability as being used for a call.
|
||
|
// This is needed for determining embargoes upon resolution. The
|
||
|
// caller must be holding onto q.conn.mu.
|
||
|
func (q *question) addPromise(transform []capnp.PipelineOp) {
|
||
|
for _, d := range q.derived {
|
||
|
if transformsEqual(transform, d) {
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
q.derived = append(q.derived, transform)
|
||
|
}
|
||
|
|
||
|
func transformsEqual(t, u []capnp.PipelineOp) bool {
|
||
|
if len(t) != len(u) {
|
||
|
return false
|
||
|
}
|
||
|
for i := range t {
|
||
|
if t[i].Field != u[i].Field {
|
||
|
return false
|
||
|
}
|
||
|
}
|
||
|
return true
|
||
|
}
|
||
|
|
||
|
func (q *question) Struct() (capnp.Struct, error) {
|
||
|
select {
|
||
|
case <-q.resolved:
|
||
|
case <-q.conn.bg.Done():
|
||
|
return capnp.Struct{}, ErrConnClosed
|
||
|
}
|
||
|
q.mu.RLock()
|
||
|
s, err := q.obj.Struct(), q.err
|
||
|
q.mu.RUnlock()
|
||
|
return s, err
|
||
|
}
|
||
|
|
||
|
func (q *question) PipelineCall(transform []capnp.PipelineOp, ccall *capnp.Call) capnp.Answer {
|
||
|
select {
|
||
|
case <-q.conn.mu:
|
||
|
if err := q.conn.startWork(); err != nil {
|
||
|
q.conn.mu.Unlock()
|
||
|
return capnp.ErrorAnswer(err)
|
||
|
}
|
||
|
case <-ccall.Ctx.Done():
|
||
|
return capnp.ErrorAnswer(ccall.Ctx.Err())
|
||
|
}
|
||
|
ans := q.lockedPipelineCall(transform, ccall)
|
||
|
q.conn.workers.Done()
|
||
|
q.conn.mu.Unlock()
|
||
|
return ans
|
||
|
}
|
||
|
|
||
|
// lockedPipelineCall is equivalent to PipelineCall but assumes that the
|
||
|
// caller is already holding onto q.conn.mu.
|
||
|
func (q *question) lockedPipelineCall(transform []capnp.PipelineOp, ccall *capnp.Call) capnp.Answer {
|
||
|
if q.conn.findQuestion(q.id) != q {
|
||
|
// Question has been finished. The call should happen as if it is
|
||
|
// back in application code.
|
||
|
q.mu.RLock()
|
||
|
obj, err, state := q.obj, q.err, q.state
|
||
|
q.mu.RUnlock()
|
||
|
if state == questionInProgress {
|
||
|
panic("question popped but not done")
|
||
|
}
|
||
|
client := clientFromResolution(transform, obj, err)
|
||
|
return q.conn.lockedCall(client, ccall)
|
||
|
}
|
||
|
|
||
|
pipeq := q.conn.newQuestion(ccall.Ctx, &ccall.Method)
|
||
|
msg := newMessage(nil)
|
||
|
msgCall, _ := msg.NewCall()
|
||
|
msgCall.SetQuestionId(uint32(pipeq.id))
|
||
|
msgCall.SetInterfaceId(ccall.Method.InterfaceID)
|
||
|
msgCall.SetMethodId(ccall.Method.MethodID)
|
||
|
target, _ := msgCall.NewTarget()
|
||
|
a, _ := target.NewPromisedAnswer()
|
||
|
a.SetQuestionId(uint32(q.id))
|
||
|
err := transformToPromisedAnswer(a.Segment(), a, transform)
|
||
|
if err != nil {
|
||
|
q.conn.popQuestion(pipeq.id)
|
||
|
return capnp.ErrorAnswer(err)
|
||
|
}
|
||
|
payload, _ := msgCall.NewParams()
|
||
|
if err := q.conn.fillParams(payload, ccall); err != nil {
|
||
|
q.conn.popQuestion(q.id)
|
||
|
return capnp.ErrorAnswer(err)
|
||
|
}
|
||
|
|
||
|
select {
|
||
|
case q.conn.out <- msg:
|
||
|
case <-ccall.Ctx.Done():
|
||
|
q.conn.popQuestion(pipeq.id)
|
||
|
return capnp.ErrorAnswer(ccall.Ctx.Err())
|
||
|
case <-q.conn.bg.Done():
|
||
|
q.conn.popQuestion(pipeq.id)
|
||
|
return capnp.ErrorAnswer(ErrConnClosed)
|
||
|
}
|
||
|
q.addPromise(transform)
|
||
|
pipeq.start()
|
||
|
return pipeq
|
||
|
}
|
||
|
|
||
|
func (q *question) PipelineClose(transform []capnp.PipelineOp) error {
|
||
|
<-q.resolved
|
||
|
q.mu.RLock()
|
||
|
obj, err := q.obj, q.err
|
||
|
q.mu.RUnlock()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
x, err := capnp.TransformPtr(obj, transform)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
c := x.Interface().Client()
|
||
|
if c == nil {
|
||
|
return capnp.ErrNullClient
|
||
|
}
|
||
|
return c.Close()
|
||
|
}
|
||
|
|
||
|
// embargoClient is a client that waits until an embargo signal is
|
||
|
// received to deliver calls.
|
||
|
type embargoClient struct {
|
||
|
cancel <-chan struct{}
|
||
|
client capnp.Client
|
||
|
embargo embargo
|
||
|
|
||
|
mu sync.RWMutex
|
||
|
q queue.Queue
|
||
|
calls ecallList
|
||
|
}
|
||
|
|
||
|
func newEmbargoClient(client capnp.Client, e embargo, cancel <-chan struct{}) *embargoClient {
|
||
|
ec := &embargoClient{
|
||
|
client: client,
|
||
|
embargo: e,
|
||
|
cancel: cancel,
|
||
|
calls: make(ecallList, callQueueSize),
|
||
|
}
|
||
|
ec.q.Init(ec.calls, 0)
|
||
|
go ec.flushQueue()
|
||
|
return ec
|
||
|
}
|
||
|
|
||
|
func (ec *embargoClient) push(cl *capnp.Call) capnp.Answer {
|
||
|
f := new(fulfiller.Fulfiller)
|
||
|
cl, err := cl.Copy(nil)
|
||
|
if err != nil {
|
||
|
return capnp.ErrorAnswer(err)
|
||
|
}
|
||
|
i := ec.q.Push()
|
||
|
if i == -1 {
|
||
|
return capnp.ErrorAnswer(errQueueFull)
|
||
|
}
|
||
|
ec.calls[i] = ecall{cl, f}
|
||
|
return f
|
||
|
}
|
||
|
|
||
|
func (ec *embargoClient) Call(cl *capnp.Call) capnp.Answer {
|
||
|
// Fast path: queue is flushed.
|
||
|
ec.mu.RLock()
|
||
|
ok := ec.isPassthrough()
|
||
|
ec.mu.RUnlock()
|
||
|
if ok {
|
||
|
return ec.client.Call(cl)
|
||
|
}
|
||
|
|
||
|
ec.mu.Lock()
|
||
|
if ec.isPassthrough() {
|
||
|
ec.mu.Unlock()
|
||
|
return ec.client.Call(cl)
|
||
|
}
|
||
|
ans := ec.push(cl)
|
||
|
ec.mu.Unlock()
|
||
|
return ans
|
||
|
}
|
||
|
|
||
|
func (ec *embargoClient) tryQueue(cl *capnp.Call) capnp.Answer {
|
||
|
ec.mu.Lock()
|
||
|
if ec.isPassthrough() {
|
||
|
ec.mu.Unlock()
|
||
|
return nil
|
||
|
}
|
||
|
ans := ec.push(cl)
|
||
|
ec.mu.Unlock()
|
||
|
return ans
|
||
|
}
|
||
|
|
||
|
func (ec *embargoClient) isPassthrough() bool {
|
||
|
select {
|
||
|
case <-ec.embargo:
|
||
|
default:
|
||
|
return false
|
||
|
}
|
||
|
return ec.q.Len() == 0
|
||
|
}
|
||
|
|
||
|
func (ec *embargoClient) Close() error {
|
||
|
ec.mu.Lock()
|
||
|
for ; ec.q.Len() > 0; ec.q.Pop() {
|
||
|
c := ec.calls[ec.q.Front()]
|
||
|
c.f.Reject(errQueueCallCancel)
|
||
|
}
|
||
|
ec.mu.Unlock()
|
||
|
return ec.client.Close()
|
||
|
}
|
||
|
|
||
|
// flushQueue is run in its own goroutine.
|
||
|
func (ec *embargoClient) flushQueue() {
|
||
|
select {
|
||
|
case <-ec.embargo:
|
||
|
case <-ec.cancel:
|
||
|
ec.mu.Lock()
|
||
|
for ec.q.Len() > 0 {
|
||
|
ec.q.Pop()
|
||
|
}
|
||
|
ec.mu.Unlock()
|
||
|
return
|
||
|
}
|
||
|
var c ecall
|
||
|
ec.mu.RLock()
|
||
|
if i := ec.q.Front(); i != -1 {
|
||
|
c = ec.calls[i]
|
||
|
}
|
||
|
ec.mu.RUnlock()
|
||
|
for c.call != nil {
|
||
|
ans := ec.client.Call(c.call)
|
||
|
go joinFulfiller(c.f, ans)
|
||
|
|
||
|
ec.mu.Lock()
|
||
|
ec.q.Pop()
|
||
|
if i := ec.q.Front(); i != -1 {
|
||
|
c = ec.calls[i]
|
||
|
} else {
|
||
|
c = ecall{}
|
||
|
}
|
||
|
ec.mu.Unlock()
|
||
|
}
|
||
|
}
|
||
|
|
||
|
type ecall struct {
|
||
|
call *capnp.Call
|
||
|
f *fulfiller.Fulfiller
|
||
|
}
|
||
|
|
||
|
type ecallList []ecall
|
||
|
|
||
|
func (el ecallList) Len() int {
|
||
|
return len(el)
|
||
|
}
|
||
|
|
||
|
func (el ecallList) Clear(i int) {
|
||
|
el[i] = ecall{}
|
||
|
}
|