package pogs

import (
	"errors"
	"fmt"
	"math"
	"reflect"

	"zombiezen.com/go/capnproto2"
	"zombiezen.com/go/capnproto2/internal/nodemap"
	"zombiezen.com/go/capnproto2/internal/schema"
)

// Extract copies s into val, a pointer to a Go struct.
func Extract(val interface{}, typeID uint64, s capnp.Struct) error {
	e := new(extracter)
	err := e.extractStruct(reflect.ValueOf(val), typeID, s)
	if err != nil {
		return fmt.Errorf("pogs: extract @%#x: %v", typeID, err)
	}
	return nil
}

type extracter struct {
	nodes nodemap.Map
}

var clientType = reflect.TypeOf((*capnp.Client)(nil)).Elem()

func (e *extracter) extractStruct(val reflect.Value, typeID uint64, s capnp.Struct) error {
	if val.Kind() == reflect.Ptr {
		if val.Type().Elem().Kind() != reflect.Struct {
			return fmt.Errorf("can't extract struct into %v", val.Type())
		}
		switch {
		case !val.CanSet() && val.IsNil():
			// Even if the Cap'n Proto pointer isn't valid, this is probably
			// the caller's fault and will be a bug at some point.
			return errors.New("can't extract struct into nil")
		case !s.IsValid() && val.CanSet():
			val.Set(reflect.Zero(val.Type()))
			return nil
		case s.IsValid() && val.CanSet() && val.IsNil():
			val.Set(reflect.New(val.Type().Elem()))
		}
		val = val.Elem()
	} else if val.Kind() != reflect.Struct {
		return fmt.Errorf("can't extract struct into %v", val.Type())
	}
	if !val.CanSet() {
		return errors.New("can't modify struct, did you pass in a pointer to your struct?")
	}
	n, err := e.nodes.Find(typeID)
	if err != nil {
		return err
	}
	if !n.IsValid() || n.Which() != schema.Node_Which_structNode {
		return fmt.Errorf("cannot find struct type %#x", typeID)
	}
	props, err := mapStruct(val.Type(), n)
	if err != nil {
		return fmt.Errorf("can't extract %s: %v", val.Type(), err)
	}
	var discriminant uint16
	hasWhich := false
	if hasDiscriminant(n) {
		discriminant = s.Uint16(capnp.DataOffset(n.StructNode().DiscriminantOffset() * 2))
		if err := props.setWhich(val, discriminant); err == nil {
			hasWhich = true
		} else if !isNoWhichError(err) {
			return err
		}
	}
	fields, err := n.StructNode().Fields()
	if err != nil {
		return err
	}
	for i := 0; i < fields.Len(); i++ {
		f := fields.At(i)
		vf := props.makeFieldByOrdinal(val, i)
		if !vf.IsValid() {
			// Don't have a field for this.
			continue
		}
		if dv := f.DiscriminantValue(); dv != schema.Field_noDiscriminant {
			if !hasWhich {
				return fmt.Errorf("can't extract %s into %v: has union field but no Which field", shortDisplayName(n), val.Type())
			}
			if dv != discriminant {
				continue
			}
		}
		switch f.Which() {
		case schema.Field_Which_slot:
			if err := e.extractField(vf, s, f); err != nil {
				return err
			}
		case schema.Field_Which_group:
			if err := e.extractStruct(vf, f.Group().TypeId(), s); err != nil {
				return err
			}
		}
	}
	return nil
}

func (e *extracter) extractField(val reflect.Value, s capnp.Struct, f schema.Field) error {
	typ, err := f.Slot().Type()
	if err != nil {
		return err
	}
	dv, err := f.Slot().DefaultValue()
	if err != nil {
		return err
	}
	if dv.IsValid() && int(typ.Which()) != int(dv.Which()) {
		name, _ := f.NameBytes()
		return fmt.Errorf("extract field %s: default value is a %v, want %v", name, dv.Which(), typ.Which())
	}
	if !isTypeMatch(val.Type(), typ) {
		name, _ := f.NameBytes()
		return fmt.Errorf("can't extract field %s of type %v into a Go %v", name, typ.Which(), val.Type())
	}
	switch typ.Which() {
	case schema.Type_Which_bool:
		v := s.Bit(capnp.BitOffset(f.Slot().Offset()))
		d := dv.Bool()
		val.SetBool(v != d) // != acts as XOR
	case schema.Type_Which_int8:
		v := int8(s.Uint8(capnp.DataOffset(f.Slot().Offset())))
		d := dv.Int8()
		val.SetInt(int64(v ^ d))
	case schema.Type_Which_int16:
		v := int16(s.Uint16(capnp.DataOffset(f.Slot().Offset() * 2)))
		d := dv.Int16()
		val.SetInt(int64(v ^ d))
	case schema.Type_Which_int32:
		v := int32(s.Uint32(capnp.DataOffset(f.Slot().Offset() * 4)))
		d := dv.Int32()
		val.SetInt(int64(v ^ d))
	case schema.Type_Which_int64:
		v := int64(s.Uint64(capnp.DataOffset(f.Slot().Offset() * 8)))
		d := dv.Int64()
		val.SetInt(v ^ d)
	case schema.Type_Which_uint8:
		v := s.Uint8(capnp.DataOffset(f.Slot().Offset()))
		d := dv.Uint8()
		val.SetUint(uint64(v ^ d))
	case schema.Type_Which_uint16:
		v := s.Uint16(capnp.DataOffset(f.Slot().Offset() * 2))
		d := dv.Uint16()
		val.SetUint(uint64(v ^ d))
	case schema.Type_Which_enum:
		v := s.Uint16(capnp.DataOffset(f.Slot().Offset() * 2))
		d := dv.Enum()
		val.SetUint(uint64(v ^ d))
	case schema.Type_Which_uint32:
		v := s.Uint32(capnp.DataOffset(f.Slot().Offset() * 4))
		d := dv.Uint32()
		val.SetUint(uint64(v ^ d))
	case schema.Type_Which_uint64:
		v := s.Uint64(capnp.DataOffset(f.Slot().Offset() * 8))
		d := dv.Uint64()
		val.SetUint(v ^ d)
	case schema.Type_Which_float32:
		v := s.Uint32(capnp.DataOffset(f.Slot().Offset() * 4))
		d := math.Float32bits(dv.Float32())
		val.SetFloat(float64(math.Float32frombits(v ^ d)))
	case schema.Type_Which_float64:
		v := s.Uint64(capnp.DataOffset(f.Slot().Offset() * 8))
		d := math.Float64bits(dv.Float64())
		val.SetFloat(math.Float64frombits(v ^ d))
	case schema.Type_Which_text:
		p, err := s.Ptr(uint16(f.Slot().Offset()))
		if err != nil {
			return err
		}
		var b []byte
		if p.IsValid() {
			b = p.TextBytes()
		} else {
			b, _ = dv.TextBytes()
		}
		if val.Kind() == reflect.String {
			val.SetString(string(b))
		} else {
			// byte slice, as guaranteed by isTypeMatch
			val.SetBytes(b)
		}
	case schema.Type_Which_data:
		p, err := s.Ptr(uint16(f.Slot().Offset()))
		if err != nil {
			return err
		}
		var b []byte
		if p.IsValid() {
			b = p.Data()
		} else {
			b, _ = dv.Data()
		}
		val.SetBytes(b)
	case schema.Type_Which_structType:
		p, err := s.Ptr(uint16(f.Slot().Offset()))
		if err != nil {
			return err
		}
		ss := p.Struct()
		if !ss.IsValid() {
			p, _ = dv.StructValuePtr()
			ss = p.Struct()
		}
		return e.extractStruct(val, typ.StructType().TypeId(), ss)
	case schema.Type_Which_list:
		p, err := s.Ptr(uint16(f.Slot().Offset()))
		if err != nil {
			return err
		}
		l := p.List()
		if !l.IsValid() {
			p, _ = dv.ListPtr()
			l = p.List()
		}
		return e.extractList(val, typ, l)
	case schema.Type_Which_interface:
		p, err := s.Ptr(uint16(f.Slot().Offset()))
		if err != nil {
			return err
		}
		if val.Type() != clientType {
			// Must be a struct wrapper.
			val = val.FieldByName("Client")
		}

		client := p.Interface().Client()
		if client == nil {
			val.Set(reflect.Zero(val.Type()))
		} else {
			val.Set(reflect.ValueOf(client))
		}
	default:
		return fmt.Errorf("unknown field type %v", typ.Which())
	}
	return nil
}

func (e *extracter) extractList(val reflect.Value, typ schema.Type, l capnp.List) error {
	vt := val.Type()
	elem, err := typ.List().ElementType()
	if err != nil {
		return err
	}
	if !isTypeMatch(vt, typ) {
		// TODO(light): the error won't be that useful for nested lists.
		return fmt.Errorf("can't extract %v list into a Go %v", elem.Which(), vt)
	}
	if !l.IsValid() {
		val.Set(reflect.Zero(vt))
		return nil
	}
	n := l.Len()
	val.Set(reflect.MakeSlice(vt, n, n))
	switch elem.Which() {
	case schema.Type_Which_bool:
		for i := 0; i < n; i++ {
			val.Index(i).SetBool(capnp.BitList{List: l}.At(i))
		}
	case schema.Type_Which_int8:
		for i := 0; i < n; i++ {
			val.Index(i).SetInt(int64(capnp.Int8List{List: l}.At(i)))
		}
	case schema.Type_Which_int16:
		for i := 0; i < n; i++ {
			val.Index(i).SetInt(int64(capnp.Int16List{List: l}.At(i)))
		}
	case schema.Type_Which_int32:
		for i := 0; i < n; i++ {
			val.Index(i).SetInt(int64(capnp.Int32List{List: l}.At(i)))
		}
	case schema.Type_Which_int64:
		for i := 0; i < n; i++ {
			val.Index(i).SetInt(capnp.Int64List{List: l}.At(i))
		}
	case schema.Type_Which_uint8:
		for i := 0; i < n; i++ {
			val.Index(i).SetUint(uint64(capnp.UInt8List{List: l}.At(i)))
		}
	case schema.Type_Which_uint16, schema.Type_Which_enum:
		for i := 0; i < n; i++ {
			val.Index(i).SetUint(uint64(capnp.UInt16List{List: l}.At(i)))
		}
	case schema.Type_Which_uint32:
		for i := 0; i < n; i++ {
			val.Index(i).SetUint(uint64(capnp.UInt32List{List: l}.At(i)))
		}
	case schema.Type_Which_uint64:
		for i := 0; i < n; i++ {
			val.Index(i).SetUint(capnp.UInt64List{List: l}.At(i))
		}
	case schema.Type_Which_float32:
		for i := 0; i < n; i++ {
			val.Index(i).SetFloat(float64(capnp.Float32List{List: l}.At(i)))
		}
	case schema.Type_Which_float64:
		for i := 0; i < n; i++ {
			val.Index(i).SetFloat(capnp.Float64List{List: l}.At(i))
		}
	case schema.Type_Which_text:
		if val.Type().Elem().Kind() == reflect.String {
			for i := 0; i < n; i++ {
				s, err := capnp.TextList{List: l}.At(i)
				if err != nil {
					// TODO(light): collect errors and finish
					return err
				}
				val.Index(i).SetString(s)
			}
		} else {
			for i := 0; i < n; i++ {
				b, err := capnp.TextList{List: l}.BytesAt(i)
				if err != nil {
					// TODO(light): collect errors and finish
					return err
				}
				val.Index(i).SetBytes(b)
			}
		}
	case schema.Type_Which_data:
		for i := 0; i < n; i++ {
			b, err := capnp.DataList{List: l}.At(i)
			if err != nil {
				// TODO(light): collect errors and finish
				return err
			}
			val.Index(i).SetBytes(b)
		}
	case schema.Type_Which_list:
		for i := 0; i < n; i++ {
			p, err := capnp.PointerList{List: l}.PtrAt(i)
			// TODO(light): collect errors and finish
			if err != nil {
				return err
			}
			if err := e.extractList(val.Index(i), elem, p.List()); err != nil {
				return err
			}
		}
	case schema.Type_Which_structType:
		if val.Type().Elem().Kind() == reflect.Struct {
			for i := 0; i < n; i++ {
				err := e.extractStruct(val.Index(i), elem.StructType().TypeId(), l.Struct(i))
				if err != nil {
					return err
				}
			}
		} else {
			for i := 0; i < n; i++ {
				newval := reflect.New(val.Type().Elem().Elem())
				val.Index(i).Set(newval)
				err := e.extractStruct(newval, elem.StructType().TypeId(), l.Struct(i))
				if err != nil {
					return err
				}
			}
		}
	default:
		return fmt.Errorf("unknown list type %v", elem.Which())
	}
	return nil
}

var typeMap = map[schema.Type_Which]reflect.Kind{
	schema.Type_Which_bool:    reflect.Bool,
	schema.Type_Which_int8:    reflect.Int8,
	schema.Type_Which_int16:   reflect.Int16,
	schema.Type_Which_int32:   reflect.Int32,
	schema.Type_Which_int64:   reflect.Int64,
	schema.Type_Which_uint8:   reflect.Uint8,
	schema.Type_Which_uint16:  reflect.Uint16,
	schema.Type_Which_uint32:  reflect.Uint32,
	schema.Type_Which_uint64:  reflect.Uint64,
	schema.Type_Which_float32: reflect.Float32,
	schema.Type_Which_float64: reflect.Float64,
	schema.Type_Which_enum:    reflect.Uint16,
}

func isTypeMatch(r reflect.Type, s schema.Type) bool {
	switch s.Which() {
	case schema.Type_Which_text:
		return r.Kind() == reflect.String || r.Kind() == reflect.Slice && r.Elem().Kind() == reflect.Uint8
	case schema.Type_Which_data:
		return r.Kind() == reflect.Slice && r.Elem().Kind() == reflect.Uint8
	case schema.Type_Which_structType:
		return isStructOrStructPtr(r)
	case schema.Type_Which_list:
		e, _ := s.List().ElementType()
		return r.Kind() == reflect.Slice && isTypeMatch(r.Elem(), e)
	case schema.Type_Which_interface:
		if r == clientType {
			return true
		}

		// Otherwise, the type must be a struct with one element named
		// "Client" of type capnp.Client.
		if r.Kind() != reflect.Struct {
			return false
		}
		if r.NumField() != 1 {
			return false
		}
		field, ok := r.FieldByName("Client")
		if !ok {
			return false
		}
		return field.Type == clientType
	}
	k, ok := typeMap[s.Which()]
	return ok && k == r.Kind()
}