cloudflared-mirror/vendor/zombiezen.com/go/capnproto2/rpc/tables.go

256 lines
5.2 KiB
Go

package rpc
import (
"errors"
"zombiezen.com/go/capnproto2"
"zombiezen.com/go/capnproto2/rpc/internal/refcount"
)
// Table IDs
type (
questionID uint32
answerID uint32
exportID uint32
importID uint32
embargoID uint32
)
// impent is an entry in the import table.
type impent struct {
rc *refcount.RefCount
refs int
}
// addImport increases the counter of the times the import ID was sent to this vat.
func (c *Conn) addImport(id importID) capnp.Client {
if c.imports == nil {
c.imports = make(map[importID]*impent)
} else if ent := c.imports[id]; ent != nil {
ent.refs++
return ent.rc.Ref()
}
client := &importClient{
id: id,
conn: c,
}
rc, ref := refcount.New(client)
c.imports[id] = &impent{rc: rc, refs: 1}
return ref
}
// popImport removes the import ID and returns the number of times the import ID was sent to this vat.
func (c *Conn) popImport(id importID) (refs int) {
if c.imports == nil {
return 0
}
ent := c.imports[id]
if ent == nil {
return 0
}
refs = ent.refs
delete(c.imports, id)
return refs
}
// An importClient implements capnp.Client for a remote capability.
type importClient struct {
id importID
conn *Conn
closed bool // protected by conn.mu
}
func (ic *importClient) Call(cl *capnp.Call) capnp.Answer {
select {
case <-ic.conn.mu:
if err := ic.conn.startWork(); err != nil {
return capnp.ErrorAnswer(err)
}
case <-cl.Ctx.Done():
return capnp.ErrorAnswer(cl.Ctx.Err())
}
ans := ic.lockedCall(cl)
ic.conn.workers.Done()
ic.conn.mu.Unlock()
return ans
}
// lockedCall is equivalent to Call but assumes that the caller is
// already holding onto ic.conn.mu.
func (ic *importClient) lockedCall(cl *capnp.Call) capnp.Answer {
if ic.closed {
return capnp.ErrorAnswer(errImportClosed)
}
q := ic.conn.newQuestion(cl.Ctx, &cl.Method)
msg := newMessage(nil)
msgCall, _ := msg.NewCall()
msgCall.SetQuestionId(uint32(q.id))
msgCall.SetInterfaceId(cl.Method.InterfaceID)
msgCall.SetMethodId(cl.Method.MethodID)
target, _ := msgCall.NewTarget()
target.SetImportedCap(uint32(ic.id))
payload, _ := msgCall.NewParams()
if err := ic.conn.fillParams(payload, cl); err != nil {
ic.conn.popQuestion(q.id)
return capnp.ErrorAnswer(err)
}
select {
case ic.conn.out <- msg:
case <-cl.Ctx.Done():
ic.conn.popQuestion(q.id)
return capnp.ErrorAnswer(cl.Ctx.Err())
case <-ic.conn.bg.Done():
ic.conn.popQuestion(q.id)
return capnp.ErrorAnswer(ErrConnClosed)
}
q.start()
return q
}
func (ic *importClient) Close() error {
ic.conn.mu.Lock()
if err := ic.conn.startWork(); err != nil {
ic.conn.mu.Unlock()
return err
}
closed := ic.closed
var i int
if !closed {
i = ic.conn.popImport(ic.id)
ic.closed = true
}
ic.conn.workers.Done()
ic.conn.mu.Unlock()
if closed {
return errImportClosed
}
if i == 0 {
return nil
}
msg := newMessage(nil)
mr, err := msg.NewRelease()
if err != nil {
return err
}
mr.SetId(uint32(ic.id))
mr.SetReferenceCount(uint32(i))
select {
case ic.conn.out <- msg:
return nil
case <-ic.conn.bg.Done():
return ErrConnClosed
}
}
type export struct {
id exportID
rc *refcount.RefCount
client capnp.Client
wireRefs int
}
func (c *Conn) findExport(id exportID) *export {
if int(id) >= len(c.exports) {
return nil
}
return c.exports[id]
}
// addExport ensures that the client is present in the table, returning its ID.
// If the client is already in the table, the previous ID is returned.
func (c *Conn) addExport(client capnp.Client) exportID {
for i, e := range c.exports {
if e != nil && isSameClient(e.rc.Client, client) {
e.wireRefs++
return exportID(i)
}
}
id := exportID(c.exportID.next())
rc, client := refcount.New(client)
export := &export{
id: id,
rc: rc,
client: client,
wireRefs: 1,
}
if int(id) == len(c.exports) {
c.exports = append(c.exports, export)
} else {
c.exports[id] = export
}
return id
}
func (c *Conn) releaseExport(id exportID, refs int) {
e := c.findExport(id)
if e == nil {
return
}
e.wireRefs -= refs
if e.wireRefs > 0 {
return
}
if e.wireRefs < 0 {
c.errorf("warning: export %v has negative refcount (%d)", id, e.wireRefs)
}
if err := e.client.Close(); err != nil {
c.errorf("export %v close: %v", id, err)
}
c.exports[id] = nil
c.exportID.remove(uint32(id))
}
type embargo <-chan struct{}
func (c *Conn) newEmbargo() (embargoID, embargo) {
id := embargoID(c.embargoID.next())
e := make(chan struct{})
if int(id) == len(c.embargoes) {
c.embargoes = append(c.embargoes, e)
} else {
c.embargoes[id] = e
}
return id, e
}
func (c *Conn) disembargo(id embargoID) {
if int(id) >= len(c.embargoes) {
return
}
e := c.embargoes[id]
if e == nil {
return
}
close(e)
c.embargoes[id] = nil
c.embargoID.remove(uint32(id))
}
// idgen returns a sequence of monotonically increasing IDs with
// support for replacement. The zero value is a generator that
// starts at zero.
type idgen struct {
i uint32
free []uint32
}
func (gen *idgen) next() uint32 {
if n := len(gen.free); n > 0 {
i := gen.free[n-1]
gen.free = gen.free[:n-1]
return i
}
i := gen.i
gen.i++
return i
}
func (gen *idgen) remove(i uint32) {
gen.free = append(gen.free, i)
}
var errImportClosed = errors.New("rpc: call on closed import")