176 lines
4.1 KiB
Go
176 lines
4.1 KiB
Go
package rpc
|
|
|
|
import (
|
|
"bytes"
|
|
"io"
|
|
"time"
|
|
|
|
"golang.org/x/net/context"
|
|
"zombiezen.com/go/capnproto2"
|
|
rpccapnp "zombiezen.com/go/capnproto2/std/capnp/rpc"
|
|
)
|
|
|
|
// Transport is the interface that abstracts sending and receiving
|
|
// individual messages of the Cap'n Proto RPC protocol.
|
|
type Transport interface {
|
|
// SendMessage sends msg.
|
|
SendMessage(ctx context.Context, msg rpccapnp.Message) error
|
|
|
|
// RecvMessage waits to receive a message and returns it.
|
|
// Implementations may re-use buffers between calls, so the message is
|
|
// only valid until the next call to RecvMessage.
|
|
RecvMessage(ctx context.Context) (rpccapnp.Message, error)
|
|
|
|
// Close releases any resources associated with the transport.
|
|
Close() error
|
|
}
|
|
|
|
type streamTransport struct {
|
|
rwc io.ReadWriteCloser
|
|
deadline writeDeadlineSetter
|
|
|
|
enc *capnp.Encoder
|
|
dec *capnp.Decoder
|
|
wbuf bytes.Buffer
|
|
}
|
|
|
|
// StreamTransport creates a transport that sends and receives messages
|
|
// by serializing and deserializing unpacked Cap'n Proto messages.
|
|
// Closing the transport will close the underlying ReadWriteCloser.
|
|
func StreamTransport(rwc io.ReadWriteCloser) Transport {
|
|
d, _ := rwc.(writeDeadlineSetter)
|
|
s := &streamTransport{
|
|
rwc: rwc,
|
|
deadline: d,
|
|
dec: capnp.NewDecoder(rwc),
|
|
}
|
|
s.wbuf.Grow(4096)
|
|
s.enc = capnp.NewEncoder(&s.wbuf)
|
|
return s
|
|
}
|
|
|
|
func (s *streamTransport) SendMessage(ctx context.Context, msg rpccapnp.Message) error {
|
|
s.wbuf.Reset()
|
|
if err := s.enc.Encode(msg.Segment().Message()); err != nil {
|
|
return err
|
|
}
|
|
if s.deadline != nil {
|
|
// TODO(light): log errors
|
|
if d, ok := ctx.Deadline(); ok {
|
|
s.deadline.SetWriteDeadline(d)
|
|
} else {
|
|
s.deadline.SetWriteDeadline(time.Time{})
|
|
}
|
|
}
|
|
_, err := s.rwc.Write(s.wbuf.Bytes())
|
|
return err
|
|
}
|
|
|
|
func (s *streamTransport) RecvMessage(ctx context.Context) (rpccapnp.Message, error) {
|
|
var (
|
|
msg *capnp.Message
|
|
err error
|
|
)
|
|
read := make(chan struct{})
|
|
go func() {
|
|
msg, err = s.dec.Decode()
|
|
close(read)
|
|
}()
|
|
select {
|
|
case <-read:
|
|
case <-ctx.Done():
|
|
return rpccapnp.Message{}, ctx.Err()
|
|
}
|
|
if err != nil {
|
|
return rpccapnp.Message{}, err
|
|
}
|
|
return rpccapnp.ReadRootMessage(msg)
|
|
}
|
|
|
|
func (s *streamTransport) Close() error {
|
|
return s.rwc.Close()
|
|
}
|
|
|
|
type writeDeadlineSetter interface {
|
|
SetWriteDeadline(t time.Time) error
|
|
}
|
|
|
|
// dispatchSend runs in its own goroutine and sends messages on a transport.
|
|
func (c *Conn) dispatchSend() {
|
|
defer c.workers.Done()
|
|
for {
|
|
select {
|
|
case msg := <-c.out:
|
|
err := c.transport.SendMessage(c.bg, msg)
|
|
if err != nil {
|
|
c.errorf("writing %v: %v", msg.Which(), err)
|
|
}
|
|
case <-c.bg.Done():
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// sendMessage enqueues a message to be sent or returns an error if the
|
|
// connection is shut down before the message is queued. It is safe to
|
|
// call from multiple goroutines and does not require holding c.mu.
|
|
func (c *Conn) sendMessage(msg rpccapnp.Message) error {
|
|
select {
|
|
case c.out <- msg:
|
|
return nil
|
|
case <-c.bg.Done():
|
|
return ErrConnClosed
|
|
}
|
|
}
|
|
|
|
// dispatchRecv runs in its own goroutine and receives messages from a transport.
|
|
func (c *Conn) dispatchRecv() {
|
|
defer c.workers.Done()
|
|
for {
|
|
msg, err := c.transport.RecvMessage(c.bg)
|
|
if err == nil {
|
|
c.handleMessage(msg)
|
|
} else if isTemporaryError(err) {
|
|
c.errorf("read temporary error: %v", err)
|
|
} else {
|
|
c.shutdown(err)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// copyMessage clones a Cap'n Proto buffer.
|
|
func copyMessage(msg *capnp.Message) *capnp.Message {
|
|
n := msg.NumSegments()
|
|
segments := make([][]byte, n)
|
|
for i := range segments {
|
|
s, err := msg.Segment(capnp.SegmentID(i))
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
segments[i] = make([]byte, len(s.Data()))
|
|
copy(segments[i], s.Data())
|
|
}
|
|
return &capnp.Message{Arena: capnp.MultiSegment(segments)}
|
|
}
|
|
|
|
// copyRPCMessage clones an RPC packet.
|
|
func copyRPCMessage(m rpccapnp.Message) rpccapnp.Message {
|
|
mm := copyMessage(m.Segment().Message())
|
|
rpcMsg, err := rpccapnp.ReadRootMessage(mm)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return rpcMsg
|
|
}
|
|
|
|
// isTemporaryError reports whether e has a Temporary() method that
|
|
// returns true.
|
|
func isTemporaryError(e error) bool {
|
|
type temp interface {
|
|
Temporary() bool
|
|
}
|
|
t, ok := e.(temp)
|
|
return ok && t.Temporary()
|
|
}
|