351 lines
8.0 KiB
Go
351 lines
8.0 KiB
Go
package pogs
|
|
|
|
import (
|
|
"fmt"
|
|
"reflect"
|
|
"strings"
|
|
|
|
"zombiezen.com/go/capnproto2/internal/schema"
|
|
)
|
|
|
|
type fieldProps struct {
|
|
schemaName string // empty if doesn't map to schema
|
|
typ fieldType
|
|
fixedWhich string
|
|
tagged bool
|
|
}
|
|
|
|
type fieldType int
|
|
|
|
const (
|
|
mappedField fieldType = iota
|
|
whichField
|
|
embedField
|
|
)
|
|
|
|
func parseField(f reflect.StructField, hasDiscrim bool) fieldProps {
|
|
var p fieldProps
|
|
tag := f.Tag.Get("capnp")
|
|
p.tagged = tag != ""
|
|
tname, opts := nextOpt(tag)
|
|
switch tname {
|
|
case "-":
|
|
// omitted field
|
|
case "":
|
|
if f.Anonymous && isStructOrStructPtr(f.Type) {
|
|
p.typ = embedField
|
|
return p
|
|
}
|
|
if hasDiscrim && f.Name == "Which" {
|
|
p.typ = whichField
|
|
for len(opts) > 0 {
|
|
var curr string
|
|
curr, opts = nextOpt(opts)
|
|
if strings.HasPrefix(curr, "which=") {
|
|
p.fixedWhich = strings.TrimPrefix(curr, "which=")
|
|
break
|
|
}
|
|
}
|
|
return p
|
|
}
|
|
// TODO(light): check it's uppercase.
|
|
x := f.Name[0] - 'A' + 'a'
|
|
p.schemaName = string(x) + f.Name[1:]
|
|
default:
|
|
p.schemaName = tname
|
|
}
|
|
return p
|
|
}
|
|
|
|
func nextOpt(opts string) (head, tail string) {
|
|
i := strings.Index(opts, ",")
|
|
if i == -1 {
|
|
return opts, ""
|
|
}
|
|
return opts[:i], opts[i+1:]
|
|
}
|
|
|
|
type fieldLoc struct {
|
|
i int
|
|
path []int
|
|
}
|
|
|
|
func (loc fieldLoc) depth() int {
|
|
if len(loc.path) > 0 {
|
|
return len(loc.path)
|
|
}
|
|
return 1
|
|
}
|
|
|
|
func (loc fieldLoc) sub(i int) fieldLoc {
|
|
n := len(loc.path)
|
|
switch {
|
|
case !loc.isValid():
|
|
return fieldLoc{i: i}
|
|
case n > 0:
|
|
p := make([]int, n+1)
|
|
copy(p, loc.path)
|
|
p[n] = i
|
|
return fieldLoc{path: p}
|
|
default:
|
|
return fieldLoc{path: []int{loc.i, i}}
|
|
}
|
|
}
|
|
|
|
func (loc fieldLoc) isValid() bool {
|
|
return loc.i >= 0
|
|
}
|
|
|
|
type structProps struct {
|
|
fields []fieldLoc
|
|
whichLoc fieldLoc // i == -1: none; i == -2: fixed
|
|
fixedWhich uint16
|
|
}
|
|
|
|
func mapStruct(t reflect.Type, n schema.Node) (structProps, error) {
|
|
fields, err := n.StructNode().Fields()
|
|
if err != nil {
|
|
return structProps{}, err
|
|
}
|
|
sp := structProps{
|
|
fields: make([]fieldLoc, fields.Len()),
|
|
whichLoc: fieldLoc{i: -1},
|
|
}
|
|
for i := range sp.fields {
|
|
sp.fields[i] = fieldLoc{i: -1}
|
|
}
|
|
sm := structMapper{
|
|
sp: &sp,
|
|
t: t,
|
|
hasDiscrim: hasDiscriminant(n),
|
|
fields: fields,
|
|
}
|
|
if err := sm.visit(fieldLoc{i: -1}); err != nil {
|
|
return structProps{}, err
|
|
}
|
|
for len(sm.embedQueue) > 0 {
|
|
loc := sm.embedQueue[0]
|
|
copy(sm.embedQueue, sm.embedQueue[1:])
|
|
sm.embedQueue = sm.embedQueue[:len(sm.embedQueue)-1]
|
|
if err := sm.visit(loc); err != nil {
|
|
return structProps{}, err
|
|
}
|
|
}
|
|
return sp, nil
|
|
}
|
|
|
|
type structMapper struct {
|
|
sp *structProps
|
|
t reflect.Type
|
|
hasDiscrim bool
|
|
fields schema.Field_List
|
|
embedQueue []fieldLoc
|
|
}
|
|
|
|
func (sm *structMapper) visit(base fieldLoc) error {
|
|
t := sm.t
|
|
if base.isValid() {
|
|
t = typeFieldByLoc(t, base).Type
|
|
if t.Kind() == reflect.Ptr {
|
|
t = t.Elem()
|
|
}
|
|
}
|
|
for i := 0; i < t.NumField(); i++ {
|
|
f := t.Field(i)
|
|
if f.PkgPath != "" && !f.Anonymous {
|
|
// unexported field
|
|
continue
|
|
}
|
|
loc := base.sub(i)
|
|
p := parseField(f, sm.hasDiscrim)
|
|
if p.typ == embedField {
|
|
sm.embedQueue = append(sm.embedQueue, loc)
|
|
continue
|
|
}
|
|
if err := sm.visitField(loc, f, p); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (sm *structMapper) visitField(loc fieldLoc, f reflect.StructField, p fieldProps) error {
|
|
switch p.typ {
|
|
case mappedField:
|
|
if p.schemaName == "" {
|
|
return nil
|
|
}
|
|
fi := fieldIndex(sm.fields, p.schemaName)
|
|
if fi < 0 {
|
|
return fmt.Errorf("%v has unknown field %s, maps to %s", sm.t, f.Name, p.schemaName)
|
|
}
|
|
switch oldloc := sm.sp.fields[fi]; {
|
|
case oldloc.i == -2:
|
|
// Prior tag collision, do nothing.
|
|
case oldloc.i == -3 && p.tagged && loc.depth() == len(oldloc.path):
|
|
// A tagged field wins over untagged fields.
|
|
sm.sp.fields[fi] = loc
|
|
case oldloc.isValid() && oldloc.depth() < loc.depth():
|
|
// This is deeper, don't override.
|
|
case oldloc.isValid() && oldloc.depth() == loc.depth():
|
|
oldp := parseField(typeFieldByLoc(sm.t, oldloc), sm.hasDiscrim)
|
|
if oldp.tagged && p.tagged {
|
|
// Tag collision
|
|
sm.sp.fields[fi] = fieldLoc{i: -2}
|
|
} else if p.tagged {
|
|
sm.sp.fields[fi] = loc
|
|
} else if !oldp.tagged {
|
|
// Multiple untagged fields. Keep path, because we need it for depth.
|
|
sm.sp.fields[fi].i = -3
|
|
}
|
|
default:
|
|
sm.sp.fields[fi] = loc
|
|
}
|
|
case whichField:
|
|
if sm.sp.whichLoc.i != -1 {
|
|
return fmt.Errorf("%v embeds multiple Which fields", sm.t)
|
|
}
|
|
switch {
|
|
case p.fixedWhich != "":
|
|
fi := fieldIndex(sm.fields, p.fixedWhich)
|
|
if fi < 0 {
|
|
return fmt.Errorf("%v.Which is tagged with unknown field %s", sm.t, p.fixedWhich)
|
|
}
|
|
dv := sm.fields.At(fi).DiscriminantValue()
|
|
if dv == schema.Field_noDiscriminant {
|
|
return fmt.Errorf("%v.Which is tagged with non-union field %s", sm.t, p.fixedWhich)
|
|
}
|
|
sm.sp.whichLoc = fieldLoc{i: -2}
|
|
sm.sp.fixedWhich = dv
|
|
case f.Type.Kind() != reflect.Uint16:
|
|
return fmt.Errorf("%v.Which is type %v, not uint16", sm.t, f.Type)
|
|
default:
|
|
sm.sp.whichLoc = loc
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// fieldBySchemaName returns the field for the given name.
|
|
// Returns an invalid value if the field was not found or it is
|
|
// contained inside a nil anonymous struct pointer.
|
|
func (sp structProps) fieldByOrdinal(val reflect.Value, i int) reflect.Value {
|
|
return fieldByLoc(val, sp.fields[i], false)
|
|
}
|
|
|
|
// makeFieldBySchemaName returns the field for the given name, creating
|
|
// its parent anonymous structs if necessary. Returns an invalid value
|
|
// if the field was not found.
|
|
func (sp structProps) makeFieldByOrdinal(val reflect.Value, i int) reflect.Value {
|
|
return fieldByLoc(val, sp.fields[i], true)
|
|
}
|
|
|
|
// which returns the value of the discriminator field.
|
|
func (sp structProps) which(val reflect.Value) (discrim uint16, ok bool) {
|
|
if sp.whichLoc.i == -2 {
|
|
return sp.fixedWhich, true
|
|
}
|
|
f := fieldByLoc(val, sp.whichLoc, false)
|
|
if !f.IsValid() {
|
|
return 0, false
|
|
}
|
|
return uint16(f.Uint()), true
|
|
}
|
|
|
|
// setWhich sets the value of the discriminator field, creating its
|
|
// parent anonymous structs if necessary. Returns whether the struct
|
|
// had a field to set.
|
|
func (sp structProps) setWhich(val reflect.Value, discrim uint16) error {
|
|
if sp.whichLoc.i == -2 {
|
|
if discrim != sp.fixedWhich {
|
|
return fmt.Errorf("extract union field @%d into %v; expected @%d", discrim, val.Type(), sp.fixedWhich)
|
|
}
|
|
return nil
|
|
}
|
|
f := fieldByLoc(val, sp.whichLoc, true)
|
|
if !f.IsValid() {
|
|
return noWhichError{val.Type()}
|
|
}
|
|
f.SetUint(uint64(discrim))
|
|
return nil
|
|
}
|
|
|
|
type noWhichError struct {
|
|
t reflect.Type
|
|
}
|
|
|
|
func (e noWhichError) Error() string {
|
|
return fmt.Sprintf("%v has no field Which", e.t)
|
|
}
|
|
|
|
func isNoWhichError(e error) bool {
|
|
_, ok := e.(noWhichError)
|
|
return ok
|
|
}
|
|
|
|
func fieldByLoc(val reflect.Value, loc fieldLoc, mkparents bool) reflect.Value {
|
|
if !loc.isValid() {
|
|
return reflect.Value{}
|
|
}
|
|
if len(loc.path) > 0 {
|
|
for i, x := range loc.path {
|
|
if i > 0 {
|
|
if val.Kind() == reflect.Ptr {
|
|
if val.IsNil() {
|
|
if !mkparents {
|
|
return reflect.Value{}
|
|
}
|
|
val.Set(reflect.New(val.Type().Elem()))
|
|
}
|
|
val = val.Elem()
|
|
}
|
|
}
|
|
val = val.Field(x)
|
|
}
|
|
return val
|
|
}
|
|
return val.Field(loc.i)
|
|
}
|
|
|
|
func typeFieldByLoc(t reflect.Type, loc fieldLoc) reflect.StructField {
|
|
if len(loc.path) > 0 {
|
|
return t.FieldByIndex(loc.path)
|
|
}
|
|
return t.Field(loc.i)
|
|
}
|
|
|
|
func hasDiscriminant(n schema.Node) bool {
|
|
return n.Which() == schema.Node_Which_structNode && n.StructNode().DiscriminantCount() > 0
|
|
}
|
|
|
|
func shortDisplayName(n schema.Node) []byte {
|
|
dn, _ := n.DisplayNameBytes()
|
|
return dn[n.DisplayNamePrefixLength():]
|
|
}
|
|
|
|
func fieldIndex(fields schema.Field_List, name string) int {
|
|
for i := 0; i < fields.Len(); i++ {
|
|
b, _ := fields.At(i).NameBytes()
|
|
if bytesStrEqual(b, name) {
|
|
return i
|
|
}
|
|
}
|
|
return -1
|
|
}
|
|
|
|
func bytesStrEqual(b []byte, s string) bool {
|
|
if len(b) != len(s) {
|
|
return false
|
|
}
|
|
for i := range b {
|
|
if b[i] != s[i] {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func isStructOrStructPtr(t reflect.Type) bool {
|
|
return t.Kind() == reflect.Struct || t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct
|
|
}
|