// Package rpc implements the Cap'n Proto RPC protocol. package rpc // import "zombiezen.com/go/capnproto2/rpc" import ( "fmt" "io" "sync" "golang.org/x/net/context" "zombiezen.com/go/capnproto2" "zombiezen.com/go/capnproto2/rpc/internal/refcount" rpccapnp "zombiezen.com/go/capnproto2/std/capnp/rpc" ) // A Conn is a connection to another Cap'n Proto vat. // It is safe to use from multiple goroutines. type Conn struct { transport Transport log Logger mainFunc func(context.Context) (capnp.Client, error) mainCloser io.Closer death chan struct{} // closed after state is connDead out chan rpccapnp.Message bg context.Context bgCancel context.CancelFunc workers sync.WaitGroup // Mutable state protected by stateMu // If you need to acquire both mu and stateMu, acquire mu first. stateMu sync.RWMutex state connState closeErr error // Mutable state protected by mu mu chanMutex questions []*question questionID idgen exports []*export exportID idgen embargoes []chan<- struct{} embargoID idgen answers map[answerID]*answer imports map[importID]*impent } type connParams struct { log Logger mainFunc func(context.Context) (capnp.Client, error) mainCloser io.Closer sendBufferSize int } // A ConnOption is an option for opening a connection. type ConnOption struct { f func(*connParams) } // MainInterface specifies that the connection should use client when // receiving bootstrap messages. By default, all bootstrap messages will // fail. The client will be closed when the connection is closed. func MainInterface(client capnp.Client) ConnOption { rc, ref1 := refcount.New(client) ref2 := rc.Ref() return ConnOption{func(c *connParams) { c.mainFunc = func(ctx context.Context) (capnp.Client, error) { return ref1, nil } c.mainCloser = ref2 }} } // BootstrapFunc specifies the function to call to create a capability // for handling bootstrap messages. This function should not make any // RPCs or block. func BootstrapFunc(f func(context.Context) (capnp.Client, error)) ConnOption { return ConnOption{func(c *connParams) { c.mainFunc = f }} } // SendBufferSize sets the number of outgoing messages to buffer on the // connection. This is in addition to whatever buffering the connection's // transport performs. func SendBufferSize(numMsgs int) ConnOption { return ConnOption{func(c *connParams) { c.sendBufferSize = numMsgs }} } // NewConn creates a new connection that communicates on c. // Closing the connection will cause c to be closed. func NewConn(t Transport, options ...ConnOption) *Conn { p := &connParams{ log: defaultLogger{}, sendBufferSize: 4, } for _, o := range options { o.f(p) } conn := &Conn{ transport: t, out: make(chan rpccapnp.Message, p.sendBufferSize), mainFunc: p.mainFunc, mainCloser: p.mainCloser, log: p.log, death: make(chan struct{}), mu: newChanMutex(), } conn.bg, conn.bgCancel = context.WithCancel(context.Background()) conn.workers.Add(2) go conn.dispatchRecv() go conn.dispatchSend() return conn } // Wait waits until the connection is closed or aborted by the remote vat. // Wait will always return an error, usually ErrConnClosed or of type Abort. func (c *Conn) Wait() error { <-c.Done() return c.Err() } // Done is a channel that is closed once the connection is fully shut down. func (c *Conn) Done() <-chan struct{} { return c.death } // Err returns the error that caused the connection to close. // Err returns nil before Done is closed. func (c *Conn) Err() error { c.stateMu.RLock() var err error if c.state != connDead { err = c.closeErr } c.stateMu.RUnlock() return err } // Close closes the connection and the underlying transport. func (c *Conn) Close() error { c.stateMu.Lock() alive := c.state == connAlive if alive { c.bgCancel() c.closeErr = ErrConnClosed c.state = connDying } c.stateMu.Unlock() if !alive { return ErrConnClosed } c.teardown(newAbortMessage(nil, errShutdown)) c.stateMu.RLock() err := c.closeErr c.stateMu.RUnlock() if err != ErrConnClosed { return err } return nil } // shutdown cancels the background context and sets closeErr to e. // No abort message will be sent on the transport. After shutdown // returns, the Conn will be in the dying or dead state. Calling // shutdown on a dying or dead Conn is a no-op. func (c *Conn) shutdown(e error) { c.stateMu.Lock() if c.state == connAlive { c.bgCancel() c.closeErr = e c.state = connDying go c.teardown(rpccapnp.Message{}) } c.stateMu.Unlock() } // abort cancels the background context, sets closeErr to e, and queues // an abort message to be sent on the transport before the Conn goes // into the dead state. After abort returns, the Conn will be in the // dying or dead state. Calling abort on a dying or dead Conn is a // no-op. func (c *Conn) abort(e error) { c.stateMu.Lock() if c.state == connAlive { c.bgCancel() c.closeErr = e c.state = connDying go c.teardown(newAbortMessage(nil, e)) } c.stateMu.Unlock() } // startWork adds a new worker if c is not dying or dead. // Otherwise, it returns the close error. // The caller is responsible for calling c.workers.Done(). // The caller must not be holding onto c.stateMu. func (c *Conn) startWork() error { var err error c.stateMu.RLock() if c.state == connAlive { c.workers.Add(1) } else { err = c.closeErr } c.stateMu.RUnlock() return err } // teardown moves the connection from the dying to the dead state. func (c *Conn) teardown(abort rpccapnp.Message) { c.workers.Wait() c.mu.Lock() for _, q := range c.questions { if q != nil { q.cancel(ErrConnClosed) } } c.questions = nil exps := c.exports c.exports = nil c.embargoes = nil for _, a := range c.answers { a.cancel() } c.answers = nil c.imports = nil c.mainFunc = nil c.mu.Unlock() if c.mainCloser != nil { if err := c.mainCloser.Close(); err != nil { c.errorf("closing main interface: %v", err) } c.mainCloser = nil } // Closing an export may try to lock the Conn, so run it outside // critical section. for id, e := range exps { if e == nil { continue } if err := e.client.Close(); err != nil { c.errorf("export %v close: %v", id, err) } } exps = nil var werr error if abort.IsValid() { werr = c.transport.SendMessage(context.Background(), abort) } cerr := c.transport.Close() c.stateMu.Lock() if c.closeErr == ErrConnClosed { if cerr != nil { c.closeErr = cerr } else if werr != nil { c.closeErr = werr } } c.state = connDead close(c.death) c.stateMu.Unlock() } // Bootstrap returns the receiver's main interface. func (c *Conn) Bootstrap(ctx context.Context) capnp.Client { // TODO(light): Create a client that returns immediately. select { case <-c.mu: // Locked. defer c.mu.Unlock() if err := c.startWork(); err != nil { return capnp.ErrorClient(err) } defer c.workers.Done() case <-ctx.Done(): return capnp.ErrorClient(ctx.Err()) case <-c.bg.Done(): return capnp.ErrorClient(ErrConnClosed) } q := c.newQuestion(ctx, nil /* method */) msg := newMessage(nil) boot, _ := msg.NewBootstrap() boot.SetQuestionId(uint32(q.id)) // The mutex must be held while sending so that call order is preserved. // Worst case, this blocks until a message is sent on the transport. // Common case, this just adds to the channel queue. select { case c.out <- msg: q.start() return capnp.NewPipeline(q).Client() case <-ctx.Done(): c.popQuestion(q.id) return capnp.ErrorClient(ctx.Err()) case <-c.bg.Done(): c.popQuestion(q.id) return capnp.ErrorClient(ErrConnClosed) } } // handleMessage is run from the receive goroutine to process a single // message. m cannot be held onto past the return of handleMessage, and // c.mu is not held at the start of handleMessage. func (c *Conn) handleMessage(m rpccapnp.Message) { switch m.Which() { case rpccapnp.Message_Which_unimplemented: // no-op for now to avoid feedback loop case rpccapnp.Message_Which_abort: a, err := copyAbort(m) if err != nil { c.errorf("decode abort: %v", err) // Keep going, since we're trying to abort anyway. } c.infof("abort: %v", a) c.shutdown(a) case rpccapnp.Message_Which_return: m = copyRPCMessage(m) c.mu.Lock() err := c.handleReturnMessage(m) c.mu.Unlock() if err != nil { c.errorf("handle return: %v", err) } case rpccapnp.Message_Which_finish: mfin, err := m.Finish() if err != nil { c.errorf("decode finish: %v", err) return } id := answerID(mfin.QuestionId()) c.mu.Lock() a := c.popAnswer(id) if a == nil { c.mu.Unlock() c.errorf("finish called for unknown answer %d", id) return } a.cancel() if mfin.ReleaseResultCaps() { for _, id := range a.resultCaps { c.releaseExport(id, 1) } } c.mu.Unlock() case rpccapnp.Message_Which_bootstrap: boot, err := m.Bootstrap() if err != nil { c.errorf("decode bootstrap: %v", err) return } id := answerID(boot.QuestionId()) c.mu.Lock() err = c.handleBootstrapMessage(id) c.mu.Unlock() if err != nil { c.errorf("handle bootstrap: %v", err) } case rpccapnp.Message_Which_call: m = copyRPCMessage(m) c.mu.Lock() err := c.handleCallMessage(m) c.mu.Unlock() if err != nil { c.errorf("handle call: %v", err) } case rpccapnp.Message_Which_release: rel, err := m.Release() if err != nil { c.errorf("decode release: %v", err) return } id := exportID(rel.Id()) refs := int(rel.ReferenceCount()) c.mu.Lock() c.releaseExport(id, refs) c.mu.Unlock() case rpccapnp.Message_Which_disembargo: m = copyRPCMessage(m) c.mu.Lock() err := c.handleDisembargoMessage(m) c.mu.Unlock() if err != nil { // Any failure in a disembargo is a protocol violation. c.abort(err) } default: c.infof("received unimplemented message, which = %v", m.Which()) um := newUnimplementedMessage(nil, m) c.sendMessage(um) } } func newUnimplementedMessage(buf []byte, m rpccapnp.Message) rpccapnp.Message { n := newMessage(buf) n.SetUnimplemented(m) return n } func (c *Conn) fillParams(payload rpccapnp.Payload, cl *capnp.Call) error { params, err := cl.PlaceParams(payload.Segment()) if err != nil { return err } if err := payload.SetContent(params); err != nil { return err } ctab, err := c.makeCapTable(payload.Segment()) if err != nil { return err } if err := payload.SetCapTable(ctab); err != nil { return err } return nil } func transformToPromisedAnswer(s *capnp.Segment, answer rpccapnp.PromisedAnswer, transform []capnp.PipelineOp) error { opList, err := rpccapnp.NewPromisedAnswer_Op_List(s, int32(len(transform))) if err != nil { return err } for i, op := range transform { opList.At(i).SetGetPointerField(uint16(op.Field)) } err = answer.SetTransform(opList) return err } // handleReturnMessage is to handle a received return message. // The caller is holding onto c.mu. func (c *Conn) handleReturnMessage(m rpccapnp.Message) error { ret, err := m.Return() if err != nil { return err } id := questionID(ret.AnswerId()) q := c.popQuestion(id) if q == nil { return fmt.Errorf("received return for unknown question id=%d", id) } if ret.ReleaseParamCaps() { for _, id := range q.paramCaps { c.releaseExport(id, 1) } } q.mu.RLock() qstate := q.state q.mu.RUnlock() if qstate == questionCanceled { // We already sent the finish message. return nil } releaseResultCaps := true switch ret.Which() { case rpccapnp.Return_Which_results: releaseResultCaps = false results, err := ret.Results() if err != nil { return err } if err := c.populateMessageCapTable(results); err == errUnimplemented { um := newUnimplementedMessage(nil, m) c.sendMessage(um) return errUnimplemented } else if err != nil { c.abort(err) return err } content, err := results.ContentPtr() if err != nil { return err } q.fulfill(content) case rpccapnp.Return_Which_exception: exc, err := ret.Exception() if err != nil { return err } e := error(Exception{exc}) if q.method != nil { e = &capnp.MethodError{ Method: q.method, Err: e, } } else { e = bootstrapError{e} } q.reject(e) case rpccapnp.Return_Which_canceled: err := &questionError{ id: id, method: q.method, err: fmt.Errorf("receiver reported canceled"), } c.errorf("%v", err) q.reject(err) return nil default: um := newUnimplementedMessage(nil, m) c.sendMessage(um) return errUnimplemented } fin := newFinishMessage(nil, id, releaseResultCaps) c.sendMessage(fin) return nil } func newFinishMessage(buf []byte, questionID questionID, release bool) rpccapnp.Message { m := newMessage(buf) f, _ := m.NewFinish() f.SetQuestionId(uint32(questionID)) f.SetReleaseResultCaps(release) return m } // populateMessageCapTable converts the descriptors in the payload into // clients and sets it on the message the payload is a part of. func (c *Conn) populateMessageCapTable(payload rpccapnp.Payload) error { msg := payload.Segment().Message() ctab, err := payload.CapTable() if err != nil { return err } for i, n := 0, ctab.Len(); i < n; i++ { desc := ctab.At(i) switch desc.Which() { case rpccapnp.CapDescriptor_Which_none: msg.AddCap(nil) case rpccapnp.CapDescriptor_Which_senderHosted: id := importID(desc.SenderHosted()) client := c.addImport(id) msg.AddCap(client) case rpccapnp.CapDescriptor_Which_senderPromise: // We do the same thing as senderHosted, above. @kentonv suggested this on // issue #2; this let's messages be delivered properly, although it's a bit // of a hack, and as Kenton describes, it has some disadvantages: // // > * Apps sometimes want to wait for promise resolution, and to find out if // > it resolved to an exception. You won't be able to provide that API. But, // > usually, it isn't needed. // > * If the promise resolves to a capability hosted on the receiver, // > messages sent to it will uselessly round-trip over the network // > rather than being delivered locally. id := importID(desc.SenderPromise()) client := c.addImport(id) msg.AddCap(client) case rpccapnp.CapDescriptor_Which_receiverHosted: id := exportID(desc.ReceiverHosted()) e := c.findExport(id) if e == nil { return fmt.Errorf("rpc: capability table references unknown export ID %d", id) } msg.AddCap(e.rc.Ref()) case rpccapnp.CapDescriptor_Which_receiverAnswer: recvAns, err := desc.ReceiverAnswer() if err != nil { return err } id := answerID(recvAns.QuestionId()) a := c.answers[id] if a == nil { return fmt.Errorf("rpc: capability table references unknown answer ID %d", id) } recvTransform, err := recvAns.Transform() if err != nil { return err } transform := promisedAnswerOpsToTransform(recvTransform) msg.AddCap(a.pipelineClient(transform)) default: c.errorf("unknown capability type %v", desc.Which()) return errUnimplemented } } return nil } // makeCapTable converts the clients in the segment's message into capability descriptors. func (c *Conn) makeCapTable(s *capnp.Segment) (rpccapnp.CapDescriptor_List, error) { msgtab := s.Message().CapTable t, err := rpccapnp.NewCapDescriptor_List(s, int32(len(msgtab))) if err != nil { return rpccapnp.CapDescriptor_List{}, nil } for i, client := range msgtab { desc := t.At(i) if client == nil { desc.SetNone() continue } c.descriptorForClient(desc, client) } return t, nil } // handleBootstrapMessage handles a received bootstrap message. // The caller holds onto c.mu. func (c *Conn) handleBootstrapMessage(id answerID) error { ctx, cancel := c.newContext() defer cancel() a := c.insertAnswer(id, cancel) if a == nil { // Question ID reused, error out. retmsg := newReturnMessage(nil, id) r, _ := retmsg.Return() setReturnException(r, errQuestionReused) return c.sendMessage(retmsg) } if c.mainFunc == nil { return a.reject(errNoMainInterface) } main, err := c.mainFunc(ctx) if err != nil { return a.reject(errNoMainInterface) } m := &capnp.Message{ Arena: capnp.SingleSegment(make([]byte, 0)), CapTable: []capnp.Client{main}, } s, _ := m.Segment(0) in := capnp.NewInterface(s, 0) return a.fulfill(in.ToPtr()) } // handleCallMessage handles a received call message. It mutates the // capability table of its parameter. The caller holds onto c.mu. func (c *Conn) handleCallMessage(m rpccapnp.Message) error { mcall, err := m.Call() if err != nil { return err } mt, err := mcall.Target() if err != nil { return err } if mt.Which() != rpccapnp.MessageTarget_Which_importedCap && mt.Which() != rpccapnp.MessageTarget_Which_promisedAnswer { um := newUnimplementedMessage(nil, m) return c.sendMessage(um) } mparams, err := mcall.Params() if err != nil { return err } if err := c.populateMessageCapTable(mparams); err == errUnimplemented { um := newUnimplementedMessage(nil, m) return c.sendMessage(um) } else if err != nil { c.abort(err) return err } ctx, cancel := c.newContext() id := answerID(mcall.QuestionId()) a := c.insertAnswer(id, cancel) if a == nil { // Question ID reused, error out. c.abort(errQuestionReused) return errQuestionReused } meth := capnp.Method{ InterfaceID: mcall.InterfaceId(), MethodID: mcall.MethodId(), } paramContent, err := mparams.ContentPtr() if err != nil { return err } cl := &capnp.Call{ Ctx: ctx, Method: meth, Params: paramContent.Struct(), } if err := c.routeCallMessage(a, mt, cl); err != nil { return a.reject(err) } return nil } func (c *Conn) routeCallMessage(result *answer, mt rpccapnp.MessageTarget, cl *capnp.Call) error { switch mt.Which() { case rpccapnp.MessageTarget_Which_importedCap: id := exportID(mt.ImportedCap()) e := c.findExport(id) if e == nil { return errBadTarget } answer := c.lockedCall(e.client, cl) go joinAnswer(result, answer) case rpccapnp.MessageTarget_Which_promisedAnswer: mpromise, err := mt.PromisedAnswer() if err != nil { return err } id := answerID(mpromise.QuestionId()) if id == result.id { // Grandfather paradox. return errBadTarget } pa := c.answers[id] if pa == nil { return errBadTarget } mtrans, err := mpromise.Transform() if err != nil { return err } transform := promisedAnswerOpsToTransform(mtrans) pa.mu.Lock() if pa.done { obj, err := pa.obj, pa.err pa.mu.Unlock() client := clientFromResolution(transform, obj, err) answer := c.lockedCall(client, cl) go joinAnswer(result, answer) } else { err = pa.queueCallLocked(cl, pcall{transform: transform, qcall: qcall{a: result}}) pa.mu.Unlock() } return err default: panic("unreachable") } return nil } func (c *Conn) handleDisembargoMessage(msg rpccapnp.Message) error { d, err := msg.Disembargo() if err != nil { return err } dtarget, err := d.Target() if err != nil { return err } switch d.Context().Which() { case rpccapnp.Disembargo_context_Which_senderLoopback: id := embargoID(d.Context().SenderLoopback()) if dtarget.Which() != rpccapnp.MessageTarget_Which_promisedAnswer { return errDisembargoNonImport } dpa, err := dtarget.PromisedAnswer() if err != nil { return err } aid := answerID(dpa.QuestionId()) a := c.answers[aid] if a == nil { return errDisembargoMissingAnswer } dtrans, err := dpa.Transform() if err != nil { return err } transform := promisedAnswerOpsToTransform(dtrans) queued, err := a.queueDisembargo(transform, id, dtarget) if err != nil { return err } if !queued { // There's nothing to embargo; everything's been delivered. resp := newDisembargoMessage(nil, rpccapnp.Disembargo_context_Which_receiverLoopback, id) rd, _ := resp.Disembargo() if err := rd.SetTarget(dtarget); err != nil { return err } c.sendMessage(resp) } case rpccapnp.Disembargo_context_Which_receiverLoopback: id := embargoID(d.Context().ReceiverLoopback()) c.disembargo(id) default: um := newUnimplementedMessage(nil, msg) c.sendMessage(um) } return nil } // newDisembargoMessage creates a disembargo message. Its target will be left blank. func newDisembargoMessage(buf []byte, which rpccapnp.Disembargo_context_Which, id embargoID) rpccapnp.Message { msg := newMessage(buf) d, _ := msg.NewDisembargo() switch which { case rpccapnp.Disembargo_context_Which_senderLoopback: d.Context().SetSenderLoopback(uint32(id)) case rpccapnp.Disembargo_context_Which_receiverLoopback: d.Context().SetReceiverLoopback(uint32(id)) default: panic("unreachable") } return msg } // newContext creates a new context for a local call. func (c *Conn) newContext() (context.Context, context.CancelFunc) { return context.WithCancel(c.bg) } func promisedAnswerOpsToTransform(list rpccapnp.PromisedAnswer_Op_List) []capnp.PipelineOp { n := list.Len() transform := make([]capnp.PipelineOp, 0, n) for i := 0; i < n; i++ { op := list.At(i) switch op.Which() { case rpccapnp.PromisedAnswer_Op_Which_getPointerField: transform = append(transform, capnp.PipelineOp{ Field: op.GetPointerField(), }) case rpccapnp.PromisedAnswer_Op_Which_noop: // no-op } } return transform } func newAbortMessage(buf []byte, err error) rpccapnp.Message { n := newMessage(buf) e, _ := n.NewAbort() toException(e, err) return n } func newReturnMessage(buf []byte, id answerID) rpccapnp.Message { retmsg := newMessage(buf) ret, _ := retmsg.NewReturn() ret.SetAnswerId(uint32(id)) ret.SetReleaseParamCaps(false) return retmsg } func setReturnException(ret rpccapnp.Return, err error) rpccapnp.Exception { e, _ := rpccapnp.NewException(ret.Segment()) toException(e, err) ret.SetException(e) return e } // clientFromResolution retrieves a client from a resolved question or // answer by applying a transform. func clientFromResolution(transform []capnp.PipelineOp, obj capnp.Ptr, err error) capnp.Client { if err != nil { return capnp.ErrorClient(err) } out, err := capnp.TransformPtr(obj, transform) if err != nil { return capnp.ErrorClient(err) } c := out.Interface().Client() if c == nil { return capnp.ErrorClient(capnp.ErrNullClient) } return c } func newMessage(buf []byte) rpccapnp.Message { _, s, err := capnp.NewMessage(capnp.SingleSegment(buf)) if err != nil { panic(err) } m, err := rpccapnp.NewRootMessage(s) if err != nil { panic(err) } return m } // chanMutex is a mutex backed by a channel so that it can be used in a select. // A receive is a lock and a send is an unlock. type chanMutex chan struct{} type connState int const ( connAlive connState = iota connDying connDead ) func newChanMutex() chanMutex { mu := make(chanMutex, 1) mu <- struct{}{} return mu } func (mu chanMutex) Lock() { <-mu } func (mu chanMutex) TryLock(ctx context.Context) error { select { case <-mu: return nil case <-ctx.Done(): return ctx.Err() } } func (mu chanMutex) Unlock() { mu <- struct{}{} }