419 lines
11 KiB
Go
419 lines
11 KiB
Go
package pogs
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"math"
|
|
"reflect"
|
|
|
|
"zombiezen.com/go/capnproto2"
|
|
"zombiezen.com/go/capnproto2/internal/nodemap"
|
|
"zombiezen.com/go/capnproto2/internal/schema"
|
|
)
|
|
|
|
// Extract copies s into val, a pointer to a Go struct.
|
|
func Extract(val interface{}, typeID uint64, s capnp.Struct) error {
|
|
e := new(extracter)
|
|
err := e.extractStruct(reflect.ValueOf(val), typeID, s)
|
|
if err != nil {
|
|
return fmt.Errorf("pogs: extract @%#x: %v", typeID, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type extracter struct {
|
|
nodes nodemap.Map
|
|
}
|
|
|
|
var clientType = reflect.TypeOf((*capnp.Client)(nil)).Elem()
|
|
|
|
func (e *extracter) extractStruct(val reflect.Value, typeID uint64, s capnp.Struct) error {
|
|
if val.Kind() == reflect.Ptr {
|
|
if val.Type().Elem().Kind() != reflect.Struct {
|
|
return fmt.Errorf("can't extract struct into %v", val.Type())
|
|
}
|
|
switch {
|
|
case !val.CanSet() && val.IsNil():
|
|
// Even if the Cap'n Proto pointer isn't valid, this is probably
|
|
// the caller's fault and will be a bug at some point.
|
|
return errors.New("can't extract struct into nil")
|
|
case !s.IsValid() && val.CanSet():
|
|
val.Set(reflect.Zero(val.Type()))
|
|
return nil
|
|
case s.IsValid() && val.CanSet() && val.IsNil():
|
|
val.Set(reflect.New(val.Type().Elem()))
|
|
}
|
|
val = val.Elem()
|
|
} else if val.Kind() != reflect.Struct {
|
|
return fmt.Errorf("can't extract struct into %v", val.Type())
|
|
}
|
|
if !val.CanSet() {
|
|
return errors.New("can't modify struct, did you pass in a pointer to your struct?")
|
|
}
|
|
n, err := e.nodes.Find(typeID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !n.IsValid() || n.Which() != schema.Node_Which_structNode {
|
|
return fmt.Errorf("cannot find struct type %#x", typeID)
|
|
}
|
|
props, err := mapStruct(val.Type(), n)
|
|
if err != nil {
|
|
return fmt.Errorf("can't extract %s: %v", val.Type(), err)
|
|
}
|
|
var discriminant uint16
|
|
hasWhich := false
|
|
if hasDiscriminant(n) {
|
|
discriminant = s.Uint16(capnp.DataOffset(n.StructNode().DiscriminantOffset() * 2))
|
|
if err := props.setWhich(val, discriminant); err == nil {
|
|
hasWhich = true
|
|
} else if !isNoWhichError(err) {
|
|
return err
|
|
}
|
|
}
|
|
fields, err := n.StructNode().Fields()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for i := 0; i < fields.Len(); i++ {
|
|
f := fields.At(i)
|
|
vf := props.makeFieldByOrdinal(val, i)
|
|
if !vf.IsValid() {
|
|
// Don't have a field for this.
|
|
continue
|
|
}
|
|
if dv := f.DiscriminantValue(); dv != schema.Field_noDiscriminant {
|
|
if !hasWhich {
|
|
return fmt.Errorf("can't extract %s into %v: has union field but no Which field", shortDisplayName(n), val.Type())
|
|
}
|
|
if dv != discriminant {
|
|
continue
|
|
}
|
|
}
|
|
switch f.Which() {
|
|
case schema.Field_Which_slot:
|
|
if err := e.extractField(vf, s, f); err != nil {
|
|
return err
|
|
}
|
|
case schema.Field_Which_group:
|
|
if err := e.extractStruct(vf, f.Group().TypeId(), s); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (e *extracter) extractField(val reflect.Value, s capnp.Struct, f schema.Field) error {
|
|
typ, err := f.Slot().Type()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
dv, err := f.Slot().DefaultValue()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if dv.IsValid() && int(typ.Which()) != int(dv.Which()) {
|
|
name, _ := f.NameBytes()
|
|
return fmt.Errorf("extract field %s: default value is a %v, want %v", name, dv.Which(), typ.Which())
|
|
}
|
|
if !isTypeMatch(val.Type(), typ) {
|
|
name, _ := f.NameBytes()
|
|
return fmt.Errorf("can't extract field %s of type %v into a Go %v", name, typ.Which(), val.Type())
|
|
}
|
|
switch typ.Which() {
|
|
case schema.Type_Which_bool:
|
|
v := s.Bit(capnp.BitOffset(f.Slot().Offset()))
|
|
d := dv.Bool()
|
|
val.SetBool(v != d) // != acts as XOR
|
|
case schema.Type_Which_int8:
|
|
v := int8(s.Uint8(capnp.DataOffset(f.Slot().Offset())))
|
|
d := dv.Int8()
|
|
val.SetInt(int64(v ^ d))
|
|
case schema.Type_Which_int16:
|
|
v := int16(s.Uint16(capnp.DataOffset(f.Slot().Offset() * 2)))
|
|
d := dv.Int16()
|
|
val.SetInt(int64(v ^ d))
|
|
case schema.Type_Which_int32:
|
|
v := int32(s.Uint32(capnp.DataOffset(f.Slot().Offset() * 4)))
|
|
d := dv.Int32()
|
|
val.SetInt(int64(v ^ d))
|
|
case schema.Type_Which_int64:
|
|
v := int64(s.Uint64(capnp.DataOffset(f.Slot().Offset() * 8)))
|
|
d := dv.Int64()
|
|
val.SetInt(v ^ d)
|
|
case schema.Type_Which_uint8:
|
|
v := s.Uint8(capnp.DataOffset(f.Slot().Offset()))
|
|
d := dv.Uint8()
|
|
val.SetUint(uint64(v ^ d))
|
|
case schema.Type_Which_uint16:
|
|
v := s.Uint16(capnp.DataOffset(f.Slot().Offset() * 2))
|
|
d := dv.Uint16()
|
|
val.SetUint(uint64(v ^ d))
|
|
case schema.Type_Which_enum:
|
|
v := s.Uint16(capnp.DataOffset(f.Slot().Offset() * 2))
|
|
d := dv.Enum()
|
|
val.SetUint(uint64(v ^ d))
|
|
case schema.Type_Which_uint32:
|
|
v := s.Uint32(capnp.DataOffset(f.Slot().Offset() * 4))
|
|
d := dv.Uint32()
|
|
val.SetUint(uint64(v ^ d))
|
|
case schema.Type_Which_uint64:
|
|
v := s.Uint64(capnp.DataOffset(f.Slot().Offset() * 8))
|
|
d := dv.Uint64()
|
|
val.SetUint(v ^ d)
|
|
case schema.Type_Which_float32:
|
|
v := s.Uint32(capnp.DataOffset(f.Slot().Offset() * 4))
|
|
d := math.Float32bits(dv.Float32())
|
|
val.SetFloat(float64(math.Float32frombits(v ^ d)))
|
|
case schema.Type_Which_float64:
|
|
v := s.Uint64(capnp.DataOffset(f.Slot().Offset() * 8))
|
|
d := math.Float64bits(dv.Float64())
|
|
val.SetFloat(math.Float64frombits(v ^ d))
|
|
case schema.Type_Which_text:
|
|
p, err := s.Ptr(uint16(f.Slot().Offset()))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
var b []byte
|
|
if p.IsValid() {
|
|
b = p.TextBytes()
|
|
} else {
|
|
b, _ = dv.TextBytes()
|
|
}
|
|
if val.Kind() == reflect.String {
|
|
val.SetString(string(b))
|
|
} else {
|
|
// byte slice, as guaranteed by isTypeMatch
|
|
val.SetBytes(b)
|
|
}
|
|
case schema.Type_Which_data:
|
|
p, err := s.Ptr(uint16(f.Slot().Offset()))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
var b []byte
|
|
if p.IsValid() {
|
|
b = p.Data()
|
|
} else {
|
|
b, _ = dv.Data()
|
|
}
|
|
val.SetBytes(b)
|
|
case schema.Type_Which_structType:
|
|
p, err := s.Ptr(uint16(f.Slot().Offset()))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
ss := p.Struct()
|
|
if !ss.IsValid() {
|
|
p, _ = dv.StructValuePtr()
|
|
ss = p.Struct()
|
|
}
|
|
return e.extractStruct(val, typ.StructType().TypeId(), ss)
|
|
case schema.Type_Which_list:
|
|
p, err := s.Ptr(uint16(f.Slot().Offset()))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
l := p.List()
|
|
if !l.IsValid() {
|
|
p, _ = dv.ListPtr()
|
|
l = p.List()
|
|
}
|
|
return e.extractList(val, typ, l)
|
|
case schema.Type_Which_interface:
|
|
p, err := s.Ptr(uint16(f.Slot().Offset()))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if val.Type() != clientType {
|
|
// Must be a struct wrapper.
|
|
val = val.FieldByName("Client")
|
|
}
|
|
|
|
client := p.Interface().Client()
|
|
if client == nil {
|
|
val.Set(reflect.Zero(val.Type()))
|
|
} else {
|
|
val.Set(reflect.ValueOf(client))
|
|
}
|
|
default:
|
|
return fmt.Errorf("unknown field type %v", typ.Which())
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (e *extracter) extractList(val reflect.Value, typ schema.Type, l capnp.List) error {
|
|
vt := val.Type()
|
|
elem, err := typ.List().ElementType()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !isTypeMatch(vt, typ) {
|
|
// TODO(light): the error won't be that useful for nested lists.
|
|
return fmt.Errorf("can't extract %v list into a Go %v", elem.Which(), vt)
|
|
}
|
|
if !l.IsValid() {
|
|
val.Set(reflect.Zero(vt))
|
|
return nil
|
|
}
|
|
n := l.Len()
|
|
val.Set(reflect.MakeSlice(vt, n, n))
|
|
switch elem.Which() {
|
|
case schema.Type_Which_bool:
|
|
for i := 0; i < n; i++ {
|
|
val.Index(i).SetBool(capnp.BitList{List: l}.At(i))
|
|
}
|
|
case schema.Type_Which_int8:
|
|
for i := 0; i < n; i++ {
|
|
val.Index(i).SetInt(int64(capnp.Int8List{List: l}.At(i)))
|
|
}
|
|
case schema.Type_Which_int16:
|
|
for i := 0; i < n; i++ {
|
|
val.Index(i).SetInt(int64(capnp.Int16List{List: l}.At(i)))
|
|
}
|
|
case schema.Type_Which_int32:
|
|
for i := 0; i < n; i++ {
|
|
val.Index(i).SetInt(int64(capnp.Int32List{List: l}.At(i)))
|
|
}
|
|
case schema.Type_Which_int64:
|
|
for i := 0; i < n; i++ {
|
|
val.Index(i).SetInt(capnp.Int64List{List: l}.At(i))
|
|
}
|
|
case schema.Type_Which_uint8:
|
|
for i := 0; i < n; i++ {
|
|
val.Index(i).SetUint(uint64(capnp.UInt8List{List: l}.At(i)))
|
|
}
|
|
case schema.Type_Which_uint16, schema.Type_Which_enum:
|
|
for i := 0; i < n; i++ {
|
|
val.Index(i).SetUint(uint64(capnp.UInt16List{List: l}.At(i)))
|
|
}
|
|
case schema.Type_Which_uint32:
|
|
for i := 0; i < n; i++ {
|
|
val.Index(i).SetUint(uint64(capnp.UInt32List{List: l}.At(i)))
|
|
}
|
|
case schema.Type_Which_uint64:
|
|
for i := 0; i < n; i++ {
|
|
val.Index(i).SetUint(capnp.UInt64List{List: l}.At(i))
|
|
}
|
|
case schema.Type_Which_float32:
|
|
for i := 0; i < n; i++ {
|
|
val.Index(i).SetFloat(float64(capnp.Float32List{List: l}.At(i)))
|
|
}
|
|
case schema.Type_Which_float64:
|
|
for i := 0; i < n; i++ {
|
|
val.Index(i).SetFloat(capnp.Float64List{List: l}.At(i))
|
|
}
|
|
case schema.Type_Which_text:
|
|
if val.Type().Elem().Kind() == reflect.String {
|
|
for i := 0; i < n; i++ {
|
|
s, err := capnp.TextList{List: l}.At(i)
|
|
if err != nil {
|
|
// TODO(light): collect errors and finish
|
|
return err
|
|
}
|
|
val.Index(i).SetString(s)
|
|
}
|
|
} else {
|
|
for i := 0; i < n; i++ {
|
|
b, err := capnp.TextList{List: l}.BytesAt(i)
|
|
if err != nil {
|
|
// TODO(light): collect errors and finish
|
|
return err
|
|
}
|
|
val.Index(i).SetBytes(b)
|
|
}
|
|
}
|
|
case schema.Type_Which_data:
|
|
for i := 0; i < n; i++ {
|
|
b, err := capnp.DataList{List: l}.At(i)
|
|
if err != nil {
|
|
// TODO(light): collect errors and finish
|
|
return err
|
|
}
|
|
val.Index(i).SetBytes(b)
|
|
}
|
|
case schema.Type_Which_list:
|
|
for i := 0; i < n; i++ {
|
|
p, err := capnp.PointerList{List: l}.PtrAt(i)
|
|
// TODO(light): collect errors and finish
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := e.extractList(val.Index(i), elem, p.List()); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
case schema.Type_Which_structType:
|
|
if val.Type().Elem().Kind() == reflect.Struct {
|
|
for i := 0; i < n; i++ {
|
|
err := e.extractStruct(val.Index(i), elem.StructType().TypeId(), l.Struct(i))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
} else {
|
|
for i := 0; i < n; i++ {
|
|
newval := reflect.New(val.Type().Elem().Elem())
|
|
val.Index(i).Set(newval)
|
|
err := e.extractStruct(newval, elem.StructType().TypeId(), l.Struct(i))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
default:
|
|
return fmt.Errorf("unknown list type %v", elem.Which())
|
|
}
|
|
return nil
|
|
}
|
|
|
|
var typeMap = map[schema.Type_Which]reflect.Kind{
|
|
schema.Type_Which_bool: reflect.Bool,
|
|
schema.Type_Which_int8: reflect.Int8,
|
|
schema.Type_Which_int16: reflect.Int16,
|
|
schema.Type_Which_int32: reflect.Int32,
|
|
schema.Type_Which_int64: reflect.Int64,
|
|
schema.Type_Which_uint8: reflect.Uint8,
|
|
schema.Type_Which_uint16: reflect.Uint16,
|
|
schema.Type_Which_uint32: reflect.Uint32,
|
|
schema.Type_Which_uint64: reflect.Uint64,
|
|
schema.Type_Which_float32: reflect.Float32,
|
|
schema.Type_Which_float64: reflect.Float64,
|
|
schema.Type_Which_enum: reflect.Uint16,
|
|
}
|
|
|
|
func isTypeMatch(r reflect.Type, s schema.Type) bool {
|
|
switch s.Which() {
|
|
case schema.Type_Which_text:
|
|
return r.Kind() == reflect.String || r.Kind() == reflect.Slice && r.Elem().Kind() == reflect.Uint8
|
|
case schema.Type_Which_data:
|
|
return r.Kind() == reflect.Slice && r.Elem().Kind() == reflect.Uint8
|
|
case schema.Type_Which_structType:
|
|
return isStructOrStructPtr(r)
|
|
case schema.Type_Which_list:
|
|
e, _ := s.List().ElementType()
|
|
return r.Kind() == reflect.Slice && isTypeMatch(r.Elem(), e)
|
|
case schema.Type_Which_interface:
|
|
if r == clientType {
|
|
return true
|
|
}
|
|
|
|
// Otherwise, the type must be a struct with one element named
|
|
// "Client" of type capnp.Client.
|
|
if r.Kind() != reflect.Struct {
|
|
return false
|
|
}
|
|
if r.NumField() != 1 {
|
|
return false
|
|
}
|
|
field, ok := r.FieldByName("Client")
|
|
if !ok {
|
|
return false
|
|
}
|
|
return field.Type == clientType
|
|
}
|
|
k, ok := typeMap[s.Which()]
|
|
return ok && k == r.Kind()
|
|
}
|