package pogs

import (
	"fmt"
	"reflect"
	"strings"

	"zombiezen.com/go/capnproto2/internal/schema"
)

type fieldProps struct {
	schemaName string // empty if doesn't map to schema
	typ        fieldType
	fixedWhich string
	tagged     bool
}

type fieldType int

const (
	mappedField fieldType = iota
	whichField
	embedField
)

func parseField(f reflect.StructField, hasDiscrim bool) fieldProps {
	var p fieldProps
	tag := f.Tag.Get("capnp")
	p.tagged = tag != ""
	tname, opts := nextOpt(tag)
	switch tname {
	case "-":
		// omitted field
	case "":
		if f.Anonymous && isStructOrStructPtr(f.Type) {
			p.typ = embedField
			return p
		}
		if hasDiscrim && f.Name == "Which" {
			p.typ = whichField
			for len(opts) > 0 {
				var curr string
				curr, opts = nextOpt(opts)
				if strings.HasPrefix(curr, "which=") {
					p.fixedWhich = strings.TrimPrefix(curr, "which=")
					break
				}
			}
			return p
		}
		// TODO(light): check it's uppercase.
		x := f.Name[0] - 'A' + 'a'
		p.schemaName = string(x) + f.Name[1:]
	default:
		p.schemaName = tname
	}
	return p
}

func nextOpt(opts string) (head, tail string) {
	i := strings.Index(opts, ",")
	if i == -1 {
		return opts, ""
	}
	return opts[:i], opts[i+1:]
}

type fieldLoc struct {
	i    int
	path []int
}

func (loc fieldLoc) depth() int {
	if len(loc.path) > 0 {
		return len(loc.path)
	}
	return 1
}

func (loc fieldLoc) sub(i int) fieldLoc {
	n := len(loc.path)
	switch {
	case !loc.isValid():
		return fieldLoc{i: i}
	case n > 0:
		p := make([]int, n+1)
		copy(p, loc.path)
		p[n] = i
		return fieldLoc{path: p}
	default:
		return fieldLoc{path: []int{loc.i, i}}
	}
}

func (loc fieldLoc) isValid() bool {
	return loc.i >= 0
}

type structProps struct {
	fields     []fieldLoc
	whichLoc   fieldLoc // i == -1: none; i == -2: fixed
	fixedWhich uint16
}

func mapStruct(t reflect.Type, n schema.Node) (structProps, error) {
	fields, err := n.StructNode().Fields()
	if err != nil {
		return structProps{}, err
	}
	sp := structProps{
		fields:   make([]fieldLoc, fields.Len()),
		whichLoc: fieldLoc{i: -1},
	}
	for i := range sp.fields {
		sp.fields[i] = fieldLoc{i: -1}
	}
	sm := structMapper{
		sp:         &sp,
		t:          t,
		hasDiscrim: hasDiscriminant(n),
		fields:     fields,
	}
	if err := sm.visit(fieldLoc{i: -1}); err != nil {
		return structProps{}, err
	}
	for len(sm.embedQueue) > 0 {
		loc := sm.embedQueue[0]
		copy(sm.embedQueue, sm.embedQueue[1:])
		sm.embedQueue = sm.embedQueue[:len(sm.embedQueue)-1]
		if err := sm.visit(loc); err != nil {
			return structProps{}, err
		}
	}
	return sp, nil
}

type structMapper struct {
	sp         *structProps
	t          reflect.Type
	hasDiscrim bool
	fields     schema.Field_List
	embedQueue []fieldLoc
}

func (sm *structMapper) visit(base fieldLoc) error {
	t := sm.t
	if base.isValid() {
		t = typeFieldByLoc(t, base).Type
		if t.Kind() == reflect.Ptr {
			t = t.Elem()
		}
	}
	for i := 0; i < t.NumField(); i++ {
		f := t.Field(i)
		if f.PkgPath != "" && !f.Anonymous {
			// unexported field
			continue
		}
		loc := base.sub(i)
		p := parseField(f, sm.hasDiscrim)
		if p.typ == embedField {
			sm.embedQueue = append(sm.embedQueue, loc)
			continue
		}
		if err := sm.visitField(loc, f, p); err != nil {
			return err
		}
	}
	return nil
}

func (sm *structMapper) visitField(loc fieldLoc, f reflect.StructField, p fieldProps) error {
	switch p.typ {
	case mappedField:
		if p.schemaName == "" {
			return nil
		}
		fi := fieldIndex(sm.fields, p.schemaName)
		if fi < 0 {
			return fmt.Errorf("%v has unknown field %s, maps to %s", sm.t, f.Name, p.schemaName)
		}
		switch oldloc := sm.sp.fields[fi]; {
		case oldloc.i == -2:
			// Prior tag collision, do nothing.
		case oldloc.i == -3 && p.tagged && loc.depth() == len(oldloc.path):
			// A tagged field wins over untagged fields.
			sm.sp.fields[fi] = loc
		case oldloc.isValid() && oldloc.depth() < loc.depth():
			// This is deeper, don't override.
		case oldloc.isValid() && oldloc.depth() == loc.depth():
			oldp := parseField(typeFieldByLoc(sm.t, oldloc), sm.hasDiscrim)
			if oldp.tagged && p.tagged {
				// Tag collision
				sm.sp.fields[fi] = fieldLoc{i: -2}
			} else if p.tagged {
				sm.sp.fields[fi] = loc
			} else if !oldp.tagged {
				// Multiple untagged fields.  Keep path, because we need it for depth.
				sm.sp.fields[fi].i = -3
			}
		default:
			sm.sp.fields[fi] = loc
		}
	case whichField:
		if sm.sp.whichLoc.i != -1 {
			return fmt.Errorf("%v embeds multiple Which fields", sm.t)
		}
		switch {
		case p.fixedWhich != "":
			fi := fieldIndex(sm.fields, p.fixedWhich)
			if fi < 0 {
				return fmt.Errorf("%v.Which is tagged with unknown field %s", sm.t, p.fixedWhich)
			}
			dv := sm.fields.At(fi).DiscriminantValue()
			if dv == schema.Field_noDiscriminant {
				return fmt.Errorf("%v.Which is tagged with non-union field %s", sm.t, p.fixedWhich)
			}
			sm.sp.whichLoc = fieldLoc{i: -2}
			sm.sp.fixedWhich = dv
		case f.Type.Kind() != reflect.Uint16:
			return fmt.Errorf("%v.Which is type %v, not uint16", sm.t, f.Type)
		default:
			sm.sp.whichLoc = loc
		}
	}
	return nil
}

// fieldBySchemaName returns the field for the given name.
// Returns an invalid value if the field was not found or it is
// contained inside a nil anonymous struct pointer.
func (sp structProps) fieldByOrdinal(val reflect.Value, i int) reflect.Value {
	return fieldByLoc(val, sp.fields[i], false)
}

// makeFieldBySchemaName returns the field for the given name, creating
// its parent anonymous structs if necessary.  Returns an invalid value
// if the field was not found.
func (sp structProps) makeFieldByOrdinal(val reflect.Value, i int) reflect.Value {
	return fieldByLoc(val, sp.fields[i], true)
}

// which returns the value of the discriminator field.
func (sp structProps) which(val reflect.Value) (discrim uint16, ok bool) {
	if sp.whichLoc.i == -2 {
		return sp.fixedWhich, true
	}
	f := fieldByLoc(val, sp.whichLoc, false)
	if !f.IsValid() {
		return 0, false
	}
	return uint16(f.Uint()), true
}

// setWhich sets the value of the discriminator field, creating its
// parent anonymous structs if necessary.  Returns whether the struct
// had a field to set.
func (sp structProps) setWhich(val reflect.Value, discrim uint16) error {
	if sp.whichLoc.i == -2 {
		if discrim != sp.fixedWhich {
			return fmt.Errorf("extract union field @%d into %v; expected @%d", discrim, val.Type(), sp.fixedWhich)
		}
		return nil
	}
	f := fieldByLoc(val, sp.whichLoc, true)
	if !f.IsValid() {
		return noWhichError{val.Type()}
	}
	f.SetUint(uint64(discrim))
	return nil
}

type noWhichError struct {
	t reflect.Type
}

func (e noWhichError) Error() string {
	return fmt.Sprintf("%v has no field Which", e.t)
}

func isNoWhichError(e error) bool {
	_, ok := e.(noWhichError)
	return ok
}

func fieldByLoc(val reflect.Value, loc fieldLoc, mkparents bool) reflect.Value {
	if !loc.isValid() {
		return reflect.Value{}
	}
	if len(loc.path) > 0 {
		for i, x := range loc.path {
			if i > 0 {
				if val.Kind() == reflect.Ptr {
					if val.IsNil() {
						if !mkparents {
							return reflect.Value{}
						}
						val.Set(reflect.New(val.Type().Elem()))
					}
					val = val.Elem()
				}
			}
			val = val.Field(x)
		}
		return val
	}
	return val.Field(loc.i)
}

func typeFieldByLoc(t reflect.Type, loc fieldLoc) reflect.StructField {
	if len(loc.path) > 0 {
		return t.FieldByIndex(loc.path)
	}
	return t.Field(loc.i)
}

func hasDiscriminant(n schema.Node) bool {
	return n.Which() == schema.Node_Which_structNode && n.StructNode().DiscriminantCount() > 0
}

func shortDisplayName(n schema.Node) []byte {
	dn, _ := n.DisplayNameBytes()
	return dn[n.DisplayNamePrefixLength():]
}

func fieldIndex(fields schema.Field_List, name string) int {
	for i := 0; i < fields.Len(); i++ {
		b, _ := fields.At(i).NameBytes()
		if bytesStrEqual(b, name) {
			return i
		}
	}
	return -1
}

func bytesStrEqual(b []byte, s string) bool {
	if len(b) != len(s) {
		return false
	}
	for i := range b {
		if b[i] != s[i] {
			return false
		}
	}
	return true
}

func isStructOrStructPtr(t reflect.Type) bool {
	return t.Kind() == reflect.Struct || t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct
}