cloudflared-mirror/vendor/zombiezen.com/go/capnproto2/pogs/insert.go

486 lines
14 KiB
Go

package pogs
import (
"fmt"
"math"
"reflect"
"zombiezen.com/go/capnproto2"
"zombiezen.com/go/capnproto2/internal/nodemap"
"zombiezen.com/go/capnproto2/internal/schema"
)
// Insert copies val, a pointer to a Go struct, into s.
func Insert(typeID uint64, s capnp.Struct, val interface{}) error {
ins := new(inserter)
err := ins.insertStruct(typeID, s, reflect.ValueOf(val))
if err != nil {
return fmt.Errorf("pogs: insert @%#x: %v", typeID, err)
}
return nil
}
type inserter struct {
nodes nodemap.Map
}
func (ins *inserter) insertStruct(typeID uint64, s capnp.Struct, val reflect.Value) error {
if val.Kind() == reflect.Ptr {
// TODO(light): ignore if nil?
val = val.Elem()
}
if val.Kind() != reflect.Struct {
return fmt.Errorf("can't insert %v into a struct", val.Kind())
}
n, err := ins.nodes.Find(typeID)
if err != nil {
return err
}
if !n.IsValid() || n.Which() != schema.Node_Which_structNode {
return fmt.Errorf("cannot find struct type %#x", typeID)
}
props, err := mapStruct(val.Type(), n)
if err != nil {
return fmt.Errorf("can't insert into %v: %v", val.Type(), err)
}
var discriminant uint16
hasWhich := false
if hasDiscriminant(n) {
discriminant, hasWhich = props.which(val)
if hasWhich {
off := capnp.DataOffset(n.StructNode().DiscriminantOffset() * 2)
if s.Size().DataSize < capnp.Size(off+2) {
return fmt.Errorf("can't set discriminant for %s: allocated struct is too small", shortDisplayName(n))
}
s.SetUint16(off, discriminant)
}
}
fields, err := n.StructNode().Fields()
if err != nil {
return err
}
for i := 0; i < fields.Len(); i++ {
f := fields.At(i)
vf := props.fieldByOrdinal(val, i)
if !vf.IsValid() {
// Don't have a field for this.
continue
}
if dv := f.DiscriminantValue(); dv != schema.Field_noDiscriminant {
if !hasWhich {
sname, _ := f.NameBytes()
return fmt.Errorf("can't insert %s from %v: has union field %s but no Which field", shortDisplayName(n), val.Type(), sname)
}
if dv != discriminant {
continue
}
}
switch f.Which() {
case schema.Field_Which_slot:
if err := ins.insertField(s, f, vf); err != nil {
return err
}
case schema.Field_Which_group:
if err := ins.insertStruct(f.Group().TypeId(), s, vf); err != nil {
return err
}
}
}
return nil
}
func (ins *inserter) insertField(s capnp.Struct, f schema.Field, val reflect.Value) error {
typ, err := f.Slot().Type()
if err != nil {
return err
}
dv, err := f.Slot().DefaultValue()
if err != nil {
return err
}
if dv.IsValid() && int(typ.Which()) != int(dv.Which()) {
name, _ := f.NameBytes()
return fmt.Errorf("insert field %s: default value is a %v, want %v", name, dv.Which(), typ.Which())
}
if !isTypeMatch(val.Type(), typ) {
name, _ := f.NameBytes()
return fmt.Errorf("can't insert field %s of type Go %v into a %v", name, val.Type(), typ.Which())
}
if !isFieldInBounds(s.Size(), f.Slot().Offset(), typ) {
name, _ := f.NameBytes()
return fmt.Errorf("can't insert field %s: allocated struct is too small", name)
}
switch typ.Which() {
case schema.Type_Which_bool:
v := val.Bool()
d := dv.Bool()
s.SetBit(capnp.BitOffset(f.Slot().Offset()), v != d) // != acts as XOR
case schema.Type_Which_int8:
v := int8(val.Int())
d := dv.Int8()
s.SetUint8(capnp.DataOffset(f.Slot().Offset()), uint8(v^d))
case schema.Type_Which_int16:
v := int16(val.Int())
d := dv.Int16()
s.SetUint16(capnp.DataOffset(f.Slot().Offset()*2), uint16(v^d))
case schema.Type_Which_int32:
v := int32(val.Int())
d := dv.Int32()
s.SetUint32(capnp.DataOffset(f.Slot().Offset()*4), uint32(v^d))
case schema.Type_Which_int64:
v := val.Int()
d := dv.Int64()
s.SetUint64(capnp.DataOffset(f.Slot().Offset()*8), uint64(v^d))
case schema.Type_Which_uint8:
v := uint8(val.Uint())
d := dv.Uint8()
s.SetUint8(capnp.DataOffset(f.Slot().Offset()), v^d)
case schema.Type_Which_uint16:
v := uint16(val.Uint())
d := dv.Uint16()
s.SetUint16(capnp.DataOffset(f.Slot().Offset()*2), v^d)
case schema.Type_Which_enum:
v := uint16(val.Uint())
d := dv.Enum()
s.SetUint16(capnp.DataOffset(f.Slot().Offset()*2), v^d)
case schema.Type_Which_uint32:
v := uint32(val.Uint())
d := dv.Uint32()
s.SetUint32(capnp.DataOffset(f.Slot().Offset()*4), v^d)
case schema.Type_Which_uint64:
v := val.Uint()
d := dv.Uint64()
s.SetUint64(capnp.DataOffset(f.Slot().Offset()*8), v^d)
case schema.Type_Which_float32:
v := math.Float32bits(float32(val.Float()))
d := math.Float32bits(dv.Float32())
s.SetUint32(capnp.DataOffset(f.Slot().Offset()*4), v^d)
case schema.Type_Which_float64:
v := math.Float64bits(val.Float())
d := uint64(math.Float64bits(dv.Float64()))
s.SetUint64(capnp.DataOffset(f.Slot().Offset()*8), v^d)
case schema.Type_Which_text:
off := uint16(f.Slot().Offset())
if val.Len() == 0 {
if !isEmptyValue(dv) {
return s.SetNewText(off, "")
}
return s.SetText(off, "")
}
if val.Kind() == reflect.String {
return s.SetText(off, val.String())
} else {
return s.SetTextFromBytes(off, val.Bytes())
}
case schema.Type_Which_data:
b := val.Bytes()
if b == nil && !isEmptyValue(dv) {
b = []byte{}
}
off := uint16(f.Slot().Offset())
return s.SetData(off, b)
case schema.Type_Which_structType:
off := uint16(f.Slot().Offset())
sval := val
if val.Kind() == reflect.Ptr {
if val.IsNil() {
return s.SetPtr(off, capnp.Ptr{})
}
sval = val.Elem()
}
id := typ.StructType().TypeId()
sz, err := ins.structSize(id)
if err != nil {
return err
}
ss, err := capnp.NewStruct(s.Segment(), sz)
if err != nil {
return err
}
if err := s.SetPtr(off, ss.ToPtr()); err != nil {
return err
}
return ins.insertStruct(id, ss, sval)
case schema.Type_Which_list:
off := uint16(f.Slot().Offset())
if val.IsNil() && isEmptyValue(dv) {
return s.SetPtr(off, capnp.Ptr{})
}
elem, err := typ.List().ElementType()
if err != nil {
return err
}
l, err := ins.newList(s.Segment(), elem, int32(val.Len()))
if err != nil {
return err
}
if err := s.SetPtr(off, l.ToPtr()); err != nil {
return err
}
return ins.insertList(l, typ, val)
case schema.Type_Which_interface:
off := uint16(f.Slot().Offset())
ptr := capPtr(s.Segment(), val)
if err := s.SetPtr(off, ptr); err != nil {
return err
}
default:
return fmt.Errorf("unknown field type %v", typ.Which())
}
return nil
}
func capPtr(seg *capnp.Segment, val reflect.Value) capnp.Ptr {
client, ok := val.Interface().(capnp.Client)
if !ok {
client, ok = val.FieldByName("Client").Interface().(capnp.Client)
if !ok {
// interface is nil.
return capnp.Ptr{}
}
}
cap := seg.Message().AddCap(client)
iface := capnp.NewInterface(seg, cap)
return iface.ToPtr()
}
func (ins *inserter) insertList(l capnp.List, typ schema.Type, val reflect.Value) error {
elem, err := typ.List().ElementType()
if err != nil {
return err
}
if !isTypeMatch(val.Type(), typ) {
// TODO(light): the error won't be that useful for nested lists.
return fmt.Errorf("can't insert Go %v into a %v list", val.Type(), elem.Which())
}
n := val.Len()
switch elem.Which() {
case schema.Type_Which_void:
case schema.Type_Which_bool:
for i := 0; i < n; i++ {
capnp.BitList{List: l}.Set(i, val.Index(i).Bool())
}
case schema.Type_Which_int8:
for i := 0; i < n; i++ {
capnp.Int8List{List: l}.Set(i, int8(val.Index(i).Int()))
}
case schema.Type_Which_int16:
for i := 0; i < n; i++ {
capnp.Int16List{List: l}.Set(i, int16(val.Index(i).Int()))
}
case schema.Type_Which_int32:
for i := 0; i < n; i++ {
capnp.Int32List{List: l}.Set(i, int32(val.Index(i).Int()))
}
case schema.Type_Which_int64:
for i := 0; i < n; i++ {
capnp.Int64List{List: l}.Set(i, val.Index(i).Int())
}
case schema.Type_Which_uint8:
for i := 0; i < n; i++ {
capnp.UInt8List{List: l}.Set(i, uint8(val.Index(i).Uint()))
}
case schema.Type_Which_uint16, schema.Type_Which_enum:
for i := 0; i < n; i++ {
capnp.UInt16List{List: l}.Set(i, uint16(val.Index(i).Uint()))
}
case schema.Type_Which_uint32:
for i := 0; i < n; i++ {
capnp.UInt32List{List: l}.Set(i, uint32(val.Index(i).Uint()))
}
case schema.Type_Which_uint64:
for i := 0; i < n; i++ {
capnp.UInt64List{List: l}.Set(i, val.Index(i).Uint())
}
case schema.Type_Which_float32:
for i := 0; i < n; i++ {
capnp.Float32List{List: l}.Set(i, float32(val.Index(i).Float()))
}
case schema.Type_Which_float64:
for i := 0; i < n; i++ {
capnp.Float64List{List: l}.Set(i, val.Index(i).Float())
}
case schema.Type_Which_text:
if val.Type().Elem().Kind() == reflect.String {
for i := 0; i < n; i++ {
err := capnp.TextList{List: l}.Set(i, val.Index(i).String())
if err != nil {
// TODO(light): collect errors and finish
return err
}
}
} else {
for i := 0; i < n; i++ {
b := val.Index(i).Bytes()
if len(b) == 0 {
err := capnp.PointerList{List: l}.SetPtr(i, capnp.Ptr{})
if err != nil {
// TODO(light): collect errors and finish
return err
}
}
t, err := capnp.NewTextFromBytes(l.Segment(), b)
if err != nil {
// TODO(light): collect errors and finish
return err
}
err = capnp.PointerList{List: l}.SetPtr(i, t.ToPtr())
if err != nil {
// TODO(light): collect errors and finish
return err
}
}
}
case schema.Type_Which_data:
for i := 0; i < n; i++ {
b := val.Index(i).Bytes()
if len(b) == 0 {
err := capnp.PointerList{List: l}.SetPtr(i, capnp.Ptr{})
if err != nil {
// TODO(light): collect errors and finish
return err
}
}
err := capnp.DataList{List: l}.Set(i, b)
if err != nil {
// TODO(light): collect errors and finish
return err
}
}
case schema.Type_Which_list:
pl := capnp.PointerList{List: l}
for i := 0; i < n; i++ {
vi := val.Index(i)
if vi.IsNil() {
if err := pl.SetPtr(i, capnp.Ptr{}); err != nil {
return err
}
continue
}
ee, err := elem.List().ElementType()
if err != nil {
return err
}
li, err := ins.newList(l.Segment(), ee, int32(vi.Len()))
if err != nil {
return err
}
if err := pl.SetPtr(i, li.ToPtr()); err != nil {
return err
}
if err := ins.insertList(li, elem, vi); err != nil {
return err
}
}
case schema.Type_Which_structType:
id := elem.StructType().TypeId()
for i := 0; i < n; i++ {
err := ins.insertStruct(id, l.Struct(i), val.Index(i))
if err != nil {
// TODO(light): collect errors and finish
return err
}
}
case schema.Type_Which_interface:
pl := capnp.PointerList{List: l}
for i := 0; i < n; i++ {
ptr := capPtr(l.Segment(), val.Index(i))
if err := pl.SetPtr(i, ptr); err != nil {
// TODO(zenhack): collect errors and finish
return err
}
}
default:
return fmt.Errorf("unknown list type %v", elem.Which())
}
return nil
}
func (ins *inserter) newList(s *capnp.Segment, t schema.Type, len int32) (capnp.List, error) {
switch t.Which() {
case schema.Type_Which_void:
l := capnp.NewVoidList(s, len)
return l.List, nil
case schema.Type_Which_bool:
l, err := capnp.NewBitList(s, len)
return l.List, err
case schema.Type_Which_int8, schema.Type_Which_uint8:
l, err := capnp.NewUInt8List(s, len)
return l.List, err
case schema.Type_Which_int16, schema.Type_Which_uint16, schema.Type_Which_enum:
l, err := capnp.NewUInt16List(s, len)
return l.List, err
case schema.Type_Which_int32, schema.Type_Which_uint32, schema.Type_Which_float32:
l, err := capnp.NewUInt32List(s, len)
return l.List, err
case schema.Type_Which_int64, schema.Type_Which_uint64, schema.Type_Which_float64:
l, err := capnp.NewUInt64List(s, len)
return l.List, err
case schema.Type_Which_text, schema.Type_Which_data, schema.Type_Which_list, schema.Type_Which_interface, schema.Type_Which_anyPointer:
l, err := capnp.NewPointerList(s, len)
return l.List, err
case schema.Type_Which_structType:
sz, err := ins.structSize(t.StructType().TypeId())
if err != nil {
return capnp.List{}, err
}
return capnp.NewCompositeList(s, sz, len)
default:
return capnp.List{}, fmt.Errorf("new list: unknown element type: %v", t.Which())
}
}
func (ins *inserter) structSize(id uint64) (capnp.ObjectSize, error) {
n, err := ins.nodes.Find(id)
if err != nil {
return capnp.ObjectSize{}, err
}
if n.Which() != schema.Node_Which_structNode {
return capnp.ObjectSize{}, fmt.Errorf("insert struct: sizing: node @%#x is not a struct", id)
}
return capnp.ObjectSize{
DataSize: capnp.Size(n.StructNode().DataWordCount()) * 8,
PointerCount: n.StructNode().PointerCount(),
}, nil
}
func isFieldInBounds(sz capnp.ObjectSize, off uint32, t schema.Type) bool {
switch t.Which() {
case schema.Type_Which_void:
return true
case schema.Type_Which_bool:
return sz.DataSize >= capnp.Size(off/8+1)
case schema.Type_Which_int8, schema.Type_Which_uint8:
return sz.DataSize >= capnp.Size(off+1)
case schema.Type_Which_int16, schema.Type_Which_uint16, schema.Type_Which_enum:
return sz.DataSize >= capnp.Size(off+1)*2
case schema.Type_Which_int32, schema.Type_Which_uint32, schema.Type_Which_float32:
return sz.DataSize >= capnp.Size(off+1)*4
case schema.Type_Which_int64, schema.Type_Which_uint64, schema.Type_Which_float64:
return sz.DataSize >= capnp.Size(off+1)*8
case schema.Type_Which_text, schema.Type_Which_data, schema.Type_Which_list, schema.Type_Which_structType, schema.Type_Which_interface, schema.Type_Which_anyPointer:
return sz.PointerCount >= uint16(off+1)
default:
return false
}
}
func isEmptyValue(v schema.Value) bool {
if !v.IsValid() {
return false
}
switch v.Which() {
case schema.Value_Which_text:
b, _ := v.TextBytes()
return len(b) == 0
case schema.Value_Which_data:
b, _ := v.Data()
return len(b) == 0
case schema.Value_Which_list:
p, _ := v.ListPtr()
return p.List().Len() == 0
default:
return false
}
}