cloudflared-mirror/vendor/zombiezen.com/go/capnproto2/rpc/transport.go

176 lines
4.1 KiB
Go

package rpc
import (
"bytes"
"io"
"time"
"golang.org/x/net/context"
"zombiezen.com/go/capnproto2"
rpccapnp "zombiezen.com/go/capnproto2/std/capnp/rpc"
)
// Transport is the interface that abstracts sending and receiving
// individual messages of the Cap'n Proto RPC protocol.
type Transport interface {
// SendMessage sends msg.
SendMessage(ctx context.Context, msg rpccapnp.Message) error
// RecvMessage waits to receive a message and returns it.
// Implementations may re-use buffers between calls, so the message is
// only valid until the next call to RecvMessage.
RecvMessage(ctx context.Context) (rpccapnp.Message, error)
// Close releases any resources associated with the transport.
Close() error
}
type streamTransport struct {
rwc io.ReadWriteCloser
deadline writeDeadlineSetter
enc *capnp.Encoder
dec *capnp.Decoder
wbuf bytes.Buffer
}
// StreamTransport creates a transport that sends and receives messages
// by serializing and deserializing unpacked Cap'n Proto messages.
// Closing the transport will close the underlying ReadWriteCloser.
func StreamTransport(rwc io.ReadWriteCloser) Transport {
d, _ := rwc.(writeDeadlineSetter)
s := &streamTransport{
rwc: rwc,
deadline: d,
dec: capnp.NewDecoder(rwc),
}
s.wbuf.Grow(4096)
s.enc = capnp.NewEncoder(&s.wbuf)
return s
}
func (s *streamTransport) SendMessage(ctx context.Context, msg rpccapnp.Message) error {
s.wbuf.Reset()
if err := s.enc.Encode(msg.Segment().Message()); err != nil {
return err
}
if s.deadline != nil {
// TODO(light): log errors
if d, ok := ctx.Deadline(); ok {
s.deadline.SetWriteDeadline(d)
} else {
s.deadline.SetWriteDeadline(time.Time{})
}
}
_, err := s.rwc.Write(s.wbuf.Bytes())
return err
}
func (s *streamTransport) RecvMessage(ctx context.Context) (rpccapnp.Message, error) {
var (
msg *capnp.Message
err error
)
read := make(chan struct{})
go func() {
msg, err = s.dec.Decode()
close(read)
}()
select {
case <-read:
case <-ctx.Done():
return rpccapnp.Message{}, ctx.Err()
}
if err != nil {
return rpccapnp.Message{}, err
}
return rpccapnp.ReadRootMessage(msg)
}
func (s *streamTransport) Close() error {
return s.rwc.Close()
}
type writeDeadlineSetter interface {
SetWriteDeadline(t time.Time) error
}
// dispatchSend runs in its own goroutine and sends messages on a transport.
func (c *Conn) dispatchSend() {
defer c.workers.Done()
for {
select {
case msg := <-c.out:
err := c.transport.SendMessage(c.bg, msg)
if err != nil {
c.errorf("writing %v: %v", msg.Which(), err)
}
case <-c.bg.Done():
return
}
}
}
// sendMessage enqueues a message to be sent or returns an error if the
// connection is shut down before the message is queued. It is safe to
// call from multiple goroutines and does not require holding c.mu.
func (c *Conn) sendMessage(msg rpccapnp.Message) error {
select {
case c.out <- msg:
return nil
case <-c.bg.Done():
return ErrConnClosed
}
}
// dispatchRecv runs in its own goroutine and receives messages from a transport.
func (c *Conn) dispatchRecv() {
defer c.workers.Done()
for {
msg, err := c.transport.RecvMessage(c.bg)
if err == nil {
c.handleMessage(msg)
} else if isTemporaryError(err) {
c.errorf("read temporary error: %v", err)
} else {
c.shutdown(err)
return
}
}
}
// copyMessage clones a Cap'n Proto buffer.
func copyMessage(msg *capnp.Message) *capnp.Message {
n := msg.NumSegments()
segments := make([][]byte, n)
for i := range segments {
s, err := msg.Segment(capnp.SegmentID(i))
if err != nil {
panic(err)
}
segments[i] = make([]byte, len(s.Data()))
copy(segments[i], s.Data())
}
return &capnp.Message{Arena: capnp.MultiSegment(segments)}
}
// copyRPCMessage clones an RPC packet.
func copyRPCMessage(m rpccapnp.Message) rpccapnp.Message {
mm := copyMessage(m.Segment().Message())
rpcMsg, err := rpccapnp.ReadRootMessage(mm)
if err != nil {
panic(err)
}
return rpcMsg
}
// isTemporaryError reports whether e has a Temporary() method that
// returns true.
func isTemporaryError(e error) bool {
type temp interface {
Temporary() bool
}
t, ok := e.(temp)
return ok && t.Temporary()
}