package rpc import ( "sync" "golang.org/x/net/context" "zombiezen.com/go/capnproto2" "zombiezen.com/go/capnproto2/internal/fulfiller" "zombiezen.com/go/capnproto2/internal/queue" rpccapnp "zombiezen.com/go/capnproto2/std/capnp/rpc" ) // newQuestion creates a new question with an unassigned ID. func (c *Conn) newQuestion(ctx context.Context, method *capnp.Method) *question { id := questionID(c.questionID.next()) q := &question{ ctx: ctx, conn: c, method: method, resolved: make(chan struct{}), id: id, } // TODO(light): populate paramCaps if int(id) == len(c.questions) { c.questions = append(c.questions, q) } else { c.questions[id] = q } return q } func (c *Conn) findQuestion(id questionID) *question { if int(id) >= len(c.questions) { return nil } return c.questions[id] } func (c *Conn) popQuestion(id questionID) *question { q := c.findQuestion(id) if q == nil { return nil } c.questions[id] = nil c.questionID.remove(uint32(id)) return q } type question struct { id questionID ctx context.Context conn *Conn method *capnp.Method // nil if this is bootstrap paramCaps []exportID resolved chan struct{} // Protected by conn.mu derived [][]capnp.PipelineOp // Fields below are protected by mu. mu sync.RWMutex obj capnp.Ptr err error state questionState } type questionState uint8 // Question states const ( questionInProgress questionState = iota questionResolved questionCanceled ) // start signals that the question has been sent. func (q *question) start() { go func() { select { case <-q.resolved: // Resolved naturally, nothing to do. case <-q.conn.bg.Done(): case <-q.ctx.Done(): select { case <-q.resolved: case <-q.conn.bg.Done(): case <-q.conn.mu: if err := q.conn.startWork(); err != nil { // teardown calls cancel. q.conn.mu.Unlock() return } if q.cancel(q.ctx.Err()) { q.conn.sendMessage(newFinishMessage(nil, q.id, true /* release */)) } q.conn.workers.Done() q.conn.mu.Unlock() } } }() } // fulfill is called to resolve a question successfully. // The caller must be holding onto q.conn.mu. func (q *question) fulfill(obj capnp.Ptr) { var ctab []capnp.Client if obj.IsValid() { ctab = obj.Segment().Message().CapTable } visited := make([]bool, len(ctab)) for _, d := range q.derived { tgt, err := capnp.TransformPtr(obj, d) if err != nil { continue } in := tgt.Interface() if !in.IsValid() { continue } if ic := isImport(in.Client()); ic != nil && ic.conn == q.conn { // Imported from remote vat. Don't need to disembargo. continue } cn := in.Capability() if visited[cn] { continue } visited[cn] = true id, e := q.conn.newEmbargo() ctab[cn] = newEmbargoClient(ctab[cn], e, q.conn.bg.Done()) m := newDisembargoMessage(nil, rpccapnp.Disembargo_context_Which_senderLoopback, id) dis, _ := m.Disembargo() mt, _ := dis.NewTarget() pa, _ := mt.NewPromisedAnswer() pa.SetQuestionId(uint32(q.id)) transformToPromisedAnswer(m.Segment(), pa, d) mt.SetPromisedAnswer(pa) select { case q.conn.out <- m: case <-q.conn.bg.Done(): // TODO(soon): perhaps just drop all embargoes in this case? } } q.mu.Lock() if q.state != questionInProgress { panic("question.fulfill called more than once") } q.obj, q.state = obj, questionResolved close(q.resolved) q.mu.Unlock() } // reject is called to resolve a question with failure. // The caller must be holding onto q.conn.mu. func (q *question) reject(err error) { if err == nil { panic("question.reject called with nil") } q.mu.Lock() if q.state != questionInProgress { panic("question.reject called more than once") } q.err = err q.state = questionResolved close(q.resolved) q.mu.Unlock() } // cancel is called to resolve a question with cancellation. // The caller must be holding onto q.conn.mu. func (q *question) cancel(err error) bool { if err == nil { panic("question.cancel called with nil") } q.mu.Lock() canceled := q.state == questionInProgress if canceled { q.err = err q.state = questionCanceled close(q.resolved) } q.mu.Unlock() return canceled } // addPromise records a returned capability as being used for a call. // This is needed for determining embargoes upon resolution. The // caller must be holding onto q.conn.mu. func (q *question) addPromise(transform []capnp.PipelineOp) { for _, d := range q.derived { if transformsEqual(transform, d) { return } } q.derived = append(q.derived, transform) } func transformsEqual(t, u []capnp.PipelineOp) bool { if len(t) != len(u) { return false } for i := range t { if t[i].Field != u[i].Field { return false } } return true } func (q *question) Struct() (capnp.Struct, error) { select { case <-q.resolved: case <-q.conn.bg.Done(): return capnp.Struct{}, ErrConnClosed } q.mu.RLock() s, err := q.obj.Struct(), q.err q.mu.RUnlock() return s, err } func (q *question) PipelineCall(transform []capnp.PipelineOp, ccall *capnp.Call) capnp.Answer { select { case <-q.conn.mu: if err := q.conn.startWork(); err != nil { q.conn.mu.Unlock() return capnp.ErrorAnswer(err) } case <-ccall.Ctx.Done(): return capnp.ErrorAnswer(ccall.Ctx.Err()) } ans := q.lockedPipelineCall(transform, ccall) q.conn.workers.Done() q.conn.mu.Unlock() return ans } // lockedPipelineCall is equivalent to PipelineCall but assumes that the // caller is already holding onto q.conn.mu. func (q *question) lockedPipelineCall(transform []capnp.PipelineOp, ccall *capnp.Call) capnp.Answer { if q.conn.findQuestion(q.id) != q { // Question has been finished. The call should happen as if it is // back in application code. q.mu.RLock() obj, err, state := q.obj, q.err, q.state q.mu.RUnlock() if state == questionInProgress { panic("question popped but not done") } client := clientFromResolution(transform, obj, err) return q.conn.lockedCall(client, ccall) } pipeq := q.conn.newQuestion(ccall.Ctx, &ccall.Method) msg := newMessage(nil) msgCall, _ := msg.NewCall() msgCall.SetQuestionId(uint32(pipeq.id)) msgCall.SetInterfaceId(ccall.Method.InterfaceID) msgCall.SetMethodId(ccall.Method.MethodID) target, _ := msgCall.NewTarget() a, _ := target.NewPromisedAnswer() a.SetQuestionId(uint32(q.id)) err := transformToPromisedAnswer(a.Segment(), a, transform) if err != nil { q.conn.popQuestion(pipeq.id) return capnp.ErrorAnswer(err) } payload, _ := msgCall.NewParams() if err := q.conn.fillParams(payload, ccall); err != nil { q.conn.popQuestion(q.id) return capnp.ErrorAnswer(err) } select { case q.conn.out <- msg: case <-ccall.Ctx.Done(): q.conn.popQuestion(pipeq.id) return capnp.ErrorAnswer(ccall.Ctx.Err()) case <-q.conn.bg.Done(): q.conn.popQuestion(pipeq.id) return capnp.ErrorAnswer(ErrConnClosed) } q.addPromise(transform) pipeq.start() return pipeq } func (q *question) PipelineClose(transform []capnp.PipelineOp) error { <-q.resolved q.mu.RLock() obj, err := q.obj, q.err q.mu.RUnlock() if err != nil { return err } x, err := capnp.TransformPtr(obj, transform) if err != nil { return err } c := x.Interface().Client() if c == nil { return capnp.ErrNullClient } return c.Close() } // embargoClient is a client that waits until an embargo signal is // received to deliver calls. type embargoClient struct { cancel <-chan struct{} client capnp.Client embargo embargo mu sync.RWMutex q queue.Queue calls ecallList } func newEmbargoClient(client capnp.Client, e embargo, cancel <-chan struct{}) *embargoClient { ec := &embargoClient{ client: client, embargo: e, cancel: cancel, calls: make(ecallList, callQueueSize), } ec.q.Init(ec.calls, 0) go ec.flushQueue() return ec } func (ec *embargoClient) push(cl *capnp.Call) capnp.Answer { f := new(fulfiller.Fulfiller) cl, err := cl.Copy(nil) if err != nil { return capnp.ErrorAnswer(err) } i := ec.q.Push() if i == -1 { return capnp.ErrorAnswer(errQueueFull) } ec.calls[i] = ecall{cl, f} return f } func (ec *embargoClient) Call(cl *capnp.Call) capnp.Answer { // Fast path: queue is flushed. ec.mu.RLock() ok := ec.isPassthrough() ec.mu.RUnlock() if ok { return ec.client.Call(cl) } ec.mu.Lock() if ec.isPassthrough() { ec.mu.Unlock() return ec.client.Call(cl) } ans := ec.push(cl) ec.mu.Unlock() return ans } func (ec *embargoClient) tryQueue(cl *capnp.Call) capnp.Answer { ec.mu.Lock() if ec.isPassthrough() { ec.mu.Unlock() return nil } ans := ec.push(cl) ec.mu.Unlock() return ans } func (ec *embargoClient) isPassthrough() bool { select { case <-ec.embargo: default: return false } return ec.q.Len() == 0 } func (ec *embargoClient) Close() error { ec.mu.Lock() for ; ec.q.Len() > 0; ec.q.Pop() { c := ec.calls[ec.q.Front()] c.f.Reject(errQueueCallCancel) } ec.mu.Unlock() return ec.client.Close() } // flushQueue is run in its own goroutine. func (ec *embargoClient) flushQueue() { select { case <-ec.embargo: case <-ec.cancel: ec.mu.Lock() for ec.q.Len() > 0 { ec.q.Pop() } ec.mu.Unlock() return } var c ecall ec.mu.RLock() if i := ec.q.Front(); i != -1 { c = ec.calls[i] } ec.mu.RUnlock() for c.call != nil { ans := ec.client.Call(c.call) go joinFulfiller(c.f, ans) ec.mu.Lock() ec.q.Pop() if i := ec.q.Front(); i != -1 { c = ec.calls[i] } else { c = ecall{} } ec.mu.Unlock() } } type ecall struct { call *capnp.Call f *fulfiller.Fulfiller } type ecallList []ecall func (el ecallList) Len() int { return len(el) } func (el ecallList) Clear(i int) { el[i] = ecall{} }