cloudflared-mirror/vendor/github.com/cisco/go-tls-syntax/decode.go

457 lines
9.8 KiB
Go

package syntax
import (
"bytes"
"fmt"
"reflect"
"runtime"
"sync"
)
func Unmarshal(data []byte, v interface{}) (int, error) {
// Check for well-formedness.
// Avoids filling out half a data structure
// before discovering a JSON syntax error.
d := decodeState{}
d.Write(data)
return d.unmarshal(v)
}
// Unmarshaler is the interface implemented by types that can
// unmarshal a TLS description of themselves. Note that unlike the
// JSON unmarshaler interface, it is not known a priori how much of
// the input data will be consumed. So the Unmarshaler must state
// how much of the input data it consumed.
type Unmarshaler interface {
UnmarshalTLS([]byte) (int, error)
}
type decodeState struct {
bytes.Buffer
}
func (d *decodeState) unmarshal(v interface{}) (read int, err error) {
defer func() {
if r := recover(); r != nil {
if _, ok := r.(runtime.Error); ok {
panic(r)
}
if s, ok := r.(string); ok {
panic(s)
}
err = r.(error)
}
}()
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr || rv.IsNil() {
return 0, fmt.Errorf("Invalid unmarshal target (non-pointer or nil)")
}
read = d.value(rv)
return read, nil
}
func (e *decodeState) value(v reflect.Value) int {
return valueDecoder(v)(e, v, fieldOptions{})
}
type decoderFunc func(e *decodeState, v reflect.Value, opts fieldOptions) int
func valueDecoder(v reflect.Value) decoderFunc {
return typeDecoder(v.Type().Elem())
}
var decoderCache sync.Map // map[reflect.Type]decoderFunc
func typeDecoder(t reflect.Type) decoderFunc {
if fi, ok := decoderCache.Load(t); ok {
return fi.(decoderFunc)
}
// XXX(RLB): Wait group based support for recursive types omitted
// Compute the real decoder and replace the indirect func with it.
f := newTypeDecoder(t)
decoderCache.Store(t, f)
return f
}
var (
unmarshalerType = reflect.TypeOf(new(Unmarshaler)).Elem()
uint8Type = reflect.TypeOf(uint8(0))
)
func newTypeDecoder(t reflect.Type) decoderFunc {
var dec decoderFunc
if t.Kind() != reflect.Ptr && reflect.PtrTo(t).Implements(unmarshalerType) {
dec = unmarshalerDecoder
} else {
switch t.Kind() {
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
dec = uintDecoder
case reflect.Array:
dec = newArrayDecoder(t)
case reflect.Slice:
dec = newSliceDecoder(t)
case reflect.Map:
dec = newMapDecoder(t)
case reflect.Struct:
dec = newStructDecoder(t)
case reflect.Ptr:
dec = newPointerDecoder(t)
default:
panic(fmt.Errorf("Unsupported type (%s)", t))
}
}
if reflect.PtrTo(t).Implements(validatorType) {
dec = newValidatorDecoder(dec)
}
return dec
}
///// Specific decoders below
func omitDecoder(d *decodeState, v reflect.Value, opts fieldOptions) int {
return 0
}
//////////
func unmarshalerDecoder(d *decodeState, v reflect.Value, opts fieldOptions) int {
um, ok := v.Interface().(Unmarshaler)
if !ok {
panic(fmt.Errorf("Non-Unmarshaler passed to unmarshalerEncoder"))
}
read, err := um.UnmarshalTLS(d.Bytes())
if err != nil {
panic(err)
}
if read > d.Len() {
panic(fmt.Errorf("Invalid return value from UnmarshalTLS"))
}
d.Next(read)
return read
}
//////////
func newValidatorDecoder(raw decoderFunc) decoderFunc {
return func(d *decodeState, v reflect.Value, opts fieldOptions) int {
read := raw(d, v, opts)
val, ok := v.Interface().(Validator)
if !ok {
panic(fmt.Errorf("Non-Validator passed to validatorDecoder"))
}
if err := val.ValidForTLS(); err != nil {
panic(fmt.Errorf("Decoded invalid TLS value: %v", err))
}
return read
}
}
//////////
func uintDecoder(d *decodeState, v reflect.Value, opts fieldOptions) int {
if opts.varint {
return varintDecoder(d, v, opts)
}
uintLen := int(v.Elem().Type().Size())
buf := d.Next(uintLen)
if len(buf) != uintLen {
panic(fmt.Errorf("Insufficient data to read uint"))
}
return setUintFromBuffer(v, buf)
}
func varintDecoder(d *decodeState, v reflect.Value, opts fieldOptions) int {
l, val := readVarint(d)
uintLen := int(v.Elem().Type().Size())
if uintLen < l {
panic(fmt.Errorf("Uint too small to fit varint: %d < %d", uintLen, l))
}
v.Elem().SetUint(val)
return l
}
func readVarint(d *decodeState) (int, uint64) {
// Read the first octet and decide the size of the presented varint
first := d.Next(1)
if len(first) != 1 {
panic(fmt.Errorf("Insufficient data to read varint length"))
}
twoBits := uint(first[0] >> 6)
varintLen := 1 << twoBits
rest := d.Next(varintLen - 1)
if len(rest) != varintLen-1 {
panic(fmt.Errorf("Insufficient data to read varint"))
}
buf := append(first, rest...)
buf[0] &= 0x3f
return len(buf), decodeUintFromBuffer(buf)
}
func decodeUintFromBuffer(buf []byte) uint64 {
val := uint64(0)
for _, b := range buf {
val = (val << 8) + uint64(b)
}
return val
}
func setUintFromBuffer(v reflect.Value, buf []byte) int {
v.Elem().SetUint(decodeUintFromBuffer(buf))
return len(buf)
}
//////////
type arrayDecoder struct {
elemDec decoderFunc
}
func (ad *arrayDecoder) decode(d *decodeState, v reflect.Value, opts fieldOptions) int {
n := v.Elem().Type().Len()
read := 0
for i := 0; i < n; i += 1 {
read += ad.elemDec(d, v.Elem().Index(i).Addr(), opts)
}
return read
}
func newArrayDecoder(t reflect.Type) decoderFunc {
dec := &arrayDecoder{typeDecoder(t.Elem())}
return dec.decode
}
//////////
func decodeLength(d *decodeState, opts fieldOptions) (int, int) {
read := 0
length := 0
switch {
case opts.omitHeader:
read = 0
length = d.Len()
case opts.varintHeader:
var length64 uint64
read, length64 = readVarint(d)
length = int(length64)
case opts.headerSize > 0:
lengthBytes := d.Next(int(opts.headerSize))
if len(lengthBytes) != int(opts.headerSize) {
panic(fmt.Errorf("Not enough data to read header"))
}
read = len(lengthBytes)
length = int(decodeUintFromBuffer(lengthBytes))
default:
panic(fmt.Errorf("Cannot decode a slice without a header length"))
}
// Check that the length is OK
if opts.maxSize > 0 && length > opts.maxSize {
panic(fmt.Errorf("Length of vector exceeds declared max"))
}
if length < opts.minSize {
panic(fmt.Errorf("Length of vector below declared min"))
}
return read, length
}
//////////
type sliceDecoder struct {
elementType reflect.Type
elementDec decoderFunc
}
func (sd *sliceDecoder) decode(d *decodeState, v reflect.Value, opts fieldOptions) int {
// Determine the length of the vector
read, length := decodeLength(d, opts)
// Decode elements
elemData := d.Next(length)
if len(elemData) != length {
panic(fmt.Errorf("Not enough data to read elements"))
}
// For opaque values, we can return a reference instead of making a new slice
if v.Elem().Type().Elem() == uint8Type {
v.Elem().Set(reflect.ValueOf(elemData))
return read + length
}
// For other values, we need to decode the raw data
elemBuf := &decodeState{}
elemBuf.Write(elemData)
elems := []reflect.Value{}
for elemBuf.Len() > 0 {
elem := reflect.New(sd.elementType)
read += sd.elementDec(elemBuf, elem, opts)
elems = append(elems, elem)
}
v.Elem().Set(reflect.MakeSlice(v.Elem().Type(), len(elems), len(elems)))
for i := 0; i < len(elems); i += 1 {
v.Elem().Index(i).Set(elems[i].Elem())
}
return read
}
func newSliceDecoder(t reflect.Type) decoderFunc {
dec := &sliceDecoder{
elementType: t.Elem(),
elementDec: typeDecoder(t.Elem()),
}
return dec.decode
}
//////////
type mapDecoder struct {
keyType reflect.Type
valType reflect.Type
keyDec decoderFunc
valDec decoderFunc
}
func (md mapDecoder) decode(d *decodeState, v reflect.Value, opts fieldOptions) int {
// Determine the length of the data
read, length := decodeLength(d, opts)
// Decode key/value pairs
elemData := d.Next(length)
if len(elemData) != length {
panic(fmt.Errorf("Not enough data to read elements"))
}
mapType := reflect.MapOf(md.keyType, md.valType)
v.Elem().Set(reflect.MakeMap(mapType))
nullOpts := fieldOptions{}
elemBuf := &decodeState{}
elemBuf.Write(elemData)
for elemBuf.Len() > 0 {
key := reflect.New(md.keyType)
read += md.keyDec(elemBuf, key, nullOpts)
val := reflect.New(md.valType)
read += md.valDec(elemBuf, val, nullOpts)
v.Elem().SetMapIndex(key.Elem(), val.Elem())
}
return read
}
func newMapDecoder(t reflect.Type) decoderFunc {
md := mapDecoder{
keyType: t.Key(),
valType: t.Elem(),
keyDec: typeDecoder(t.Key()),
valDec: typeDecoder(t.Elem()),
}
return md.decode
}
//////////
type structDecoder struct {
fieldOpts []fieldOptions
fieldDecs []decoderFunc
}
func (sd *structDecoder) decode(d *decodeState, v reflect.Value, opts fieldOptions) int {
read := 0
for i := range sd.fieldDecs {
read += sd.fieldDecs[i](d, v.Elem().Field(i).Addr(), sd.fieldOpts[i])
}
return read
}
func newStructDecoder(t reflect.Type) decoderFunc {
n := t.NumField()
sd := structDecoder{
fieldOpts: make([]fieldOptions, n),
fieldDecs: make([]decoderFunc, n),
}
for i := 0; i < n; i += 1 {
f := t.Field(i)
tag := f.Tag.Get("tls")
opts := parseTag(tag)
if !opts.ValidForType(f.Type) {
panic(fmt.Errorf("Tags invalid for field type"))
}
sd.fieldOpts[i] = opts
if sd.fieldOpts[i].omit {
sd.fieldDecs[i] = omitDecoder
} else {
sd.fieldDecs[i] = typeDecoder(f.Type)
}
}
return sd.decode
}
//////////
type pointerDecoder struct {
base decoderFunc
}
func (pd *pointerDecoder) decode(d *decodeState, v reflect.Value, opts fieldOptions) int {
readBase := 0
if opts.optional {
readBase = 1
flag := d.Next(1)
switch flag[0] {
case optionalFlagAbsent:
indir := v.Elem()
indir.Set(reflect.Zero(indir.Type()))
return 1
case optionalFlagPresent:
// No action; continue as normal
default:
panic(fmt.Errorf("Invalid flag byte for optional: [%x]", flag))
}
}
v.Elem().Set(reflect.New(v.Elem().Type().Elem()))
return readBase + pd.base(d, v.Elem(), opts)
}
func newPointerDecoder(t reflect.Type) decoderFunc {
baseDecoder := typeDecoder(t.Elem())
pd := pointerDecoder{base: baseDecoder}
return pd.decode
}