package rpc import ( "errors" "zombiezen.com/go/capnproto2" "zombiezen.com/go/capnproto2/rpc/internal/refcount" ) // Table IDs type ( questionID uint32 answerID uint32 exportID uint32 importID uint32 embargoID uint32 ) // impent is an entry in the import table. type impent struct { rc *refcount.RefCount refs int } // addImport increases the counter of the times the import ID was sent to this vat. func (c *Conn) addImport(id importID) capnp.Client { if c.imports == nil { c.imports = make(map[importID]*impent) } else if ent := c.imports[id]; ent != nil { ent.refs++ return ent.rc.Ref() } client := &importClient{ id: id, conn: c, } rc, ref := refcount.New(client) c.imports[id] = &impent{rc: rc, refs: 1} return ref } // popImport removes the import ID and returns the number of times the import ID was sent to this vat. func (c *Conn) popImport(id importID) (refs int) { if c.imports == nil { return 0 } ent := c.imports[id] if ent == nil { return 0 } refs = ent.refs delete(c.imports, id) return refs } // An importClient implements capnp.Client for a remote capability. type importClient struct { id importID conn *Conn closed bool // protected by conn.mu } func (ic *importClient) Call(cl *capnp.Call) capnp.Answer { select { case <-ic.conn.mu: if err := ic.conn.startWork(); err != nil { return capnp.ErrorAnswer(err) } case <-cl.Ctx.Done(): return capnp.ErrorAnswer(cl.Ctx.Err()) } ans := ic.lockedCall(cl) ic.conn.workers.Done() ic.conn.mu.Unlock() return ans } // lockedCall is equivalent to Call but assumes that the caller is // already holding onto ic.conn.mu. func (ic *importClient) lockedCall(cl *capnp.Call) capnp.Answer { if ic.closed { return capnp.ErrorAnswer(errImportClosed) } q := ic.conn.newQuestion(cl.Ctx, &cl.Method) msg := newMessage(nil) msgCall, _ := msg.NewCall() msgCall.SetQuestionId(uint32(q.id)) msgCall.SetInterfaceId(cl.Method.InterfaceID) msgCall.SetMethodId(cl.Method.MethodID) target, _ := msgCall.NewTarget() target.SetImportedCap(uint32(ic.id)) payload, _ := msgCall.NewParams() if err := ic.conn.fillParams(payload, cl); err != nil { ic.conn.popQuestion(q.id) return capnp.ErrorAnswer(err) } select { case ic.conn.out <- msg: case <-cl.Ctx.Done(): ic.conn.popQuestion(q.id) return capnp.ErrorAnswer(cl.Ctx.Err()) case <-ic.conn.bg.Done(): ic.conn.popQuestion(q.id) return capnp.ErrorAnswer(ErrConnClosed) } q.start() return q } func (ic *importClient) Close() error { ic.conn.mu.Lock() if err := ic.conn.startWork(); err != nil { ic.conn.mu.Unlock() return err } closed := ic.closed var i int if !closed { i = ic.conn.popImport(ic.id) ic.closed = true } ic.conn.workers.Done() ic.conn.mu.Unlock() if closed { return errImportClosed } if i == 0 { return nil } msg := newMessage(nil) mr, err := msg.NewRelease() if err != nil { return err } mr.SetId(uint32(ic.id)) mr.SetReferenceCount(uint32(i)) select { case ic.conn.out <- msg: return nil case <-ic.conn.bg.Done(): return ErrConnClosed } } type export struct { id exportID rc *refcount.RefCount client capnp.Client wireRefs int } func (c *Conn) findExport(id exportID) *export { if int(id) >= len(c.exports) { return nil } return c.exports[id] } // addExport ensures that the client is present in the table, returning its ID. // If the client is already in the table, the previous ID is returned. func (c *Conn) addExport(client capnp.Client) exportID { for i, e := range c.exports { if e != nil && isSameClient(e.rc.Client, client) { e.wireRefs++ return exportID(i) } } id := exportID(c.exportID.next()) rc, client := refcount.New(client) export := &export{ id: id, rc: rc, client: client, wireRefs: 1, } if int(id) == len(c.exports) { c.exports = append(c.exports, export) } else { c.exports[id] = export } return id } func (c *Conn) releaseExport(id exportID, refs int) { e := c.findExport(id) if e == nil { return } e.wireRefs -= refs if e.wireRefs > 0 { return } if e.wireRefs < 0 { c.errorf("warning: export %v has negative refcount (%d)", id, e.wireRefs) } if err := e.client.Close(); err != nil { c.errorf("export %v close: %v", id, err) } c.exports[id] = nil c.exportID.remove(uint32(id)) } type embargo <-chan struct{} func (c *Conn) newEmbargo() (embargoID, embargo) { id := embargoID(c.embargoID.next()) e := make(chan struct{}) if int(id) == len(c.embargoes) { c.embargoes = append(c.embargoes, e) } else { c.embargoes[id] = e } return id, e } func (c *Conn) disembargo(id embargoID) { if int(id) >= len(c.embargoes) { return } e := c.embargoes[id] if e == nil { return } close(e) c.embargoes[id] = nil c.embargoID.remove(uint32(id)) } // idgen returns a sequence of monotonically increasing IDs with // support for replacement. The zero value is a generator that // starts at zero. type idgen struct { i uint32 free []uint32 } func (gen *idgen) next() uint32 { if n := len(gen.free); n > 0 { i := gen.free[n-1] gen.free = gen.free[:n-1] return i } i := gen.i gen.i++ return i } func (gen *idgen) remove(i uint32) { gen.free = append(gen.free, i) } var errImportClosed = errors.New("rpc: call on closed import")