// Package schemas provides a container for Cap'n Proto reflection data.
// The code generated by capnpc-go will register its schema in the
// default registry (unless disabled at generation time).
//
// Most programs will use the default registry.  However, a program
// could dynamically build up a registry, perhaps by invoking the capnp
// tool or querying a service.
package schemas

import (
	"bufio"
	"bytes"
	"compress/zlib"
	"errors"
	"fmt"
	"io"
	"io/ioutil"
	"strings"
	"sync"

	"zombiezen.com/go/capnproto2/internal/packed"
)

// A Schema is a collection of schema nodes parsed by the capnp tool.
type Schema struct {
	// Either String or Bytes must be populated with a CodeGeneratorRequest
	// message in the standard Cap'n Proto framing format.
	String string
	Bytes  []byte

	// If true, the input is assumed to be zlib-compressed and packed.
	Compressed bool

	// Node IDs that are contained in this schema.
	Nodes []uint64
}

// A Registry is a mapping of IDs to schema blobs.  It is safe to read
// from multiple goroutines.  The zero value is an empty registry.
type Registry struct {
	m map[uint64]*record
}

// Register indexes a schema in the registry.  It is an error to
// register schemas with overlapping IDs.
func (reg *Registry) Register(s *Schema) error {
	if len(s.String) > 0 && len(s.Bytes) > 0 {
		return errors.New("schemas: schema should have only one of string or bytes")
	}
	r := &record{
		s:          s.String,
		data:       s.Bytes,
		compressed: s.Compressed,
	}
	if reg.m == nil {
		reg.m = make(map[uint64]*record)
	}
	for _, id := range s.Nodes {
		if _, dup := reg.m[id]; dup {
			return &dupeError{id: id}
		}
		reg.m[id] = r
	}
	return nil
}

// Find returns the CodeGeneratorRequest message for the given ID,
// suitable for capnp.Unmarshal.  If the ID is not found, Find returns
// an error that can be identified with IsNotFound.  The returned byte
// slice should not be modified.
func (reg *Registry) Find(id uint64) ([]byte, error) {
	r := reg.m[id]
	if r == nil {
		return nil, &notFoundError{id: id}
	}
	b, err := r.read()
	if err != nil {
		return nil, &decompressError{id, err}
	}
	return b, nil
}

type record struct {
	// All the fields are protected by once.
	once       sync.Once
	s          string // input
	compressed bool
	data       []byte // input and result
	err        error  // result
}

func (r *record) read() ([]byte, error) {
	r.once.Do(func() {
		if !r.compressed {
			if r.s != "" {
				r.data = []byte(r.s)
				r.s = ""
			}
			return
		}
		var in io.Reader
		if r.s != "" {
			in = strings.NewReader(r.s)
			r.s = ""
		} else {
			in = bytes.NewReader(r.data)
		}
		z, err := zlib.NewReader(in)
		if err != nil {
			r.data, r.err = nil, err
			return
		}
		p := packed.NewReader(bufio.NewReader(z))
		r.data, r.err = ioutil.ReadAll(p)
		if r.err != nil {
			r.data = nil
			return
		}
	})
	return r.data, r.err
}

// DefaultRegistry is the process-wide registry used by Register and Find.
var DefaultRegistry Registry

// Register is called by generated code to associate a blob of zlib-
// compressed, packed Cap'n Proto data for a CodeGeneratorRequest with
// the IDs it contains.  It should only be called during init().
func Register(data string, ids ...uint64) {
	err := DefaultRegistry.Register(&Schema{
		String:     data,
		Nodes:      ids,
		Compressed: true,
	})
	if err != nil {
		panic(err)
	}
}

// Find returns the CodeGeneratorRequest message for the given ID,
// suitable for capnp.Unmarshal, or nil if the ID was not found.
// It is safe to call Find from multiple goroutines, so the returned
// byte slice should not be modified.  However, it is not safe to
// call Find concurrently with Register.
func Find(id uint64) []byte {
	b, err := DefaultRegistry.Find(id)
	if IsNotFound(err) {
		return nil
	}
	if err != nil {
		panic(err)
	}
	return b
}

// IsNotFound reports whether e indicates a failure to find a schema.
func IsNotFound(e error) bool {
	_, ok := e.(*notFoundError)
	return ok
}

type dupeError struct {
	id uint64
}

func (e *dupeError) Error() string {
	return fmt.Sprintf("schemas: registered @%#x twice", e.id)
}

type notFoundError struct {
	id uint64
}

func (e *notFoundError) Error() string {
	return fmt.Sprintf("schemas: could not find @%#x", e.id)
}

type decompressError struct {
	id  uint64
	err error
}

func (e *decompressError) Error() string {
	return fmt.Sprintf("schemas: decompressing schema for @%#x: %v", e.id, e.err)
}