Add go modules vendor
This commit is contained in:
parent
5826732fb4
commit
748eb367fa
|
@ -0,0 +1,5 @@
|
|||
TAGS
|
||||
tags
|
||||
.*.swp
|
||||
tomlcheck/tomlcheck
|
||||
toml.test
|
|
@ -0,0 +1,15 @@
|
|||
language: go
|
||||
go:
|
||||
- 1.1
|
||||
- 1.2
|
||||
- 1.3
|
||||
- 1.4
|
||||
- 1.5
|
||||
- 1.6
|
||||
- tip
|
||||
install:
|
||||
- go install ./...
|
||||
- go get github.com/BurntSushi/toml-test
|
||||
script:
|
||||
- export PATH="$PATH:$HOME/gopath/bin"
|
||||
- make test
|
|
@ -0,0 +1,3 @@
|
|||
Compatible with TOML version
|
||||
[v0.4.0](https://github.com/toml-lang/toml/blob/v0.4.0/versions/en/toml-v0.4.0.md)
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2013 TOML authors
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
|
@ -0,0 +1,19 @@
|
|||
install:
|
||||
go install ./...
|
||||
|
||||
test: install
|
||||
go test -v
|
||||
toml-test toml-test-decoder
|
||||
toml-test -encoder toml-test-encoder
|
||||
|
||||
fmt:
|
||||
gofmt -w *.go */*.go
|
||||
colcheck *.go */*.go
|
||||
|
||||
tags:
|
||||
find ./ -name '*.go' -print0 | xargs -0 gotags > TAGS
|
||||
|
||||
push:
|
||||
git push origin master
|
||||
git push github master
|
||||
|
|
@ -0,0 +1,218 @@
|
|||
## TOML parser and encoder for Go with reflection
|
||||
|
||||
TOML stands for Tom's Obvious, Minimal Language. This Go package provides a
|
||||
reflection interface similar to Go's standard library `json` and `xml`
|
||||
packages. This package also supports the `encoding.TextUnmarshaler` and
|
||||
`encoding.TextMarshaler` interfaces so that you can define custom data
|
||||
representations. (There is an example of this below.)
|
||||
|
||||
Spec: https://github.com/toml-lang/toml
|
||||
|
||||
Compatible with TOML version
|
||||
[v0.4.0](https://github.com/toml-lang/toml/blob/master/versions/en/toml-v0.4.0.md)
|
||||
|
||||
Documentation: https://godoc.org/github.com/BurntSushi/toml
|
||||
|
||||
Installation:
|
||||
|
||||
```bash
|
||||
go get github.com/BurntSushi/toml
|
||||
```
|
||||
|
||||
Try the toml validator:
|
||||
|
||||
```bash
|
||||
go get github.com/BurntSushi/toml/cmd/tomlv
|
||||
tomlv some-toml-file.toml
|
||||
```
|
||||
|
||||
[](https://travis-ci.org/BurntSushi/toml) [](https://godoc.org/github.com/BurntSushi/toml)
|
||||
|
||||
### Testing
|
||||
|
||||
This package passes all tests in
|
||||
[toml-test](https://github.com/BurntSushi/toml-test) for both the decoder
|
||||
and the encoder.
|
||||
|
||||
### Examples
|
||||
|
||||
This package works similarly to how the Go standard library handles `XML`
|
||||
and `JSON`. Namely, data is loaded into Go values via reflection.
|
||||
|
||||
For the simplest example, consider some TOML file as just a list of keys
|
||||
and values:
|
||||
|
||||
```toml
|
||||
Age = 25
|
||||
Cats = [ "Cauchy", "Plato" ]
|
||||
Pi = 3.14
|
||||
Perfection = [ 6, 28, 496, 8128 ]
|
||||
DOB = 1987-07-05T05:45:00Z
|
||||
```
|
||||
|
||||
Which could be defined in Go as:
|
||||
|
||||
```go
|
||||
type Config struct {
|
||||
Age int
|
||||
Cats []string
|
||||
Pi float64
|
||||
Perfection []int
|
||||
DOB time.Time // requires `import time`
|
||||
}
|
||||
```
|
||||
|
||||
And then decoded with:
|
||||
|
||||
```go
|
||||
var conf Config
|
||||
if _, err := toml.Decode(tomlData, &conf); err != nil {
|
||||
// handle error
|
||||
}
|
||||
```
|
||||
|
||||
You can also use struct tags if your struct field name doesn't map to a TOML
|
||||
key value directly:
|
||||
|
||||
```toml
|
||||
some_key_NAME = "wat"
|
||||
```
|
||||
|
||||
```go
|
||||
type TOML struct {
|
||||
ObscureKey string `toml:"some_key_NAME"`
|
||||
}
|
||||
```
|
||||
|
||||
### Using the `encoding.TextUnmarshaler` interface
|
||||
|
||||
Here's an example that automatically parses duration strings into
|
||||
`time.Duration` values:
|
||||
|
||||
```toml
|
||||
[[song]]
|
||||
name = "Thunder Road"
|
||||
duration = "4m49s"
|
||||
|
||||
[[song]]
|
||||
name = "Stairway to Heaven"
|
||||
duration = "8m03s"
|
||||
```
|
||||
|
||||
Which can be decoded with:
|
||||
|
||||
```go
|
||||
type song struct {
|
||||
Name string
|
||||
Duration duration
|
||||
}
|
||||
type songs struct {
|
||||
Song []song
|
||||
}
|
||||
var favorites songs
|
||||
if _, err := toml.Decode(blob, &favorites); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
for _, s := range favorites.Song {
|
||||
fmt.Printf("%s (%s)\n", s.Name, s.Duration)
|
||||
}
|
||||
```
|
||||
|
||||
And you'll also need a `duration` type that satisfies the
|
||||
`encoding.TextUnmarshaler` interface:
|
||||
|
||||
```go
|
||||
type duration struct {
|
||||
time.Duration
|
||||
}
|
||||
|
||||
func (d *duration) UnmarshalText(text []byte) error {
|
||||
var err error
|
||||
d.Duration, err = time.ParseDuration(string(text))
|
||||
return err
|
||||
}
|
||||
```
|
||||
|
||||
### More complex usage
|
||||
|
||||
Here's an example of how to load the example from the official spec page:
|
||||
|
||||
```toml
|
||||
# This is a TOML document. Boom.
|
||||
|
||||
title = "TOML Example"
|
||||
|
||||
[owner]
|
||||
name = "Tom Preston-Werner"
|
||||
organization = "GitHub"
|
||||
bio = "GitHub Cofounder & CEO\nLikes tater tots and beer."
|
||||
dob = 1979-05-27T07:32:00Z # First class dates? Why not?
|
||||
|
||||
[database]
|
||||
server = "192.168.1.1"
|
||||
ports = [ 8001, 8001, 8002 ]
|
||||
connection_max = 5000
|
||||
enabled = true
|
||||
|
||||
[servers]
|
||||
|
||||
# You can indent as you please. Tabs or spaces. TOML don't care.
|
||||
[servers.alpha]
|
||||
ip = "10.0.0.1"
|
||||
dc = "eqdc10"
|
||||
|
||||
[servers.beta]
|
||||
ip = "10.0.0.2"
|
||||
dc = "eqdc10"
|
||||
|
||||
[clients]
|
||||
data = [ ["gamma", "delta"], [1, 2] ] # just an update to make sure parsers support it
|
||||
|
||||
# Line breaks are OK when inside arrays
|
||||
hosts = [
|
||||
"alpha",
|
||||
"omega"
|
||||
]
|
||||
```
|
||||
|
||||
And the corresponding Go types are:
|
||||
|
||||
```go
|
||||
type tomlConfig struct {
|
||||
Title string
|
||||
Owner ownerInfo
|
||||
DB database `toml:"database"`
|
||||
Servers map[string]server
|
||||
Clients clients
|
||||
}
|
||||
|
||||
type ownerInfo struct {
|
||||
Name string
|
||||
Org string `toml:"organization"`
|
||||
Bio string
|
||||
DOB time.Time
|
||||
}
|
||||
|
||||
type database struct {
|
||||
Server string
|
||||
Ports []int
|
||||
ConnMax int `toml:"connection_max"`
|
||||
Enabled bool
|
||||
}
|
||||
|
||||
type server struct {
|
||||
IP string
|
||||
DC string
|
||||
}
|
||||
|
||||
type clients struct {
|
||||
Data [][]interface{}
|
||||
Hosts []string
|
||||
}
|
||||
```
|
||||
|
||||
Note that a case insensitive match will be tried if an exact match can't be
|
||||
found.
|
||||
|
||||
A working example of the above can be found in `_examples/example.{go,toml}`.
|
|
@ -0,0 +1,509 @@
|
|||
package toml
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"math"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func e(format string, args ...interface{}) error {
|
||||
return fmt.Errorf("toml: "+format, args...)
|
||||
}
|
||||
|
||||
// Unmarshaler is the interface implemented by objects that can unmarshal a
|
||||
// TOML description of themselves.
|
||||
type Unmarshaler interface {
|
||||
UnmarshalTOML(interface{}) error
|
||||
}
|
||||
|
||||
// Unmarshal decodes the contents of `p` in TOML format into a pointer `v`.
|
||||
func Unmarshal(p []byte, v interface{}) error {
|
||||
_, err := Decode(string(p), v)
|
||||
return err
|
||||
}
|
||||
|
||||
// Primitive is a TOML value that hasn't been decoded into a Go value.
|
||||
// When using the various `Decode*` functions, the type `Primitive` may
|
||||
// be given to any value, and its decoding will be delayed.
|
||||
//
|
||||
// A `Primitive` value can be decoded using the `PrimitiveDecode` function.
|
||||
//
|
||||
// The underlying representation of a `Primitive` value is subject to change.
|
||||
// Do not rely on it.
|
||||
//
|
||||
// N.B. Primitive values are still parsed, so using them will only avoid
|
||||
// the overhead of reflection. They can be useful when you don't know the
|
||||
// exact type of TOML data until run time.
|
||||
type Primitive struct {
|
||||
undecoded interface{}
|
||||
context Key
|
||||
}
|
||||
|
||||
// DEPRECATED!
|
||||
//
|
||||
// Use MetaData.PrimitiveDecode instead.
|
||||
func PrimitiveDecode(primValue Primitive, v interface{}) error {
|
||||
md := MetaData{decoded: make(map[string]bool)}
|
||||
return md.unify(primValue.undecoded, rvalue(v))
|
||||
}
|
||||
|
||||
// PrimitiveDecode is just like the other `Decode*` functions, except it
|
||||
// decodes a TOML value that has already been parsed. Valid primitive values
|
||||
// can *only* be obtained from values filled by the decoder functions,
|
||||
// including this method. (i.e., `v` may contain more `Primitive`
|
||||
// values.)
|
||||
//
|
||||
// Meta data for primitive values is included in the meta data returned by
|
||||
// the `Decode*` functions with one exception: keys returned by the Undecoded
|
||||
// method will only reflect keys that were decoded. Namely, any keys hidden
|
||||
// behind a Primitive will be considered undecoded. Executing this method will
|
||||
// update the undecoded keys in the meta data. (See the example.)
|
||||
func (md *MetaData) PrimitiveDecode(primValue Primitive, v interface{}) error {
|
||||
md.context = primValue.context
|
||||
defer func() { md.context = nil }()
|
||||
return md.unify(primValue.undecoded, rvalue(v))
|
||||
}
|
||||
|
||||
// Decode will decode the contents of `data` in TOML format into a pointer
|
||||
// `v`.
|
||||
//
|
||||
// TOML hashes correspond to Go structs or maps. (Dealer's choice. They can be
|
||||
// used interchangeably.)
|
||||
//
|
||||
// TOML arrays of tables correspond to either a slice of structs or a slice
|
||||
// of maps.
|
||||
//
|
||||
// TOML datetimes correspond to Go `time.Time` values.
|
||||
//
|
||||
// All other TOML types (float, string, int, bool and array) correspond
|
||||
// to the obvious Go types.
|
||||
//
|
||||
// An exception to the above rules is if a type implements the
|
||||
// encoding.TextUnmarshaler interface. In this case, any primitive TOML value
|
||||
// (floats, strings, integers, booleans and datetimes) will be converted to
|
||||
// a byte string and given to the value's UnmarshalText method. See the
|
||||
// Unmarshaler example for a demonstration with time duration strings.
|
||||
//
|
||||
// Key mapping
|
||||
//
|
||||
// TOML keys can map to either keys in a Go map or field names in a Go
|
||||
// struct. The special `toml` struct tag may be used to map TOML keys to
|
||||
// struct fields that don't match the key name exactly. (See the example.)
|
||||
// A case insensitive match to struct names will be tried if an exact match
|
||||
// can't be found.
|
||||
//
|
||||
// The mapping between TOML values and Go values is loose. That is, there
|
||||
// may exist TOML values that cannot be placed into your representation, and
|
||||
// there may be parts of your representation that do not correspond to
|
||||
// TOML values. This loose mapping can be made stricter by using the IsDefined
|
||||
// and/or Undecoded methods on the MetaData returned.
|
||||
//
|
||||
// This decoder will not handle cyclic types. If a cyclic type is passed,
|
||||
// `Decode` will not terminate.
|
||||
func Decode(data string, v interface{}) (MetaData, error) {
|
||||
rv := reflect.ValueOf(v)
|
||||
if rv.Kind() != reflect.Ptr {
|
||||
return MetaData{}, e("Decode of non-pointer %s", reflect.TypeOf(v))
|
||||
}
|
||||
if rv.IsNil() {
|
||||
return MetaData{}, e("Decode of nil %s", reflect.TypeOf(v))
|
||||
}
|
||||
p, err := parse(data)
|
||||
if err != nil {
|
||||
return MetaData{}, err
|
||||
}
|
||||
md := MetaData{
|
||||
p.mapping, p.types, p.ordered,
|
||||
make(map[string]bool, len(p.ordered)), nil,
|
||||
}
|
||||
return md, md.unify(p.mapping, indirect(rv))
|
||||
}
|
||||
|
||||
// DecodeFile is just like Decode, except it will automatically read the
|
||||
// contents of the file at `fpath` and decode it for you.
|
||||
func DecodeFile(fpath string, v interface{}) (MetaData, error) {
|
||||
bs, err := ioutil.ReadFile(fpath)
|
||||
if err != nil {
|
||||
return MetaData{}, err
|
||||
}
|
||||
return Decode(string(bs), v)
|
||||
}
|
||||
|
||||
// DecodeReader is just like Decode, except it will consume all bytes
|
||||
// from the reader and decode it for you.
|
||||
func DecodeReader(r io.Reader, v interface{}) (MetaData, error) {
|
||||
bs, err := ioutil.ReadAll(r)
|
||||
if err != nil {
|
||||
return MetaData{}, err
|
||||
}
|
||||
return Decode(string(bs), v)
|
||||
}
|
||||
|
||||
// unify performs a sort of type unification based on the structure of `rv`,
|
||||
// which is the client representation.
|
||||
//
|
||||
// Any type mismatch produces an error. Finding a type that we don't know
|
||||
// how to handle produces an unsupported type error.
|
||||
func (md *MetaData) unify(data interface{}, rv reflect.Value) error {
|
||||
|
||||
// Special case. Look for a `Primitive` value.
|
||||
if rv.Type() == reflect.TypeOf((*Primitive)(nil)).Elem() {
|
||||
// Save the undecoded data and the key context into the primitive
|
||||
// value.
|
||||
context := make(Key, len(md.context))
|
||||
copy(context, md.context)
|
||||
rv.Set(reflect.ValueOf(Primitive{
|
||||
undecoded: data,
|
||||
context: context,
|
||||
}))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Special case. Unmarshaler Interface support.
|
||||
if rv.CanAddr() {
|
||||
if v, ok := rv.Addr().Interface().(Unmarshaler); ok {
|
||||
return v.UnmarshalTOML(data)
|
||||
}
|
||||
}
|
||||
|
||||
// Special case. Handle time.Time values specifically.
|
||||
// TODO: Remove this code when we decide to drop support for Go 1.1.
|
||||
// This isn't necessary in Go 1.2 because time.Time satisfies the encoding
|
||||
// interfaces.
|
||||
if rv.Type().AssignableTo(rvalue(time.Time{}).Type()) {
|
||||
return md.unifyDatetime(data, rv)
|
||||
}
|
||||
|
||||
// Special case. Look for a value satisfying the TextUnmarshaler interface.
|
||||
if v, ok := rv.Interface().(TextUnmarshaler); ok {
|
||||
return md.unifyText(data, v)
|
||||
}
|
||||
// BUG(burntsushi)
|
||||
// The behavior here is incorrect whenever a Go type satisfies the
|
||||
// encoding.TextUnmarshaler interface but also corresponds to a TOML
|
||||
// hash or array. In particular, the unmarshaler should only be applied
|
||||
// to primitive TOML values. But at this point, it will be applied to
|
||||
// all kinds of values and produce an incorrect error whenever those values
|
||||
// are hashes or arrays (including arrays of tables).
|
||||
|
||||
k := rv.Kind()
|
||||
|
||||
// laziness
|
||||
if k >= reflect.Int && k <= reflect.Uint64 {
|
||||
return md.unifyInt(data, rv)
|
||||
}
|
||||
switch k {
|
||||
case reflect.Ptr:
|
||||
elem := reflect.New(rv.Type().Elem())
|
||||
err := md.unify(data, reflect.Indirect(elem))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rv.Set(elem)
|
||||
return nil
|
||||
case reflect.Struct:
|
||||
return md.unifyStruct(data, rv)
|
||||
case reflect.Map:
|
||||
return md.unifyMap(data, rv)
|
||||
case reflect.Array:
|
||||
return md.unifyArray(data, rv)
|
||||
case reflect.Slice:
|
||||
return md.unifySlice(data, rv)
|
||||
case reflect.String:
|
||||
return md.unifyString(data, rv)
|
||||
case reflect.Bool:
|
||||
return md.unifyBool(data, rv)
|
||||
case reflect.Interface:
|
||||
// we only support empty interfaces.
|
||||
if rv.NumMethod() > 0 {
|
||||
return e("unsupported type %s", rv.Type())
|
||||
}
|
||||
return md.unifyAnything(data, rv)
|
||||
case reflect.Float32:
|
||||
fallthrough
|
||||
case reflect.Float64:
|
||||
return md.unifyFloat64(data, rv)
|
||||
}
|
||||
return e("unsupported type %s", rv.Kind())
|
||||
}
|
||||
|
||||
func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error {
|
||||
tmap, ok := mapping.(map[string]interface{})
|
||||
if !ok {
|
||||
if mapping == nil {
|
||||
return nil
|
||||
}
|
||||
return e("type mismatch for %s: expected table but found %T",
|
||||
rv.Type().String(), mapping)
|
||||
}
|
||||
|
||||
for key, datum := range tmap {
|
||||
var f *field
|
||||
fields := cachedTypeFields(rv.Type())
|
||||
for i := range fields {
|
||||
ff := &fields[i]
|
||||
if ff.name == key {
|
||||
f = ff
|
||||
break
|
||||
}
|
||||
if f == nil && strings.EqualFold(ff.name, key) {
|
||||
f = ff
|
||||
}
|
||||
}
|
||||
if f != nil {
|
||||
subv := rv
|
||||
for _, i := range f.index {
|
||||
subv = indirect(subv.Field(i))
|
||||
}
|
||||
if isUnifiable(subv) {
|
||||
md.decoded[md.context.add(key).String()] = true
|
||||
md.context = append(md.context, key)
|
||||
if err := md.unify(datum, subv); err != nil {
|
||||
return err
|
||||
}
|
||||
md.context = md.context[0 : len(md.context)-1]
|
||||
} else if f.name != "" {
|
||||
// Bad user! No soup for you!
|
||||
return e("cannot write unexported field %s.%s",
|
||||
rv.Type().String(), f.name)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (md *MetaData) unifyMap(mapping interface{}, rv reflect.Value) error {
|
||||
tmap, ok := mapping.(map[string]interface{})
|
||||
if !ok {
|
||||
if tmap == nil {
|
||||
return nil
|
||||
}
|
||||
return badtype("map", mapping)
|
||||
}
|
||||
if rv.IsNil() {
|
||||
rv.Set(reflect.MakeMap(rv.Type()))
|
||||
}
|
||||
for k, v := range tmap {
|
||||
md.decoded[md.context.add(k).String()] = true
|
||||
md.context = append(md.context, k)
|
||||
|
||||
rvkey := indirect(reflect.New(rv.Type().Key()))
|
||||
rvval := reflect.Indirect(reflect.New(rv.Type().Elem()))
|
||||
if err := md.unify(v, rvval); err != nil {
|
||||
return err
|
||||
}
|
||||
md.context = md.context[0 : len(md.context)-1]
|
||||
|
||||
rvkey.SetString(k)
|
||||
rv.SetMapIndex(rvkey, rvval)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (md *MetaData) unifyArray(data interface{}, rv reflect.Value) error {
|
||||
datav := reflect.ValueOf(data)
|
||||
if datav.Kind() != reflect.Slice {
|
||||
if !datav.IsValid() {
|
||||
return nil
|
||||
}
|
||||
return badtype("slice", data)
|
||||
}
|
||||
sliceLen := datav.Len()
|
||||
if sliceLen != rv.Len() {
|
||||
return e("expected array length %d; got TOML array of length %d",
|
||||
rv.Len(), sliceLen)
|
||||
}
|
||||
return md.unifySliceArray(datav, rv)
|
||||
}
|
||||
|
||||
func (md *MetaData) unifySlice(data interface{}, rv reflect.Value) error {
|
||||
datav := reflect.ValueOf(data)
|
||||
if datav.Kind() != reflect.Slice {
|
||||
if !datav.IsValid() {
|
||||
return nil
|
||||
}
|
||||
return badtype("slice", data)
|
||||
}
|
||||
n := datav.Len()
|
||||
if rv.IsNil() || rv.Cap() < n {
|
||||
rv.Set(reflect.MakeSlice(rv.Type(), n, n))
|
||||
}
|
||||
rv.SetLen(n)
|
||||
return md.unifySliceArray(datav, rv)
|
||||
}
|
||||
|
||||
func (md *MetaData) unifySliceArray(data, rv reflect.Value) error {
|
||||
sliceLen := data.Len()
|
||||
for i := 0; i < sliceLen; i++ {
|
||||
v := data.Index(i).Interface()
|
||||
sliceval := indirect(rv.Index(i))
|
||||
if err := md.unify(v, sliceval); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (md *MetaData) unifyDatetime(data interface{}, rv reflect.Value) error {
|
||||
if _, ok := data.(time.Time); ok {
|
||||
rv.Set(reflect.ValueOf(data))
|
||||
return nil
|
||||
}
|
||||
return badtype("time.Time", data)
|
||||
}
|
||||
|
||||
func (md *MetaData) unifyString(data interface{}, rv reflect.Value) error {
|
||||
if s, ok := data.(string); ok {
|
||||
rv.SetString(s)
|
||||
return nil
|
||||
}
|
||||
return badtype("string", data)
|
||||
}
|
||||
|
||||
func (md *MetaData) unifyFloat64(data interface{}, rv reflect.Value) error {
|
||||
if num, ok := data.(float64); ok {
|
||||
switch rv.Kind() {
|
||||
case reflect.Float32:
|
||||
fallthrough
|
||||
case reflect.Float64:
|
||||
rv.SetFloat(num)
|
||||
default:
|
||||
panic("bug")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return badtype("float", data)
|
||||
}
|
||||
|
||||
func (md *MetaData) unifyInt(data interface{}, rv reflect.Value) error {
|
||||
if num, ok := data.(int64); ok {
|
||||
if rv.Kind() >= reflect.Int && rv.Kind() <= reflect.Int64 {
|
||||
switch rv.Kind() {
|
||||
case reflect.Int, reflect.Int64:
|
||||
// No bounds checking necessary.
|
||||
case reflect.Int8:
|
||||
if num < math.MinInt8 || num > math.MaxInt8 {
|
||||
return e("value %d is out of range for int8", num)
|
||||
}
|
||||
case reflect.Int16:
|
||||
if num < math.MinInt16 || num > math.MaxInt16 {
|
||||
return e("value %d is out of range for int16", num)
|
||||
}
|
||||
case reflect.Int32:
|
||||
if num < math.MinInt32 || num > math.MaxInt32 {
|
||||
return e("value %d is out of range for int32", num)
|
||||
}
|
||||
}
|
||||
rv.SetInt(num)
|
||||
} else if rv.Kind() >= reflect.Uint && rv.Kind() <= reflect.Uint64 {
|
||||
unum := uint64(num)
|
||||
switch rv.Kind() {
|
||||
case reflect.Uint, reflect.Uint64:
|
||||
// No bounds checking necessary.
|
||||
case reflect.Uint8:
|
||||
if num < 0 || unum > math.MaxUint8 {
|
||||
return e("value %d is out of range for uint8", num)
|
||||
}
|
||||
case reflect.Uint16:
|
||||
if num < 0 || unum > math.MaxUint16 {
|
||||
return e("value %d is out of range for uint16", num)
|
||||
}
|
||||
case reflect.Uint32:
|
||||
if num < 0 || unum > math.MaxUint32 {
|
||||
return e("value %d is out of range for uint32", num)
|
||||
}
|
||||
}
|
||||
rv.SetUint(unum)
|
||||
} else {
|
||||
panic("unreachable")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return badtype("integer", data)
|
||||
}
|
||||
|
||||
func (md *MetaData) unifyBool(data interface{}, rv reflect.Value) error {
|
||||
if b, ok := data.(bool); ok {
|
||||
rv.SetBool(b)
|
||||
return nil
|
||||
}
|
||||
return badtype("boolean", data)
|
||||
}
|
||||
|
||||
func (md *MetaData) unifyAnything(data interface{}, rv reflect.Value) error {
|
||||
rv.Set(reflect.ValueOf(data))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (md *MetaData) unifyText(data interface{}, v TextUnmarshaler) error {
|
||||
var s string
|
||||
switch sdata := data.(type) {
|
||||
case TextMarshaler:
|
||||
text, err := sdata.MarshalText()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s = string(text)
|
||||
case fmt.Stringer:
|
||||
s = sdata.String()
|
||||
case string:
|
||||
s = sdata
|
||||
case bool:
|
||||
s = fmt.Sprintf("%v", sdata)
|
||||
case int64:
|
||||
s = fmt.Sprintf("%d", sdata)
|
||||
case float64:
|
||||
s = fmt.Sprintf("%f", sdata)
|
||||
default:
|
||||
return badtype("primitive (string-like)", data)
|
||||
}
|
||||
if err := v.UnmarshalText([]byte(s)); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// rvalue returns a reflect.Value of `v`. All pointers are resolved.
|
||||
func rvalue(v interface{}) reflect.Value {
|
||||
return indirect(reflect.ValueOf(v))
|
||||
}
|
||||
|
||||
// indirect returns the value pointed to by a pointer.
|
||||
// Pointers are followed until the value is not a pointer.
|
||||
// New values are allocated for each nil pointer.
|
||||
//
|
||||
// An exception to this rule is if the value satisfies an interface of
|
||||
// interest to us (like encoding.TextUnmarshaler).
|
||||
func indirect(v reflect.Value) reflect.Value {
|
||||
if v.Kind() != reflect.Ptr {
|
||||
if v.CanSet() {
|
||||
pv := v.Addr()
|
||||
if _, ok := pv.Interface().(TextUnmarshaler); ok {
|
||||
return pv
|
||||
}
|
||||
}
|
||||
return v
|
||||
}
|
||||
if v.IsNil() {
|
||||
v.Set(reflect.New(v.Type().Elem()))
|
||||
}
|
||||
return indirect(reflect.Indirect(v))
|
||||
}
|
||||
|
||||
func isUnifiable(rv reflect.Value) bool {
|
||||
if rv.CanSet() {
|
||||
return true
|
||||
}
|
||||
if _, ok := rv.Interface().(TextUnmarshaler); ok {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func badtype(expected string, data interface{}) error {
|
||||
return e("cannot load TOML value of type %T into a Go %s", data, expected)
|
||||
}
|
|
@ -0,0 +1,121 @@
|
|||
package toml
|
||||
|
||||
import "strings"
|
||||
|
||||
// MetaData allows access to meta information about TOML data that may not
|
||||
// be inferrable via reflection. In particular, whether a key has been defined
|
||||
// and the TOML type of a key.
|
||||
type MetaData struct {
|
||||
mapping map[string]interface{}
|
||||
types map[string]tomlType
|
||||
keys []Key
|
||||
decoded map[string]bool
|
||||
context Key // Used only during decoding.
|
||||
}
|
||||
|
||||
// IsDefined returns true if the key given exists in the TOML data. The key
|
||||
// should be specified hierarchially. e.g.,
|
||||
//
|
||||
// // access the TOML key 'a.b.c'
|
||||
// IsDefined("a", "b", "c")
|
||||
//
|
||||
// IsDefined will return false if an empty key given. Keys are case sensitive.
|
||||
func (md *MetaData) IsDefined(key ...string) bool {
|
||||
if len(key) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
var hash map[string]interface{}
|
||||
var ok bool
|
||||
var hashOrVal interface{} = md.mapping
|
||||
for _, k := range key {
|
||||
if hash, ok = hashOrVal.(map[string]interface{}); !ok {
|
||||
return false
|
||||
}
|
||||
if hashOrVal, ok = hash[k]; !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Type returns a string representation of the type of the key specified.
|
||||
//
|
||||
// Type will return the empty string if given an empty key or a key that
|
||||
// does not exist. Keys are case sensitive.
|
||||
func (md *MetaData) Type(key ...string) string {
|
||||
fullkey := strings.Join(key, ".")
|
||||
if typ, ok := md.types[fullkey]; ok {
|
||||
return typ.typeString()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Key is the type of any TOML key, including key groups. Use (MetaData).Keys
|
||||
// to get values of this type.
|
||||
type Key []string
|
||||
|
||||
func (k Key) String() string {
|
||||
return strings.Join(k, ".")
|
||||
}
|
||||
|
||||
func (k Key) maybeQuotedAll() string {
|
||||
var ss []string
|
||||
for i := range k {
|
||||
ss = append(ss, k.maybeQuoted(i))
|
||||
}
|
||||
return strings.Join(ss, ".")
|
||||
}
|
||||
|
||||
func (k Key) maybeQuoted(i int) string {
|
||||
quote := false
|
||||
for _, c := range k[i] {
|
||||
if !isBareKeyChar(c) {
|
||||
quote = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if quote {
|
||||
return "\"" + strings.Replace(k[i], "\"", "\\\"", -1) + "\""
|
||||
}
|
||||
return k[i]
|
||||
}
|
||||
|
||||
func (k Key) add(piece string) Key {
|
||||
newKey := make(Key, len(k)+1)
|
||||
copy(newKey, k)
|
||||
newKey[len(k)] = piece
|
||||
return newKey
|
||||
}
|
||||
|
||||
// Keys returns a slice of every key in the TOML data, including key groups.
|
||||
// Each key is itself a slice, where the first element is the top of the
|
||||
// hierarchy and the last is the most specific.
|
||||
//
|
||||
// The list will have the same order as the keys appeared in the TOML data.
|
||||
//
|
||||
// All keys returned are non-empty.
|
||||
func (md *MetaData) Keys() []Key {
|
||||
return md.keys
|
||||
}
|
||||
|
||||
// Undecoded returns all keys that have not been decoded in the order in which
|
||||
// they appear in the original TOML document.
|
||||
//
|
||||
// This includes keys that haven't been decoded because of a Primitive value.
|
||||
// Once the Primitive value is decoded, the keys will be considered decoded.
|
||||
//
|
||||
// Also note that decoding into an empty interface will result in no decoding,
|
||||
// and so no keys will be considered decoded.
|
||||
//
|
||||
// In this sense, the Undecoded keys correspond to keys in the TOML document
|
||||
// that do not have a concrete type in your representation.
|
||||
func (md *MetaData) Undecoded() []Key {
|
||||
undecoded := make([]Key, 0, len(md.keys))
|
||||
for _, key := range md.keys {
|
||||
if !md.decoded[key.String()] {
|
||||
undecoded = append(undecoded, key)
|
||||
}
|
||||
}
|
||||
return undecoded
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
/*
|
||||
Package toml provides facilities for decoding and encoding TOML configuration
|
||||
files via reflection. There is also support for delaying decoding with
|
||||
the Primitive type, and querying the set of keys in a TOML document with the
|
||||
MetaData type.
|
||||
|
||||
The specification implemented: https://github.com/toml-lang/toml
|
||||
|
||||
The sub-command github.com/BurntSushi/toml/cmd/tomlv can be used to verify
|
||||
whether a file is a valid TOML document. It can also be used to print the
|
||||
type of each key in a TOML document.
|
||||
|
||||
Testing
|
||||
|
||||
There are two important types of tests used for this package. The first is
|
||||
contained inside '*_test.go' files and uses the standard Go unit testing
|
||||
framework. These tests are primarily devoted to holistically testing the
|
||||
decoder and encoder.
|
||||
|
||||
The second type of testing is used to verify the implementation's adherence
|
||||
to the TOML specification. These tests have been factored into their own
|
||||
project: https://github.com/BurntSushi/toml-test
|
||||
|
||||
The reason the tests are in a separate project is so that they can be used by
|
||||
any implementation of TOML. Namely, it is language agnostic.
|
||||
*/
|
||||
package toml
|
|
@ -0,0 +1,568 @@
|
|||
package toml
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type tomlEncodeError struct{ error }
|
||||
|
||||
var (
|
||||
errArrayMixedElementTypes = errors.New(
|
||||
"toml: cannot encode array with mixed element types")
|
||||
errArrayNilElement = errors.New(
|
||||
"toml: cannot encode array with nil element")
|
||||
errNonString = errors.New(
|
||||
"toml: cannot encode a map with non-string key type")
|
||||
errAnonNonStruct = errors.New(
|
||||
"toml: cannot encode an anonymous field that is not a struct")
|
||||
errArrayNoTable = errors.New(
|
||||
"toml: TOML array element cannot contain a table")
|
||||
errNoKey = errors.New(
|
||||
"toml: top-level values must be Go maps or structs")
|
||||
errAnything = errors.New("") // used in testing
|
||||
)
|
||||
|
||||
var quotedReplacer = strings.NewReplacer(
|
||||
"\t", "\\t",
|
||||
"\n", "\\n",
|
||||
"\r", "\\r",
|
||||
"\"", "\\\"",
|
||||
"\\", "\\\\",
|
||||
)
|
||||
|
||||
// Encoder controls the encoding of Go values to a TOML document to some
|
||||
// io.Writer.
|
||||
//
|
||||
// The indentation level can be controlled with the Indent field.
|
||||
type Encoder struct {
|
||||
// A single indentation level. By default it is two spaces.
|
||||
Indent string
|
||||
|
||||
// hasWritten is whether we have written any output to w yet.
|
||||
hasWritten bool
|
||||
w *bufio.Writer
|
||||
}
|
||||
|
||||
// NewEncoder returns a TOML encoder that encodes Go values to the io.Writer
|
||||
// given. By default, a single indentation level is 2 spaces.
|
||||
func NewEncoder(w io.Writer) *Encoder {
|
||||
return &Encoder{
|
||||
w: bufio.NewWriter(w),
|
||||
Indent: " ",
|
||||
}
|
||||
}
|
||||
|
||||
// Encode writes a TOML representation of the Go value to the underlying
|
||||
// io.Writer. If the value given cannot be encoded to a valid TOML document,
|
||||
// then an error is returned.
|
||||
//
|
||||
// The mapping between Go values and TOML values should be precisely the same
|
||||
// as for the Decode* functions. Similarly, the TextMarshaler interface is
|
||||
// supported by encoding the resulting bytes as strings. (If you want to write
|
||||
// arbitrary binary data then you will need to use something like base64 since
|
||||
// TOML does not have any binary types.)
|
||||
//
|
||||
// When encoding TOML hashes (i.e., Go maps or structs), keys without any
|
||||
// sub-hashes are encoded first.
|
||||
//
|
||||
// If a Go map is encoded, then its keys are sorted alphabetically for
|
||||
// deterministic output. More control over this behavior may be provided if
|
||||
// there is demand for it.
|
||||
//
|
||||
// Encoding Go values without a corresponding TOML representation---like map
|
||||
// types with non-string keys---will cause an error to be returned. Similarly
|
||||
// for mixed arrays/slices, arrays/slices with nil elements, embedded
|
||||
// non-struct types and nested slices containing maps or structs.
|
||||
// (e.g., [][]map[string]string is not allowed but []map[string]string is OK
|
||||
// and so is []map[string][]string.)
|
||||
func (enc *Encoder) Encode(v interface{}) error {
|
||||
rv := eindirect(reflect.ValueOf(v))
|
||||
if err := enc.safeEncode(Key([]string{}), rv); err != nil {
|
||||
return err
|
||||
}
|
||||
return enc.w.Flush()
|
||||
}
|
||||
|
||||
func (enc *Encoder) safeEncode(key Key, rv reflect.Value) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
if terr, ok := r.(tomlEncodeError); ok {
|
||||
err = terr.error
|
||||
return
|
||||
}
|
||||
panic(r)
|
||||
}
|
||||
}()
|
||||
enc.encode(key, rv)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (enc *Encoder) encode(key Key, rv reflect.Value) {
|
||||
// Special case. Time needs to be in ISO8601 format.
|
||||
// Special case. If we can marshal the type to text, then we used that.
|
||||
// Basically, this prevents the encoder for handling these types as
|
||||
// generic structs (or whatever the underlying type of a TextMarshaler is).
|
||||
switch rv.Interface().(type) {
|
||||
case time.Time, TextMarshaler:
|
||||
enc.keyEqElement(key, rv)
|
||||
return
|
||||
}
|
||||
|
||||
k := rv.Kind()
|
||||
switch k {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
|
||||
reflect.Int64,
|
||||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
|
||||
reflect.Uint64,
|
||||
reflect.Float32, reflect.Float64, reflect.String, reflect.Bool:
|
||||
enc.keyEqElement(key, rv)
|
||||
case reflect.Array, reflect.Slice:
|
||||
if typeEqual(tomlArrayHash, tomlTypeOfGo(rv)) {
|
||||
enc.eArrayOfTables(key, rv)
|
||||
} else {
|
||||
enc.keyEqElement(key, rv)
|
||||
}
|
||||
case reflect.Interface:
|
||||
if rv.IsNil() {
|
||||
return
|
||||
}
|
||||
enc.encode(key, rv.Elem())
|
||||
case reflect.Map:
|
||||
if rv.IsNil() {
|
||||
return
|
||||
}
|
||||
enc.eTable(key, rv)
|
||||
case reflect.Ptr:
|
||||
if rv.IsNil() {
|
||||
return
|
||||
}
|
||||
enc.encode(key, rv.Elem())
|
||||
case reflect.Struct:
|
||||
enc.eTable(key, rv)
|
||||
default:
|
||||
panic(e("unsupported type for key '%s': %s", key, k))
|
||||
}
|
||||
}
|
||||
|
||||
// eElement encodes any value that can be an array element (primitives and
|
||||
// arrays).
|
||||
func (enc *Encoder) eElement(rv reflect.Value) {
|
||||
switch v := rv.Interface().(type) {
|
||||
case time.Time:
|
||||
// Special case time.Time as a primitive. Has to come before
|
||||
// TextMarshaler below because time.Time implements
|
||||
// encoding.TextMarshaler, but we need to always use UTC.
|
||||
enc.wf(v.UTC().Format("2006-01-02T15:04:05Z"))
|
||||
return
|
||||
case TextMarshaler:
|
||||
// Special case. Use text marshaler if it's available for this value.
|
||||
if s, err := v.MarshalText(); err != nil {
|
||||
encPanic(err)
|
||||
} else {
|
||||
enc.writeQuoted(string(s))
|
||||
}
|
||||
return
|
||||
}
|
||||
switch rv.Kind() {
|
||||
case reflect.Bool:
|
||||
enc.wf(strconv.FormatBool(rv.Bool()))
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
|
||||
reflect.Int64:
|
||||
enc.wf(strconv.FormatInt(rv.Int(), 10))
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16,
|
||||
reflect.Uint32, reflect.Uint64:
|
||||
enc.wf(strconv.FormatUint(rv.Uint(), 10))
|
||||
case reflect.Float32:
|
||||
enc.wf(floatAddDecimal(strconv.FormatFloat(rv.Float(), 'f', -1, 32)))
|
||||
case reflect.Float64:
|
||||
enc.wf(floatAddDecimal(strconv.FormatFloat(rv.Float(), 'f', -1, 64)))
|
||||
case reflect.Array, reflect.Slice:
|
||||
enc.eArrayOrSliceElement(rv)
|
||||
case reflect.Interface:
|
||||
enc.eElement(rv.Elem())
|
||||
case reflect.String:
|
||||
enc.writeQuoted(rv.String())
|
||||
default:
|
||||
panic(e("unexpected primitive type: %s", rv.Kind()))
|
||||
}
|
||||
}
|
||||
|
||||
// By the TOML spec, all floats must have a decimal with at least one
|
||||
// number on either side.
|
||||
func floatAddDecimal(fstr string) string {
|
||||
if !strings.Contains(fstr, ".") {
|
||||
return fstr + ".0"
|
||||
}
|
||||
return fstr
|
||||
}
|
||||
|
||||
func (enc *Encoder) writeQuoted(s string) {
|
||||
enc.wf("\"%s\"", quotedReplacer.Replace(s))
|
||||
}
|
||||
|
||||
func (enc *Encoder) eArrayOrSliceElement(rv reflect.Value) {
|
||||
length := rv.Len()
|
||||
enc.wf("[")
|
||||
for i := 0; i < length; i++ {
|
||||
elem := rv.Index(i)
|
||||
enc.eElement(elem)
|
||||
if i != length-1 {
|
||||
enc.wf(", ")
|
||||
}
|
||||
}
|
||||
enc.wf("]")
|
||||
}
|
||||
|
||||
func (enc *Encoder) eArrayOfTables(key Key, rv reflect.Value) {
|
||||
if len(key) == 0 {
|
||||
encPanic(errNoKey)
|
||||
}
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
trv := rv.Index(i)
|
||||
if isNil(trv) {
|
||||
continue
|
||||
}
|
||||
panicIfInvalidKey(key)
|
||||
enc.newline()
|
||||
enc.wf("%s[[%s]]", enc.indentStr(key), key.maybeQuotedAll())
|
||||
enc.newline()
|
||||
enc.eMapOrStruct(key, trv)
|
||||
}
|
||||
}
|
||||
|
||||
func (enc *Encoder) eTable(key Key, rv reflect.Value) {
|
||||
panicIfInvalidKey(key)
|
||||
if len(key) == 1 {
|
||||
// Output an extra newline between top-level tables.
|
||||
// (The newline isn't written if nothing else has been written though.)
|
||||
enc.newline()
|
||||
}
|
||||
if len(key) > 0 {
|
||||
enc.wf("%s[%s]", enc.indentStr(key), key.maybeQuotedAll())
|
||||
enc.newline()
|
||||
}
|
||||
enc.eMapOrStruct(key, rv)
|
||||
}
|
||||
|
||||
func (enc *Encoder) eMapOrStruct(key Key, rv reflect.Value) {
|
||||
switch rv := eindirect(rv); rv.Kind() {
|
||||
case reflect.Map:
|
||||
enc.eMap(key, rv)
|
||||
case reflect.Struct:
|
||||
enc.eStruct(key, rv)
|
||||
default:
|
||||
panic("eTable: unhandled reflect.Value Kind: " + rv.Kind().String())
|
||||
}
|
||||
}
|
||||
|
||||
func (enc *Encoder) eMap(key Key, rv reflect.Value) {
|
||||
rt := rv.Type()
|
||||
if rt.Key().Kind() != reflect.String {
|
||||
encPanic(errNonString)
|
||||
}
|
||||
|
||||
// Sort keys so that we have deterministic output. And write keys directly
|
||||
// underneath this key first, before writing sub-structs or sub-maps.
|
||||
var mapKeysDirect, mapKeysSub []string
|
||||
for _, mapKey := range rv.MapKeys() {
|
||||
k := mapKey.String()
|
||||
if typeIsHash(tomlTypeOfGo(rv.MapIndex(mapKey))) {
|
||||
mapKeysSub = append(mapKeysSub, k)
|
||||
} else {
|
||||
mapKeysDirect = append(mapKeysDirect, k)
|
||||
}
|
||||
}
|
||||
|
||||
var writeMapKeys = func(mapKeys []string) {
|
||||
sort.Strings(mapKeys)
|
||||
for _, mapKey := range mapKeys {
|
||||
mrv := rv.MapIndex(reflect.ValueOf(mapKey))
|
||||
if isNil(mrv) {
|
||||
// Don't write anything for nil fields.
|
||||
continue
|
||||
}
|
||||
enc.encode(key.add(mapKey), mrv)
|
||||
}
|
||||
}
|
||||
writeMapKeys(mapKeysDirect)
|
||||
writeMapKeys(mapKeysSub)
|
||||
}
|
||||
|
||||
func (enc *Encoder) eStruct(key Key, rv reflect.Value) {
|
||||
// Write keys for fields directly under this key first, because if we write
|
||||
// a field that creates a new table, then all keys under it will be in that
|
||||
// table (not the one we're writing here).
|
||||
rt := rv.Type()
|
||||
var fieldsDirect, fieldsSub [][]int
|
||||
var addFields func(rt reflect.Type, rv reflect.Value, start []int)
|
||||
addFields = func(rt reflect.Type, rv reflect.Value, start []int) {
|
||||
for i := 0; i < rt.NumField(); i++ {
|
||||
f := rt.Field(i)
|
||||
// skip unexported fields
|
||||
if f.PkgPath != "" && !f.Anonymous {
|
||||
continue
|
||||
}
|
||||
frv := rv.Field(i)
|
||||
if f.Anonymous {
|
||||
t := f.Type
|
||||
switch t.Kind() {
|
||||
case reflect.Struct:
|
||||
// Treat anonymous struct fields with
|
||||
// tag names as though they are not
|
||||
// anonymous, like encoding/json does.
|
||||
if getOptions(f.Tag).name == "" {
|
||||
addFields(t, frv, f.Index)
|
||||
continue
|
||||
}
|
||||
case reflect.Ptr:
|
||||
if t.Elem().Kind() == reflect.Struct &&
|
||||
getOptions(f.Tag).name == "" {
|
||||
if !frv.IsNil() {
|
||||
addFields(t.Elem(), frv.Elem(), f.Index)
|
||||
}
|
||||
continue
|
||||
}
|
||||
// Fall through to the normal field encoding logic below
|
||||
// for non-struct anonymous fields.
|
||||
}
|
||||
}
|
||||
|
||||
if typeIsHash(tomlTypeOfGo(frv)) {
|
||||
fieldsSub = append(fieldsSub, append(start, f.Index...))
|
||||
} else {
|
||||
fieldsDirect = append(fieldsDirect, append(start, f.Index...))
|
||||
}
|
||||
}
|
||||
}
|
||||
addFields(rt, rv, nil)
|
||||
|
||||
var writeFields = func(fields [][]int) {
|
||||
for _, fieldIndex := range fields {
|
||||
sft := rt.FieldByIndex(fieldIndex)
|
||||
sf := rv.FieldByIndex(fieldIndex)
|
||||
if isNil(sf) {
|
||||
// Don't write anything for nil fields.
|
||||
continue
|
||||
}
|
||||
|
||||
opts := getOptions(sft.Tag)
|
||||
if opts.skip {
|
||||
continue
|
||||
}
|
||||
keyName := sft.Name
|
||||
if opts.name != "" {
|
||||
keyName = opts.name
|
||||
}
|
||||
if opts.omitempty && isEmpty(sf) {
|
||||
continue
|
||||
}
|
||||
if opts.omitzero && isZero(sf) {
|
||||
continue
|
||||
}
|
||||
|
||||
enc.encode(key.add(keyName), sf)
|
||||
}
|
||||
}
|
||||
writeFields(fieldsDirect)
|
||||
writeFields(fieldsSub)
|
||||
}
|
||||
|
||||
// tomlTypeName returns the TOML type name of the Go value's type. It is
|
||||
// used to determine whether the types of array elements are mixed (which is
|
||||
// forbidden). If the Go value is nil, then it is illegal for it to be an array
|
||||
// element, and valueIsNil is returned as true.
|
||||
|
||||
// Returns the TOML type of a Go value. The type may be `nil`, which means
|
||||
// no concrete TOML type could be found.
|
||||
func tomlTypeOfGo(rv reflect.Value) tomlType {
|
||||
if isNil(rv) || !rv.IsValid() {
|
||||
return nil
|
||||
}
|
||||
switch rv.Kind() {
|
||||
case reflect.Bool:
|
||||
return tomlBool
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
|
||||
reflect.Int64,
|
||||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
|
||||
reflect.Uint64:
|
||||
return tomlInteger
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return tomlFloat
|
||||
case reflect.Array, reflect.Slice:
|
||||
if typeEqual(tomlHash, tomlArrayType(rv)) {
|
||||
return tomlArrayHash
|
||||
}
|
||||
return tomlArray
|
||||
case reflect.Ptr, reflect.Interface:
|
||||
return tomlTypeOfGo(rv.Elem())
|
||||
case reflect.String:
|
||||
return tomlString
|
||||
case reflect.Map:
|
||||
return tomlHash
|
||||
case reflect.Struct:
|
||||
switch rv.Interface().(type) {
|
||||
case time.Time:
|
||||
return tomlDatetime
|
||||
case TextMarshaler:
|
||||
return tomlString
|
||||
default:
|
||||
return tomlHash
|
||||
}
|
||||
default:
|
||||
panic("unexpected reflect.Kind: " + rv.Kind().String())
|
||||
}
|
||||
}
|
||||
|
||||
// tomlArrayType returns the element type of a TOML array. The type returned
|
||||
// may be nil if it cannot be determined (e.g., a nil slice or a zero length
|
||||
// slize). This function may also panic if it finds a type that cannot be
|
||||
// expressed in TOML (such as nil elements, heterogeneous arrays or directly
|
||||
// nested arrays of tables).
|
||||
func tomlArrayType(rv reflect.Value) tomlType {
|
||||
if isNil(rv) || !rv.IsValid() || rv.Len() == 0 {
|
||||
return nil
|
||||
}
|
||||
firstType := tomlTypeOfGo(rv.Index(0))
|
||||
if firstType == nil {
|
||||
encPanic(errArrayNilElement)
|
||||
}
|
||||
|
||||
rvlen := rv.Len()
|
||||
for i := 1; i < rvlen; i++ {
|
||||
elem := rv.Index(i)
|
||||
switch elemType := tomlTypeOfGo(elem); {
|
||||
case elemType == nil:
|
||||
encPanic(errArrayNilElement)
|
||||
case !typeEqual(firstType, elemType):
|
||||
encPanic(errArrayMixedElementTypes)
|
||||
}
|
||||
}
|
||||
// If we have a nested array, then we must make sure that the nested
|
||||
// array contains ONLY primitives.
|
||||
// This checks arbitrarily nested arrays.
|
||||
if typeEqual(firstType, tomlArray) || typeEqual(firstType, tomlArrayHash) {
|
||||
nest := tomlArrayType(eindirect(rv.Index(0)))
|
||||
if typeEqual(nest, tomlHash) || typeEqual(nest, tomlArrayHash) {
|
||||
encPanic(errArrayNoTable)
|
||||
}
|
||||
}
|
||||
return firstType
|
||||
}
|
||||
|
||||
type tagOptions struct {
|
||||
skip bool // "-"
|
||||
name string
|
||||
omitempty bool
|
||||
omitzero bool
|
||||
}
|
||||
|
||||
func getOptions(tag reflect.StructTag) tagOptions {
|
||||
t := tag.Get("toml")
|
||||
if t == "-" {
|
||||
return tagOptions{skip: true}
|
||||
}
|
||||
var opts tagOptions
|
||||
parts := strings.Split(t, ",")
|
||||
opts.name = parts[0]
|
||||
for _, s := range parts[1:] {
|
||||
switch s {
|
||||
case "omitempty":
|
||||
opts.omitempty = true
|
||||
case "omitzero":
|
||||
opts.omitzero = true
|
||||
}
|
||||
}
|
||||
return opts
|
||||
}
|
||||
|
||||
func isZero(rv reflect.Value) bool {
|
||||
switch rv.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return rv.Int() == 0
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
return rv.Uint() == 0
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return rv.Float() == 0.0
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isEmpty(rv reflect.Value) bool {
|
||||
switch rv.Kind() {
|
||||
case reflect.Array, reflect.Slice, reflect.Map, reflect.String:
|
||||
return rv.Len() == 0
|
||||
case reflect.Bool:
|
||||
return !rv.Bool()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (enc *Encoder) newline() {
|
||||
if enc.hasWritten {
|
||||
enc.wf("\n")
|
||||
}
|
||||
}
|
||||
|
||||
func (enc *Encoder) keyEqElement(key Key, val reflect.Value) {
|
||||
if len(key) == 0 {
|
||||
encPanic(errNoKey)
|
||||
}
|
||||
panicIfInvalidKey(key)
|
||||
enc.wf("%s%s = ", enc.indentStr(key), key.maybeQuoted(len(key)-1))
|
||||
enc.eElement(val)
|
||||
enc.newline()
|
||||
}
|
||||
|
||||
func (enc *Encoder) wf(format string, v ...interface{}) {
|
||||
if _, err := fmt.Fprintf(enc.w, format, v...); err != nil {
|
||||
encPanic(err)
|
||||
}
|
||||
enc.hasWritten = true
|
||||
}
|
||||
|
||||
func (enc *Encoder) indentStr(key Key) string {
|
||||
return strings.Repeat(enc.Indent, len(key)-1)
|
||||
}
|
||||
|
||||
func encPanic(err error) {
|
||||
panic(tomlEncodeError{err})
|
||||
}
|
||||
|
||||
func eindirect(v reflect.Value) reflect.Value {
|
||||
switch v.Kind() {
|
||||
case reflect.Ptr, reflect.Interface:
|
||||
return eindirect(v.Elem())
|
||||
default:
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
func isNil(rv reflect.Value) bool {
|
||||
switch rv.Kind() {
|
||||
case reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
|
||||
return rv.IsNil()
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func panicIfInvalidKey(key Key) {
|
||||
for _, k := range key {
|
||||
if len(k) == 0 {
|
||||
encPanic(e("Key '%s' is not a valid table name. Key names "+
|
||||
"cannot be empty.", key.maybeQuotedAll()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isValidKeyName(s string) bool {
|
||||
return len(s) != 0
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
// +build go1.2
|
||||
|
||||
package toml
|
||||
|
||||
// In order to support Go 1.1, we define our own TextMarshaler and
|
||||
// TextUnmarshaler types. For Go 1.2+, we just alias them with the
|
||||
// standard library interfaces.
|
||||
|
||||
import (
|
||||
"encoding"
|
||||
)
|
||||
|
||||
// TextMarshaler is a synonym for encoding.TextMarshaler. It is defined here
|
||||
// so that Go 1.1 can be supported.
|
||||
type TextMarshaler encoding.TextMarshaler
|
||||
|
||||
// TextUnmarshaler is a synonym for encoding.TextUnmarshaler. It is defined
|
||||
// here so that Go 1.1 can be supported.
|
||||
type TextUnmarshaler encoding.TextUnmarshaler
|
|
@ -0,0 +1,18 @@
|
|||
// +build !go1.2
|
||||
|
||||
package toml
|
||||
|
||||
// These interfaces were introduced in Go 1.2, so we add them manually when
|
||||
// compiling for Go 1.1.
|
||||
|
||||
// TextMarshaler is a synonym for encoding.TextMarshaler. It is defined here
|
||||
// so that Go 1.1 can be supported.
|
||||
type TextMarshaler interface {
|
||||
MarshalText() (text []byte, err error)
|
||||
}
|
||||
|
||||
// TextUnmarshaler is a synonym for encoding.TextUnmarshaler. It is defined
|
||||
// here so that Go 1.1 can be supported.
|
||||
type TextUnmarshaler interface {
|
||||
UnmarshalText(text []byte) error
|
||||
}
|
|
@ -0,0 +1,953 @@
|
|||
package toml
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
type itemType int
|
||||
|
||||
const (
|
||||
itemError itemType = iota
|
||||
itemNIL // used in the parser to indicate no type
|
||||
itemEOF
|
||||
itemText
|
||||
itemString
|
||||
itemRawString
|
||||
itemMultilineString
|
||||
itemRawMultilineString
|
||||
itemBool
|
||||
itemInteger
|
||||
itemFloat
|
||||
itemDatetime
|
||||
itemArray // the start of an array
|
||||
itemArrayEnd
|
||||
itemTableStart
|
||||
itemTableEnd
|
||||
itemArrayTableStart
|
||||
itemArrayTableEnd
|
||||
itemKeyStart
|
||||
itemCommentStart
|
||||
itemInlineTableStart
|
||||
itemInlineTableEnd
|
||||
)
|
||||
|
||||
const (
|
||||
eof = 0
|
||||
comma = ','
|
||||
tableStart = '['
|
||||
tableEnd = ']'
|
||||
arrayTableStart = '['
|
||||
arrayTableEnd = ']'
|
||||
tableSep = '.'
|
||||
keySep = '='
|
||||
arrayStart = '['
|
||||
arrayEnd = ']'
|
||||
commentStart = '#'
|
||||
stringStart = '"'
|
||||
stringEnd = '"'
|
||||
rawStringStart = '\''
|
||||
rawStringEnd = '\''
|
||||
inlineTableStart = '{'
|
||||
inlineTableEnd = '}'
|
||||
)
|
||||
|
||||
type stateFn func(lx *lexer) stateFn
|
||||
|
||||
type lexer struct {
|
||||
input string
|
||||
start int
|
||||
pos int
|
||||
line int
|
||||
state stateFn
|
||||
items chan item
|
||||
|
||||
// Allow for backing up up to three runes.
|
||||
// This is necessary because TOML contains 3-rune tokens (""" and ''').
|
||||
prevWidths [3]int
|
||||
nprev int // how many of prevWidths are in use
|
||||
// If we emit an eof, we can still back up, but it is not OK to call
|
||||
// next again.
|
||||
atEOF bool
|
||||
|
||||
// A stack of state functions used to maintain context.
|
||||
// The idea is to reuse parts of the state machine in various places.
|
||||
// For example, values can appear at the top level or within arbitrarily
|
||||
// nested arrays. The last state on the stack is used after a value has
|
||||
// been lexed. Similarly for comments.
|
||||
stack []stateFn
|
||||
}
|
||||
|
||||
type item struct {
|
||||
typ itemType
|
||||
val string
|
||||
line int
|
||||
}
|
||||
|
||||
func (lx *lexer) nextItem() item {
|
||||
for {
|
||||
select {
|
||||
case item := <-lx.items:
|
||||
return item
|
||||
default:
|
||||
lx.state = lx.state(lx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func lex(input string) *lexer {
|
||||
lx := &lexer{
|
||||
input: input,
|
||||
state: lexTop,
|
||||
line: 1,
|
||||
items: make(chan item, 10),
|
||||
stack: make([]stateFn, 0, 10),
|
||||
}
|
||||
return lx
|
||||
}
|
||||
|
||||
func (lx *lexer) push(state stateFn) {
|
||||
lx.stack = append(lx.stack, state)
|
||||
}
|
||||
|
||||
func (lx *lexer) pop() stateFn {
|
||||
if len(lx.stack) == 0 {
|
||||
return lx.errorf("BUG in lexer: no states to pop")
|
||||
}
|
||||
last := lx.stack[len(lx.stack)-1]
|
||||
lx.stack = lx.stack[0 : len(lx.stack)-1]
|
||||
return last
|
||||
}
|
||||
|
||||
func (lx *lexer) current() string {
|
||||
return lx.input[lx.start:lx.pos]
|
||||
}
|
||||
|
||||
func (lx *lexer) emit(typ itemType) {
|
||||
lx.items <- item{typ, lx.current(), lx.line}
|
||||
lx.start = lx.pos
|
||||
}
|
||||
|
||||
func (lx *lexer) emitTrim(typ itemType) {
|
||||
lx.items <- item{typ, strings.TrimSpace(lx.current()), lx.line}
|
||||
lx.start = lx.pos
|
||||
}
|
||||
|
||||
func (lx *lexer) next() (r rune) {
|
||||
if lx.atEOF {
|
||||
panic("next called after EOF")
|
||||
}
|
||||
if lx.pos >= len(lx.input) {
|
||||
lx.atEOF = true
|
||||
return eof
|
||||
}
|
||||
|
||||
if lx.input[lx.pos] == '\n' {
|
||||
lx.line++
|
||||
}
|
||||
lx.prevWidths[2] = lx.prevWidths[1]
|
||||
lx.prevWidths[1] = lx.prevWidths[0]
|
||||
if lx.nprev < 3 {
|
||||
lx.nprev++
|
||||
}
|
||||
r, w := utf8.DecodeRuneInString(lx.input[lx.pos:])
|
||||
lx.prevWidths[0] = w
|
||||
lx.pos += w
|
||||
return r
|
||||
}
|
||||
|
||||
// ignore skips over the pending input before this point.
|
||||
func (lx *lexer) ignore() {
|
||||
lx.start = lx.pos
|
||||
}
|
||||
|
||||
// backup steps back one rune. Can be called only twice between calls to next.
|
||||
func (lx *lexer) backup() {
|
||||
if lx.atEOF {
|
||||
lx.atEOF = false
|
||||
return
|
||||
}
|
||||
if lx.nprev < 1 {
|
||||
panic("backed up too far")
|
||||
}
|
||||
w := lx.prevWidths[0]
|
||||
lx.prevWidths[0] = lx.prevWidths[1]
|
||||
lx.prevWidths[1] = lx.prevWidths[2]
|
||||
lx.nprev--
|
||||
lx.pos -= w
|
||||
if lx.pos < len(lx.input) && lx.input[lx.pos] == '\n' {
|
||||
lx.line--
|
||||
}
|
||||
}
|
||||
|
||||
// accept consumes the next rune if it's equal to `valid`.
|
||||
func (lx *lexer) accept(valid rune) bool {
|
||||
if lx.next() == valid {
|
||||
return true
|
||||
}
|
||||
lx.backup()
|
||||
return false
|
||||
}
|
||||
|
||||
// peek returns but does not consume the next rune in the input.
|
||||
func (lx *lexer) peek() rune {
|
||||
r := lx.next()
|
||||
lx.backup()
|
||||
return r
|
||||
}
|
||||
|
||||
// skip ignores all input that matches the given predicate.
|
||||
func (lx *lexer) skip(pred func(rune) bool) {
|
||||
for {
|
||||
r := lx.next()
|
||||
if pred(r) {
|
||||
continue
|
||||
}
|
||||
lx.backup()
|
||||
lx.ignore()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// errorf stops all lexing by emitting an error and returning `nil`.
|
||||
// Note that any value that is a character is escaped if it's a special
|
||||
// character (newlines, tabs, etc.).
|
||||
func (lx *lexer) errorf(format string, values ...interface{}) stateFn {
|
||||
lx.items <- item{
|
||||
itemError,
|
||||
fmt.Sprintf(format, values...),
|
||||
lx.line,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// lexTop consumes elements at the top level of TOML data.
|
||||
func lexTop(lx *lexer) stateFn {
|
||||
r := lx.next()
|
||||
if isWhitespace(r) || isNL(r) {
|
||||
return lexSkip(lx, lexTop)
|
||||
}
|
||||
switch r {
|
||||
case commentStart:
|
||||
lx.push(lexTop)
|
||||
return lexCommentStart
|
||||
case tableStart:
|
||||
return lexTableStart
|
||||
case eof:
|
||||
if lx.pos > lx.start {
|
||||
return lx.errorf("unexpected EOF")
|
||||
}
|
||||
lx.emit(itemEOF)
|
||||
return nil
|
||||
}
|
||||
|
||||
// At this point, the only valid item can be a key, so we back up
|
||||
// and let the key lexer do the rest.
|
||||
lx.backup()
|
||||
lx.push(lexTopEnd)
|
||||
return lexKeyStart
|
||||
}
|
||||
|
||||
// lexTopEnd is entered whenever a top-level item has been consumed. (A value
|
||||
// or a table.) It must see only whitespace, and will turn back to lexTop
|
||||
// upon a newline. If it sees EOF, it will quit the lexer successfully.
|
||||
func lexTopEnd(lx *lexer) stateFn {
|
||||
r := lx.next()
|
||||
switch {
|
||||
case r == commentStart:
|
||||
// a comment will read to a newline for us.
|
||||
lx.push(lexTop)
|
||||
return lexCommentStart
|
||||
case isWhitespace(r):
|
||||
return lexTopEnd
|
||||
case isNL(r):
|
||||
lx.ignore()
|
||||
return lexTop
|
||||
case r == eof:
|
||||
lx.emit(itemEOF)
|
||||
return nil
|
||||
}
|
||||
return lx.errorf("expected a top-level item to end with a newline, "+
|
||||
"comment, or EOF, but got %q instead", r)
|
||||
}
|
||||
|
||||
// lexTable lexes the beginning of a table. Namely, it makes sure that
|
||||
// it starts with a character other than '.' and ']'.
|
||||
// It assumes that '[' has already been consumed.
|
||||
// It also handles the case that this is an item in an array of tables.
|
||||
// e.g., '[[name]]'.
|
||||
func lexTableStart(lx *lexer) stateFn {
|
||||
if lx.peek() == arrayTableStart {
|
||||
lx.next()
|
||||
lx.emit(itemArrayTableStart)
|
||||
lx.push(lexArrayTableEnd)
|
||||
} else {
|
||||
lx.emit(itemTableStart)
|
||||
lx.push(lexTableEnd)
|
||||
}
|
||||
return lexTableNameStart
|
||||
}
|
||||
|
||||
func lexTableEnd(lx *lexer) stateFn {
|
||||
lx.emit(itemTableEnd)
|
||||
return lexTopEnd
|
||||
}
|
||||
|
||||
func lexArrayTableEnd(lx *lexer) stateFn {
|
||||
if r := lx.next(); r != arrayTableEnd {
|
||||
return lx.errorf("expected end of table array name delimiter %q, "+
|
||||
"but got %q instead", arrayTableEnd, r)
|
||||
}
|
||||
lx.emit(itemArrayTableEnd)
|
||||
return lexTopEnd
|
||||
}
|
||||
|
||||
func lexTableNameStart(lx *lexer) stateFn {
|
||||
lx.skip(isWhitespace)
|
||||
switch r := lx.peek(); {
|
||||
case r == tableEnd || r == eof:
|
||||
return lx.errorf("unexpected end of table name " +
|
||||
"(table names cannot be empty)")
|
||||
case r == tableSep:
|
||||
return lx.errorf("unexpected table separator " +
|
||||
"(table names cannot be empty)")
|
||||
case r == stringStart || r == rawStringStart:
|
||||
lx.ignore()
|
||||
lx.push(lexTableNameEnd)
|
||||
return lexValue // reuse string lexing
|
||||
default:
|
||||
return lexBareTableName
|
||||
}
|
||||
}
|
||||
|
||||
// lexBareTableName lexes the name of a table. It assumes that at least one
|
||||
// valid character for the table has already been read.
|
||||
func lexBareTableName(lx *lexer) stateFn {
|
||||
r := lx.next()
|
||||
if isBareKeyChar(r) {
|
||||
return lexBareTableName
|
||||
}
|
||||
lx.backup()
|
||||
lx.emit(itemText)
|
||||
return lexTableNameEnd
|
||||
}
|
||||
|
||||
// lexTableNameEnd reads the end of a piece of a table name, optionally
|
||||
// consuming whitespace.
|
||||
func lexTableNameEnd(lx *lexer) stateFn {
|
||||
lx.skip(isWhitespace)
|
||||
switch r := lx.next(); {
|
||||
case isWhitespace(r):
|
||||
return lexTableNameEnd
|
||||
case r == tableSep:
|
||||
lx.ignore()
|
||||
return lexTableNameStart
|
||||
case r == tableEnd:
|
||||
return lx.pop()
|
||||
default:
|
||||
return lx.errorf("expected '.' or ']' to end table name, "+
|
||||
"but got %q instead", r)
|
||||
}
|
||||
}
|
||||
|
||||
// lexKeyStart consumes a key name up until the first non-whitespace character.
|
||||
// lexKeyStart will ignore whitespace.
|
||||
func lexKeyStart(lx *lexer) stateFn {
|
||||
r := lx.peek()
|
||||
switch {
|
||||
case r == keySep:
|
||||
return lx.errorf("unexpected key separator %q", keySep)
|
||||
case isWhitespace(r) || isNL(r):
|
||||
lx.next()
|
||||
return lexSkip(lx, lexKeyStart)
|
||||
case r == stringStart || r == rawStringStart:
|
||||
lx.ignore()
|
||||
lx.emit(itemKeyStart)
|
||||
lx.push(lexKeyEnd)
|
||||
return lexValue // reuse string lexing
|
||||
default:
|
||||
lx.ignore()
|
||||
lx.emit(itemKeyStart)
|
||||
return lexBareKey
|
||||
}
|
||||
}
|
||||
|
||||
// lexBareKey consumes the text of a bare key. Assumes that the first character
|
||||
// (which is not whitespace) has not yet been consumed.
|
||||
func lexBareKey(lx *lexer) stateFn {
|
||||
switch r := lx.next(); {
|
||||
case isBareKeyChar(r):
|
||||
return lexBareKey
|
||||
case isWhitespace(r):
|
||||
lx.backup()
|
||||
lx.emit(itemText)
|
||||
return lexKeyEnd
|
||||
case r == keySep:
|
||||
lx.backup()
|
||||
lx.emit(itemText)
|
||||
return lexKeyEnd
|
||||
default:
|
||||
return lx.errorf("bare keys cannot contain %q", r)
|
||||
}
|
||||
}
|
||||
|
||||
// lexKeyEnd consumes the end of a key and trims whitespace (up to the key
|
||||
// separator).
|
||||
func lexKeyEnd(lx *lexer) stateFn {
|
||||
switch r := lx.next(); {
|
||||
case r == keySep:
|
||||
return lexSkip(lx, lexValue)
|
||||
case isWhitespace(r):
|
||||
return lexSkip(lx, lexKeyEnd)
|
||||
default:
|
||||
return lx.errorf("expected key separator %q, but got %q instead",
|
||||
keySep, r)
|
||||
}
|
||||
}
|
||||
|
||||
// lexValue starts the consumption of a value anywhere a value is expected.
|
||||
// lexValue will ignore whitespace.
|
||||
// After a value is lexed, the last state on the next is popped and returned.
|
||||
func lexValue(lx *lexer) stateFn {
|
||||
// We allow whitespace to precede a value, but NOT newlines.
|
||||
// In array syntax, the array states are responsible for ignoring newlines.
|
||||
r := lx.next()
|
||||
switch {
|
||||
case isWhitespace(r):
|
||||
return lexSkip(lx, lexValue)
|
||||
case isDigit(r):
|
||||
lx.backup() // avoid an extra state and use the same as above
|
||||
return lexNumberOrDateStart
|
||||
}
|
||||
switch r {
|
||||
case arrayStart:
|
||||
lx.ignore()
|
||||
lx.emit(itemArray)
|
||||
return lexArrayValue
|
||||
case inlineTableStart:
|
||||
lx.ignore()
|
||||
lx.emit(itemInlineTableStart)
|
||||
return lexInlineTableValue
|
||||
case stringStart:
|
||||
if lx.accept(stringStart) {
|
||||
if lx.accept(stringStart) {
|
||||
lx.ignore() // Ignore """
|
||||
return lexMultilineString
|
||||
}
|
||||
lx.backup()
|
||||
}
|
||||
lx.ignore() // ignore the '"'
|
||||
return lexString
|
||||
case rawStringStart:
|
||||
if lx.accept(rawStringStart) {
|
||||
if lx.accept(rawStringStart) {
|
||||
lx.ignore() // Ignore """
|
||||
return lexMultilineRawString
|
||||
}
|
||||
lx.backup()
|
||||
}
|
||||
lx.ignore() // ignore the "'"
|
||||
return lexRawString
|
||||
case '+', '-':
|
||||
return lexNumberStart
|
||||
case '.': // special error case, be kind to users
|
||||
return lx.errorf("floats must start with a digit, not '.'")
|
||||
}
|
||||
if unicode.IsLetter(r) {
|
||||
// Be permissive here; lexBool will give a nice error if the
|
||||
// user wrote something like
|
||||
// x = foo
|
||||
// (i.e. not 'true' or 'false' but is something else word-like.)
|
||||
lx.backup()
|
||||
return lexBool
|
||||
}
|
||||
return lx.errorf("expected value but found %q instead", r)
|
||||
}
|
||||
|
||||
// lexArrayValue consumes one value in an array. It assumes that '[' or ','
|
||||
// have already been consumed. All whitespace and newlines are ignored.
|
||||
func lexArrayValue(lx *lexer) stateFn {
|
||||
r := lx.next()
|
||||
switch {
|
||||
case isWhitespace(r) || isNL(r):
|
||||
return lexSkip(lx, lexArrayValue)
|
||||
case r == commentStart:
|
||||
lx.push(lexArrayValue)
|
||||
return lexCommentStart
|
||||
case r == comma:
|
||||
return lx.errorf("unexpected comma")
|
||||
case r == arrayEnd:
|
||||
// NOTE(caleb): The spec isn't clear about whether you can have
|
||||
// a trailing comma or not, so we'll allow it.
|
||||
return lexArrayEnd
|
||||
}
|
||||
|
||||
lx.backup()
|
||||
lx.push(lexArrayValueEnd)
|
||||
return lexValue
|
||||
}
|
||||
|
||||
// lexArrayValueEnd consumes everything between the end of an array value and
|
||||
// the next value (or the end of the array): it ignores whitespace and newlines
|
||||
// and expects either a ',' or a ']'.
|
||||
func lexArrayValueEnd(lx *lexer) stateFn {
|
||||
r := lx.next()
|
||||
switch {
|
||||
case isWhitespace(r) || isNL(r):
|
||||
return lexSkip(lx, lexArrayValueEnd)
|
||||
case r == commentStart:
|
||||
lx.push(lexArrayValueEnd)
|
||||
return lexCommentStart
|
||||
case r == comma:
|
||||
lx.ignore()
|
||||
return lexArrayValue // move on to the next value
|
||||
case r == arrayEnd:
|
||||
return lexArrayEnd
|
||||
}
|
||||
return lx.errorf(
|
||||
"expected a comma or array terminator %q, but got %q instead",
|
||||
arrayEnd, r,
|
||||
)
|
||||
}
|
||||
|
||||
// lexArrayEnd finishes the lexing of an array.
|
||||
// It assumes that a ']' has just been consumed.
|
||||
func lexArrayEnd(lx *lexer) stateFn {
|
||||
lx.ignore()
|
||||
lx.emit(itemArrayEnd)
|
||||
return lx.pop()
|
||||
}
|
||||
|
||||
// lexInlineTableValue consumes one key/value pair in an inline table.
|
||||
// It assumes that '{' or ',' have already been consumed. Whitespace is ignored.
|
||||
func lexInlineTableValue(lx *lexer) stateFn {
|
||||
r := lx.next()
|
||||
switch {
|
||||
case isWhitespace(r):
|
||||
return lexSkip(lx, lexInlineTableValue)
|
||||
case isNL(r):
|
||||
return lx.errorf("newlines not allowed within inline tables")
|
||||
case r == commentStart:
|
||||
lx.push(lexInlineTableValue)
|
||||
return lexCommentStart
|
||||
case r == comma:
|
||||
return lx.errorf("unexpected comma")
|
||||
case r == inlineTableEnd:
|
||||
return lexInlineTableEnd
|
||||
}
|
||||
lx.backup()
|
||||
lx.push(lexInlineTableValueEnd)
|
||||
return lexKeyStart
|
||||
}
|
||||
|
||||
// lexInlineTableValueEnd consumes everything between the end of an inline table
|
||||
// key/value pair and the next pair (or the end of the table):
|
||||
// it ignores whitespace and expects either a ',' or a '}'.
|
||||
func lexInlineTableValueEnd(lx *lexer) stateFn {
|
||||
r := lx.next()
|
||||
switch {
|
||||
case isWhitespace(r):
|
||||
return lexSkip(lx, lexInlineTableValueEnd)
|
||||
case isNL(r):
|
||||
return lx.errorf("newlines not allowed within inline tables")
|
||||
case r == commentStart:
|
||||
lx.push(lexInlineTableValueEnd)
|
||||
return lexCommentStart
|
||||
case r == comma:
|
||||
lx.ignore()
|
||||
return lexInlineTableValue
|
||||
case r == inlineTableEnd:
|
||||
return lexInlineTableEnd
|
||||
}
|
||||
return lx.errorf("expected a comma or an inline table terminator %q, "+
|
||||
"but got %q instead", inlineTableEnd, r)
|
||||
}
|
||||
|
||||
// lexInlineTableEnd finishes the lexing of an inline table.
|
||||
// It assumes that a '}' has just been consumed.
|
||||
func lexInlineTableEnd(lx *lexer) stateFn {
|
||||
lx.ignore()
|
||||
lx.emit(itemInlineTableEnd)
|
||||
return lx.pop()
|
||||
}
|
||||
|
||||
// lexString consumes the inner contents of a string. It assumes that the
|
||||
// beginning '"' has already been consumed and ignored.
|
||||
func lexString(lx *lexer) stateFn {
|
||||
r := lx.next()
|
||||
switch {
|
||||
case r == eof:
|
||||
return lx.errorf("unexpected EOF")
|
||||
case isNL(r):
|
||||
return lx.errorf("strings cannot contain newlines")
|
||||
case r == '\\':
|
||||
lx.push(lexString)
|
||||
return lexStringEscape
|
||||
case r == stringEnd:
|
||||
lx.backup()
|
||||
lx.emit(itemString)
|
||||
lx.next()
|
||||
lx.ignore()
|
||||
return lx.pop()
|
||||
}
|
||||
return lexString
|
||||
}
|
||||
|
||||
// lexMultilineString consumes the inner contents of a string. It assumes that
|
||||
// the beginning '"""' has already been consumed and ignored.
|
||||
func lexMultilineString(lx *lexer) stateFn {
|
||||
switch lx.next() {
|
||||
case eof:
|
||||
return lx.errorf("unexpected EOF")
|
||||
case '\\':
|
||||
return lexMultilineStringEscape
|
||||
case stringEnd:
|
||||
if lx.accept(stringEnd) {
|
||||
if lx.accept(stringEnd) {
|
||||
lx.backup()
|
||||
lx.backup()
|
||||
lx.backup()
|
||||
lx.emit(itemMultilineString)
|
||||
lx.next()
|
||||
lx.next()
|
||||
lx.next()
|
||||
lx.ignore()
|
||||
return lx.pop()
|
||||
}
|
||||
lx.backup()
|
||||
}
|
||||
}
|
||||
return lexMultilineString
|
||||
}
|
||||
|
||||
// lexRawString consumes a raw string. Nothing can be escaped in such a string.
|
||||
// It assumes that the beginning "'" has already been consumed and ignored.
|
||||
func lexRawString(lx *lexer) stateFn {
|
||||
r := lx.next()
|
||||
switch {
|
||||
case r == eof:
|
||||
return lx.errorf("unexpected EOF")
|
||||
case isNL(r):
|
||||
return lx.errorf("strings cannot contain newlines")
|
||||
case r == rawStringEnd:
|
||||
lx.backup()
|
||||
lx.emit(itemRawString)
|
||||
lx.next()
|
||||
lx.ignore()
|
||||
return lx.pop()
|
||||
}
|
||||
return lexRawString
|
||||
}
|
||||
|
||||
// lexMultilineRawString consumes a raw string. Nothing can be escaped in such
|
||||
// a string. It assumes that the beginning "'''" has already been consumed and
|
||||
// ignored.
|
||||
func lexMultilineRawString(lx *lexer) stateFn {
|
||||
switch lx.next() {
|
||||
case eof:
|
||||
return lx.errorf("unexpected EOF")
|
||||
case rawStringEnd:
|
||||
if lx.accept(rawStringEnd) {
|
||||
if lx.accept(rawStringEnd) {
|
||||
lx.backup()
|
||||
lx.backup()
|
||||
lx.backup()
|
||||
lx.emit(itemRawMultilineString)
|
||||
lx.next()
|
||||
lx.next()
|
||||
lx.next()
|
||||
lx.ignore()
|
||||
return lx.pop()
|
||||
}
|
||||
lx.backup()
|
||||
}
|
||||
}
|
||||
return lexMultilineRawString
|
||||
}
|
||||
|
||||
// lexMultilineStringEscape consumes an escaped character. It assumes that the
|
||||
// preceding '\\' has already been consumed.
|
||||
func lexMultilineStringEscape(lx *lexer) stateFn {
|
||||
// Handle the special case first:
|
||||
if isNL(lx.next()) {
|
||||
return lexMultilineString
|
||||
}
|
||||
lx.backup()
|
||||
lx.push(lexMultilineString)
|
||||
return lexStringEscape(lx)
|
||||
}
|
||||
|
||||
func lexStringEscape(lx *lexer) stateFn {
|
||||
r := lx.next()
|
||||
switch r {
|
||||
case 'b':
|
||||
fallthrough
|
||||
case 't':
|
||||
fallthrough
|
||||
case 'n':
|
||||
fallthrough
|
||||
case 'f':
|
||||
fallthrough
|
||||
case 'r':
|
||||
fallthrough
|
||||
case '"':
|
||||
fallthrough
|
||||
case '\\':
|
||||
return lx.pop()
|
||||
case 'u':
|
||||
return lexShortUnicodeEscape
|
||||
case 'U':
|
||||
return lexLongUnicodeEscape
|
||||
}
|
||||
return lx.errorf("invalid escape character %q; only the following "+
|
||||
"escape characters are allowed: "+
|
||||
`\b, \t, \n, \f, \r, \", \\, \uXXXX, and \UXXXXXXXX`, r)
|
||||
}
|
||||
|
||||
func lexShortUnicodeEscape(lx *lexer) stateFn {
|
||||
var r rune
|
||||
for i := 0; i < 4; i++ {
|
||||
r = lx.next()
|
||||
if !isHexadecimal(r) {
|
||||
return lx.errorf(`expected four hexadecimal digits after '\u', `+
|
||||
"but got %q instead", lx.current())
|
||||
}
|
||||
}
|
||||
return lx.pop()
|
||||
}
|
||||
|
||||
func lexLongUnicodeEscape(lx *lexer) stateFn {
|
||||
var r rune
|
||||
for i := 0; i < 8; i++ {
|
||||
r = lx.next()
|
||||
if !isHexadecimal(r) {
|
||||
return lx.errorf(`expected eight hexadecimal digits after '\U', `+
|
||||
"but got %q instead", lx.current())
|
||||
}
|
||||
}
|
||||
return lx.pop()
|
||||
}
|
||||
|
||||
// lexNumberOrDateStart consumes either an integer, a float, or datetime.
|
||||
func lexNumberOrDateStart(lx *lexer) stateFn {
|
||||
r := lx.next()
|
||||
if isDigit(r) {
|
||||
return lexNumberOrDate
|
||||
}
|
||||
switch r {
|
||||
case '_':
|
||||
return lexNumber
|
||||
case 'e', 'E':
|
||||
return lexFloat
|
||||
case '.':
|
||||
return lx.errorf("floats must start with a digit, not '.'")
|
||||
}
|
||||
return lx.errorf("expected a digit but got %q", r)
|
||||
}
|
||||
|
||||
// lexNumberOrDate consumes either an integer, float or datetime.
|
||||
func lexNumberOrDate(lx *lexer) stateFn {
|
||||
r := lx.next()
|
||||
if isDigit(r) {
|
||||
return lexNumberOrDate
|
||||
}
|
||||
switch r {
|
||||
case '-':
|
||||
return lexDatetime
|
||||
case '_':
|
||||
return lexNumber
|
||||
case '.', 'e', 'E':
|
||||
return lexFloat
|
||||
}
|
||||
|
||||
lx.backup()
|
||||
lx.emit(itemInteger)
|
||||
return lx.pop()
|
||||
}
|
||||
|
||||
// lexDatetime consumes a Datetime, to a first approximation.
|
||||
// The parser validates that it matches one of the accepted formats.
|
||||
func lexDatetime(lx *lexer) stateFn {
|
||||
r := lx.next()
|
||||
if isDigit(r) {
|
||||
return lexDatetime
|
||||
}
|
||||
switch r {
|
||||
case '-', 'T', ':', '.', 'Z', '+':
|
||||
return lexDatetime
|
||||
}
|
||||
|
||||
lx.backup()
|
||||
lx.emit(itemDatetime)
|
||||
return lx.pop()
|
||||
}
|
||||
|
||||
// lexNumberStart consumes either an integer or a float. It assumes that a sign
|
||||
// has already been read, but that *no* digits have been consumed.
|
||||
// lexNumberStart will move to the appropriate integer or float states.
|
||||
func lexNumberStart(lx *lexer) stateFn {
|
||||
// We MUST see a digit. Even floats have to start with a digit.
|
||||
r := lx.next()
|
||||
if !isDigit(r) {
|
||||
if r == '.' {
|
||||
return lx.errorf("floats must start with a digit, not '.'")
|
||||
}
|
||||
return lx.errorf("expected a digit but got %q", r)
|
||||
}
|
||||
return lexNumber
|
||||
}
|
||||
|
||||
// lexNumber consumes an integer or a float after seeing the first digit.
|
||||
func lexNumber(lx *lexer) stateFn {
|
||||
r := lx.next()
|
||||
if isDigit(r) {
|
||||
return lexNumber
|
||||
}
|
||||
switch r {
|
||||
case '_':
|
||||
return lexNumber
|
||||
case '.', 'e', 'E':
|
||||
return lexFloat
|
||||
}
|
||||
|
||||
lx.backup()
|
||||
lx.emit(itemInteger)
|
||||
return lx.pop()
|
||||
}
|
||||
|
||||
// lexFloat consumes the elements of a float. It allows any sequence of
|
||||
// float-like characters, so floats emitted by the lexer are only a first
|
||||
// approximation and must be validated by the parser.
|
||||
func lexFloat(lx *lexer) stateFn {
|
||||
r := lx.next()
|
||||
if isDigit(r) {
|
||||
return lexFloat
|
||||
}
|
||||
switch r {
|
||||
case '_', '.', '-', '+', 'e', 'E':
|
||||
return lexFloat
|
||||
}
|
||||
|
||||
lx.backup()
|
||||
lx.emit(itemFloat)
|
||||
return lx.pop()
|
||||
}
|
||||
|
||||
// lexBool consumes a bool string: 'true' or 'false.
|
||||
func lexBool(lx *lexer) stateFn {
|
||||
var rs []rune
|
||||
for {
|
||||
r := lx.next()
|
||||
if !unicode.IsLetter(r) {
|
||||
lx.backup()
|
||||
break
|
||||
}
|
||||
rs = append(rs, r)
|
||||
}
|
||||
s := string(rs)
|
||||
switch s {
|
||||
case "true", "false":
|
||||
lx.emit(itemBool)
|
||||
return lx.pop()
|
||||
}
|
||||
return lx.errorf("expected value but found %q instead", s)
|
||||
}
|
||||
|
||||
// lexCommentStart begins the lexing of a comment. It will emit
|
||||
// itemCommentStart and consume no characters, passing control to lexComment.
|
||||
func lexCommentStart(lx *lexer) stateFn {
|
||||
lx.ignore()
|
||||
lx.emit(itemCommentStart)
|
||||
return lexComment
|
||||
}
|
||||
|
||||
// lexComment lexes an entire comment. It assumes that '#' has been consumed.
|
||||
// It will consume *up to* the first newline character, and pass control
|
||||
// back to the last state on the stack.
|
||||
func lexComment(lx *lexer) stateFn {
|
||||
r := lx.peek()
|
||||
if isNL(r) || r == eof {
|
||||
lx.emit(itemText)
|
||||
return lx.pop()
|
||||
}
|
||||
lx.next()
|
||||
return lexComment
|
||||
}
|
||||
|
||||
// lexSkip ignores all slurped input and moves on to the next state.
|
||||
func lexSkip(lx *lexer, nextState stateFn) stateFn {
|
||||
return func(lx *lexer) stateFn {
|
||||
lx.ignore()
|
||||
return nextState
|
||||
}
|
||||
}
|
||||
|
||||
// isWhitespace returns true if `r` is a whitespace character according
|
||||
// to the spec.
|
||||
func isWhitespace(r rune) bool {
|
||||
return r == '\t' || r == ' '
|
||||
}
|
||||
|
||||
func isNL(r rune) bool {
|
||||
return r == '\n' || r == '\r'
|
||||
}
|
||||
|
||||
func isDigit(r rune) bool {
|
||||
return r >= '0' && r <= '9'
|
||||
}
|
||||
|
||||
func isHexadecimal(r rune) bool {
|
||||
return (r >= '0' && r <= '9') ||
|
||||
(r >= 'a' && r <= 'f') ||
|
||||
(r >= 'A' && r <= 'F')
|
||||
}
|
||||
|
||||
func isBareKeyChar(r rune) bool {
|
||||
return (r >= 'A' && r <= 'Z') ||
|
||||
(r >= 'a' && r <= 'z') ||
|
||||
(r >= '0' && r <= '9') ||
|
||||
r == '_' ||
|
||||
r == '-'
|
||||
}
|
||||
|
||||
func (itype itemType) String() string {
|
||||
switch itype {
|
||||
case itemError:
|
||||
return "Error"
|
||||
case itemNIL:
|
||||
return "NIL"
|
||||
case itemEOF:
|
||||
return "EOF"
|
||||
case itemText:
|
||||
return "Text"
|
||||
case itemString, itemRawString, itemMultilineString, itemRawMultilineString:
|
||||
return "String"
|
||||
case itemBool:
|
||||
return "Bool"
|
||||
case itemInteger:
|
||||
return "Integer"
|
||||
case itemFloat:
|
||||
return "Float"
|
||||
case itemDatetime:
|
||||
return "DateTime"
|
||||
case itemTableStart:
|
||||
return "TableStart"
|
||||
case itemTableEnd:
|
||||
return "TableEnd"
|
||||
case itemKeyStart:
|
||||
return "KeyStart"
|
||||
case itemArray:
|
||||
return "Array"
|
||||
case itemArrayEnd:
|
||||
return "ArrayEnd"
|
||||
case itemCommentStart:
|
||||
return "CommentStart"
|
||||
}
|
||||
panic(fmt.Sprintf("BUG: Unknown type '%d'.", int(itype)))
|
||||
}
|
||||
|
||||
func (item item) String() string {
|
||||
return fmt.Sprintf("(%s, %s)", item.typ.String(), item.val)
|
||||
}
|
|
@ -0,0 +1,592 @@
|
|||
package toml
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
type parser struct {
|
||||
mapping map[string]interface{}
|
||||
types map[string]tomlType
|
||||
lx *lexer
|
||||
|
||||
// A list of keys in the order that they appear in the TOML data.
|
||||
ordered []Key
|
||||
|
||||
// the full key for the current hash in scope
|
||||
context Key
|
||||
|
||||
// the base key name for everything except hashes
|
||||
currentKey string
|
||||
|
||||
// rough approximation of line number
|
||||
approxLine int
|
||||
|
||||
// A map of 'key.group.names' to whether they were created implicitly.
|
||||
implicits map[string]bool
|
||||
}
|
||||
|
||||
type parseError string
|
||||
|
||||
func (pe parseError) Error() string {
|
||||
return string(pe)
|
||||
}
|
||||
|
||||
func parse(data string) (p *parser, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
var ok bool
|
||||
if err, ok = r.(parseError); ok {
|
||||
return
|
||||
}
|
||||
panic(r)
|
||||
}
|
||||
}()
|
||||
|
||||
p = &parser{
|
||||
mapping: make(map[string]interface{}),
|
||||
types: make(map[string]tomlType),
|
||||
lx: lex(data),
|
||||
ordered: make([]Key, 0),
|
||||
implicits: make(map[string]bool),
|
||||
}
|
||||
for {
|
||||
item := p.next()
|
||||
if item.typ == itemEOF {
|
||||
break
|
||||
}
|
||||
p.topLevel(item)
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (p *parser) panicf(format string, v ...interface{}) {
|
||||
msg := fmt.Sprintf("Near line %d (last key parsed '%s'): %s",
|
||||
p.approxLine, p.current(), fmt.Sprintf(format, v...))
|
||||
panic(parseError(msg))
|
||||
}
|
||||
|
||||
func (p *parser) next() item {
|
||||
it := p.lx.nextItem()
|
||||
if it.typ == itemError {
|
||||
p.panicf("%s", it.val)
|
||||
}
|
||||
return it
|
||||
}
|
||||
|
||||
func (p *parser) bug(format string, v ...interface{}) {
|
||||
panic(fmt.Sprintf("BUG: "+format+"\n\n", v...))
|
||||
}
|
||||
|
||||
func (p *parser) expect(typ itemType) item {
|
||||
it := p.next()
|
||||
p.assertEqual(typ, it.typ)
|
||||
return it
|
||||
}
|
||||
|
||||
func (p *parser) assertEqual(expected, got itemType) {
|
||||
if expected != got {
|
||||
p.bug("Expected '%s' but got '%s'.", expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *parser) topLevel(item item) {
|
||||
switch item.typ {
|
||||
case itemCommentStart:
|
||||
p.approxLine = item.line
|
||||
p.expect(itemText)
|
||||
case itemTableStart:
|
||||
kg := p.next()
|
||||
p.approxLine = kg.line
|
||||
|
||||
var key Key
|
||||
for ; kg.typ != itemTableEnd && kg.typ != itemEOF; kg = p.next() {
|
||||
key = append(key, p.keyString(kg))
|
||||
}
|
||||
p.assertEqual(itemTableEnd, kg.typ)
|
||||
|
||||
p.establishContext(key, false)
|
||||
p.setType("", tomlHash)
|
||||
p.ordered = append(p.ordered, key)
|
||||
case itemArrayTableStart:
|
||||
kg := p.next()
|
||||
p.approxLine = kg.line
|
||||
|
||||
var key Key
|
||||
for ; kg.typ != itemArrayTableEnd && kg.typ != itemEOF; kg = p.next() {
|
||||
key = append(key, p.keyString(kg))
|
||||
}
|
||||
p.assertEqual(itemArrayTableEnd, kg.typ)
|
||||
|
||||
p.establishContext(key, true)
|
||||
p.setType("", tomlArrayHash)
|
||||
p.ordered = append(p.ordered, key)
|
||||
case itemKeyStart:
|
||||
kname := p.next()
|
||||
p.approxLine = kname.line
|
||||
p.currentKey = p.keyString(kname)
|
||||
|
||||
val, typ := p.value(p.next())
|
||||
p.setValue(p.currentKey, val)
|
||||
p.setType(p.currentKey, typ)
|
||||
p.ordered = append(p.ordered, p.context.add(p.currentKey))
|
||||
p.currentKey = ""
|
||||
default:
|
||||
p.bug("Unexpected type at top level: %s", item.typ)
|
||||
}
|
||||
}
|
||||
|
||||
// Gets a string for a key (or part of a key in a table name).
|
||||
func (p *parser) keyString(it item) string {
|
||||
switch it.typ {
|
||||
case itemText:
|
||||
return it.val
|
||||
case itemString, itemMultilineString,
|
||||
itemRawString, itemRawMultilineString:
|
||||
s, _ := p.value(it)
|
||||
return s.(string)
|
||||
default:
|
||||
p.bug("Unexpected key type: %s", it.typ)
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
// value translates an expected value from the lexer into a Go value wrapped
|
||||
// as an empty interface.
|
||||
func (p *parser) value(it item) (interface{}, tomlType) {
|
||||
switch it.typ {
|
||||
case itemString:
|
||||
return p.replaceEscapes(it.val), p.typeOfPrimitive(it)
|
||||
case itemMultilineString:
|
||||
trimmed := stripFirstNewline(stripEscapedWhitespace(it.val))
|
||||
return p.replaceEscapes(trimmed), p.typeOfPrimitive(it)
|
||||
case itemRawString:
|
||||
return it.val, p.typeOfPrimitive(it)
|
||||
case itemRawMultilineString:
|
||||
return stripFirstNewline(it.val), p.typeOfPrimitive(it)
|
||||
case itemBool:
|
||||
switch it.val {
|
||||
case "true":
|
||||
return true, p.typeOfPrimitive(it)
|
||||
case "false":
|
||||
return false, p.typeOfPrimitive(it)
|
||||
}
|
||||
p.bug("Expected boolean value, but got '%s'.", it.val)
|
||||
case itemInteger:
|
||||
if !numUnderscoresOK(it.val) {
|
||||
p.panicf("Invalid integer %q: underscores must be surrounded by digits",
|
||||
it.val)
|
||||
}
|
||||
val := strings.Replace(it.val, "_", "", -1)
|
||||
num, err := strconv.ParseInt(val, 10, 64)
|
||||
if err != nil {
|
||||
// Distinguish integer values. Normally, it'd be a bug if the lexer
|
||||
// provides an invalid integer, but it's possible that the number is
|
||||
// out of range of valid values (which the lexer cannot determine).
|
||||
// So mark the former as a bug but the latter as a legitimate user
|
||||
// error.
|
||||
if e, ok := err.(*strconv.NumError); ok &&
|
||||
e.Err == strconv.ErrRange {
|
||||
|
||||
p.panicf("Integer '%s' is out of the range of 64-bit "+
|
||||
"signed integers.", it.val)
|
||||
} else {
|
||||
p.bug("Expected integer value, but got '%s'.", it.val)
|
||||
}
|
||||
}
|
||||
return num, p.typeOfPrimitive(it)
|
||||
case itemFloat:
|
||||
parts := strings.FieldsFunc(it.val, func(r rune) bool {
|
||||
switch r {
|
||||
case '.', 'e', 'E':
|
||||
return true
|
||||
}
|
||||
return false
|
||||
})
|
||||
for _, part := range parts {
|
||||
if !numUnderscoresOK(part) {
|
||||
p.panicf("Invalid float %q: underscores must be "+
|
||||
"surrounded by digits", it.val)
|
||||
}
|
||||
}
|
||||
if !numPeriodsOK(it.val) {
|
||||
// As a special case, numbers like '123.' or '1.e2',
|
||||
// which are valid as far as Go/strconv are concerned,
|
||||
// must be rejected because TOML says that a fractional
|
||||
// part consists of '.' followed by 1+ digits.
|
||||
p.panicf("Invalid float %q: '.' must be followed "+
|
||||
"by one or more digits", it.val)
|
||||
}
|
||||
val := strings.Replace(it.val, "_", "", -1)
|
||||
num, err := strconv.ParseFloat(val, 64)
|
||||
if err != nil {
|
||||
if e, ok := err.(*strconv.NumError); ok &&
|
||||
e.Err == strconv.ErrRange {
|
||||
|
||||
p.panicf("Float '%s' is out of the range of 64-bit "+
|
||||
"IEEE-754 floating-point numbers.", it.val)
|
||||
} else {
|
||||
p.panicf("Invalid float value: %q", it.val)
|
||||
}
|
||||
}
|
||||
return num, p.typeOfPrimitive(it)
|
||||
case itemDatetime:
|
||||
var t time.Time
|
||||
var ok bool
|
||||
var err error
|
||||
for _, format := range []string{
|
||||
"2006-01-02T15:04:05Z07:00",
|
||||
"2006-01-02T15:04:05",
|
||||
"2006-01-02",
|
||||
} {
|
||||
t, err = time.ParseInLocation(format, it.val, time.Local)
|
||||
if err == nil {
|
||||
ok = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
p.panicf("Invalid TOML Datetime: %q.", it.val)
|
||||
}
|
||||
return t, p.typeOfPrimitive(it)
|
||||
case itemArray:
|
||||
array := make([]interface{}, 0)
|
||||
types := make([]tomlType, 0)
|
||||
|
||||
for it = p.next(); it.typ != itemArrayEnd; it = p.next() {
|
||||
if it.typ == itemCommentStart {
|
||||
p.expect(itemText)
|
||||
continue
|
||||
}
|
||||
|
||||
val, typ := p.value(it)
|
||||
array = append(array, val)
|
||||
types = append(types, typ)
|
||||
}
|
||||
return array, p.typeOfArray(types)
|
||||
case itemInlineTableStart:
|
||||
var (
|
||||
hash = make(map[string]interface{})
|
||||
outerContext = p.context
|
||||
outerKey = p.currentKey
|
||||
)
|
||||
|
||||
p.context = append(p.context, p.currentKey)
|
||||
p.currentKey = ""
|
||||
for it := p.next(); it.typ != itemInlineTableEnd; it = p.next() {
|
||||
if it.typ != itemKeyStart {
|
||||
p.bug("Expected key start but instead found %q, around line %d",
|
||||
it.val, p.approxLine)
|
||||
}
|
||||
if it.typ == itemCommentStart {
|
||||
p.expect(itemText)
|
||||
continue
|
||||
}
|
||||
|
||||
// retrieve key
|
||||
k := p.next()
|
||||
p.approxLine = k.line
|
||||
kname := p.keyString(k)
|
||||
|
||||
// retrieve value
|
||||
p.currentKey = kname
|
||||
val, typ := p.value(p.next())
|
||||
// make sure we keep metadata up to date
|
||||
p.setType(kname, typ)
|
||||
p.ordered = append(p.ordered, p.context.add(p.currentKey))
|
||||
hash[kname] = val
|
||||
}
|
||||
p.context = outerContext
|
||||
p.currentKey = outerKey
|
||||
return hash, tomlHash
|
||||
}
|
||||
p.bug("Unexpected value type: %s", it.typ)
|
||||
panic("unreachable")
|
||||
}
|
||||
|
||||
// numUnderscoresOK checks whether each underscore in s is surrounded by
|
||||
// characters that are not underscores.
|
||||
func numUnderscoresOK(s string) bool {
|
||||
accept := false
|
||||
for _, r := range s {
|
||||
if r == '_' {
|
||||
if !accept {
|
||||
return false
|
||||
}
|
||||
accept = false
|
||||
continue
|
||||
}
|
||||
accept = true
|
||||
}
|
||||
return accept
|
||||
}
|
||||
|
||||
// numPeriodsOK checks whether every period in s is followed by a digit.
|
||||
func numPeriodsOK(s string) bool {
|
||||
period := false
|
||||
for _, r := range s {
|
||||
if period && !isDigit(r) {
|
||||
return false
|
||||
}
|
||||
period = r == '.'
|
||||
}
|
||||
return !period
|
||||
}
|
||||
|
||||
// establishContext sets the current context of the parser,
|
||||
// where the context is either a hash or an array of hashes. Which one is
|
||||
// set depends on the value of the `array` parameter.
|
||||
//
|
||||
// Establishing the context also makes sure that the key isn't a duplicate, and
|
||||
// will create implicit hashes automatically.
|
||||
func (p *parser) establishContext(key Key, array bool) {
|
||||
var ok bool
|
||||
|
||||
// Always start at the top level and drill down for our context.
|
||||
hashContext := p.mapping
|
||||
keyContext := make(Key, 0)
|
||||
|
||||
// We only need implicit hashes for key[0:-1]
|
||||
for _, k := range key[0 : len(key)-1] {
|
||||
_, ok = hashContext[k]
|
||||
keyContext = append(keyContext, k)
|
||||
|
||||
// No key? Make an implicit hash and move on.
|
||||
if !ok {
|
||||
p.addImplicit(keyContext)
|
||||
hashContext[k] = make(map[string]interface{})
|
||||
}
|
||||
|
||||
// If the hash context is actually an array of tables, then set
|
||||
// the hash context to the last element in that array.
|
||||
//
|
||||
// Otherwise, it better be a table, since this MUST be a key group (by
|
||||
// virtue of it not being the last element in a key).
|
||||
switch t := hashContext[k].(type) {
|
||||
case []map[string]interface{}:
|
||||
hashContext = t[len(t)-1]
|
||||
case map[string]interface{}:
|
||||
hashContext = t
|
||||
default:
|
||||
p.panicf("Key '%s' was already created as a hash.", keyContext)
|
||||
}
|
||||
}
|
||||
|
||||
p.context = keyContext
|
||||
if array {
|
||||
// If this is the first element for this array, then allocate a new
|
||||
// list of tables for it.
|
||||
k := key[len(key)-1]
|
||||
if _, ok := hashContext[k]; !ok {
|
||||
hashContext[k] = make([]map[string]interface{}, 0, 5)
|
||||
}
|
||||
|
||||
// Add a new table. But make sure the key hasn't already been used
|
||||
// for something else.
|
||||
if hash, ok := hashContext[k].([]map[string]interface{}); ok {
|
||||
hashContext[k] = append(hash, make(map[string]interface{}))
|
||||
} else {
|
||||
p.panicf("Key '%s' was already created and cannot be used as "+
|
||||
"an array.", keyContext)
|
||||
}
|
||||
} else {
|
||||
p.setValue(key[len(key)-1], make(map[string]interface{}))
|
||||
}
|
||||
p.context = append(p.context, key[len(key)-1])
|
||||
}
|
||||
|
||||
// setValue sets the given key to the given value in the current context.
|
||||
// It will make sure that the key hasn't already been defined, account for
|
||||
// implicit key groups.
|
||||
func (p *parser) setValue(key string, value interface{}) {
|
||||
var tmpHash interface{}
|
||||
var ok bool
|
||||
|
||||
hash := p.mapping
|
||||
keyContext := make(Key, 0)
|
||||
for _, k := range p.context {
|
||||
keyContext = append(keyContext, k)
|
||||
if tmpHash, ok = hash[k]; !ok {
|
||||
p.bug("Context for key '%s' has not been established.", keyContext)
|
||||
}
|
||||
switch t := tmpHash.(type) {
|
||||
case []map[string]interface{}:
|
||||
// The context is a table of hashes. Pick the most recent table
|
||||
// defined as the current hash.
|
||||
hash = t[len(t)-1]
|
||||
case map[string]interface{}:
|
||||
hash = t
|
||||
default:
|
||||
p.bug("Expected hash to have type 'map[string]interface{}', but "+
|
||||
"it has '%T' instead.", tmpHash)
|
||||
}
|
||||
}
|
||||
keyContext = append(keyContext, key)
|
||||
|
||||
if _, ok := hash[key]; ok {
|
||||
// Typically, if the given key has already been set, then we have
|
||||
// to raise an error since duplicate keys are disallowed. However,
|
||||
// it's possible that a key was previously defined implicitly. In this
|
||||
// case, it is allowed to be redefined concretely. (See the
|
||||
// `tests/valid/implicit-and-explicit-after.toml` test in `toml-test`.)
|
||||
//
|
||||
// But we have to make sure to stop marking it as an implicit. (So that
|
||||
// another redefinition provokes an error.)
|
||||
//
|
||||
// Note that since it has already been defined (as a hash), we don't
|
||||
// want to overwrite it. So our business is done.
|
||||
if p.isImplicit(keyContext) {
|
||||
p.removeImplicit(keyContext)
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, we have a concrete key trying to override a previous
|
||||
// key, which is *always* wrong.
|
||||
p.panicf("Key '%s' has already been defined.", keyContext)
|
||||
}
|
||||
hash[key] = value
|
||||
}
|
||||
|
||||
// setType sets the type of a particular value at a given key.
|
||||
// It should be called immediately AFTER setValue.
|
||||
//
|
||||
// Note that if `key` is empty, then the type given will be applied to the
|
||||
// current context (which is either a table or an array of tables).
|
||||
func (p *parser) setType(key string, typ tomlType) {
|
||||
keyContext := make(Key, 0, len(p.context)+1)
|
||||
for _, k := range p.context {
|
||||
keyContext = append(keyContext, k)
|
||||
}
|
||||
if len(key) > 0 { // allow type setting for hashes
|
||||
keyContext = append(keyContext, key)
|
||||
}
|
||||
p.types[keyContext.String()] = typ
|
||||
}
|
||||
|
||||
// addImplicit sets the given Key as having been created implicitly.
|
||||
func (p *parser) addImplicit(key Key) {
|
||||
p.implicits[key.String()] = true
|
||||
}
|
||||
|
||||
// removeImplicit stops tagging the given key as having been implicitly
|
||||
// created.
|
||||
func (p *parser) removeImplicit(key Key) {
|
||||
p.implicits[key.String()] = false
|
||||
}
|
||||
|
||||
// isImplicit returns true if the key group pointed to by the key was created
|
||||
// implicitly.
|
||||
func (p *parser) isImplicit(key Key) bool {
|
||||
return p.implicits[key.String()]
|
||||
}
|
||||
|
||||
// current returns the full key name of the current context.
|
||||
func (p *parser) current() string {
|
||||
if len(p.currentKey) == 0 {
|
||||
return p.context.String()
|
||||
}
|
||||
if len(p.context) == 0 {
|
||||
return p.currentKey
|
||||
}
|
||||
return fmt.Sprintf("%s.%s", p.context, p.currentKey)
|
||||
}
|
||||
|
||||
func stripFirstNewline(s string) string {
|
||||
if len(s) == 0 || s[0] != '\n' {
|
||||
return s
|
||||
}
|
||||
return s[1:]
|
||||
}
|
||||
|
||||
func stripEscapedWhitespace(s string) string {
|
||||
esc := strings.Split(s, "\\\n")
|
||||
if len(esc) > 1 {
|
||||
for i := 1; i < len(esc); i++ {
|
||||
esc[i] = strings.TrimLeftFunc(esc[i], unicode.IsSpace)
|
||||
}
|
||||
}
|
||||
return strings.Join(esc, "")
|
||||
}
|
||||
|
||||
func (p *parser) replaceEscapes(str string) string {
|
||||
var replaced []rune
|
||||
s := []byte(str)
|
||||
r := 0
|
||||
for r < len(s) {
|
||||
if s[r] != '\\' {
|
||||
c, size := utf8.DecodeRune(s[r:])
|
||||
r += size
|
||||
replaced = append(replaced, c)
|
||||
continue
|
||||
}
|
||||
r += 1
|
||||
if r >= len(s) {
|
||||
p.bug("Escape sequence at end of string.")
|
||||
return ""
|
||||
}
|
||||
switch s[r] {
|
||||
default:
|
||||
p.bug("Expected valid escape code after \\, but got %q.", s[r])
|
||||
return ""
|
||||
case 'b':
|
||||
replaced = append(replaced, rune(0x0008))
|
||||
r += 1
|
||||
case 't':
|
||||
replaced = append(replaced, rune(0x0009))
|
||||
r += 1
|
||||
case 'n':
|
||||
replaced = append(replaced, rune(0x000A))
|
||||
r += 1
|
||||
case 'f':
|
||||
replaced = append(replaced, rune(0x000C))
|
||||
r += 1
|
||||
case 'r':
|
||||
replaced = append(replaced, rune(0x000D))
|
||||
r += 1
|
||||
case '"':
|
||||
replaced = append(replaced, rune(0x0022))
|
||||
r += 1
|
||||
case '\\':
|
||||
replaced = append(replaced, rune(0x005C))
|
||||
r += 1
|
||||
case 'u':
|
||||
// At this point, we know we have a Unicode escape of the form
|
||||
// `uXXXX` at [r, r+5). (Because the lexer guarantees this
|
||||
// for us.)
|
||||
escaped := p.asciiEscapeToUnicode(s[r+1 : r+5])
|
||||
replaced = append(replaced, escaped)
|
||||
r += 5
|
||||
case 'U':
|
||||
// At this point, we know we have a Unicode escape of the form
|
||||
// `uXXXX` at [r, r+9). (Because the lexer guarantees this
|
||||
// for us.)
|
||||
escaped := p.asciiEscapeToUnicode(s[r+1 : r+9])
|
||||
replaced = append(replaced, escaped)
|
||||
r += 9
|
||||
}
|
||||
}
|
||||
return string(replaced)
|
||||
}
|
||||
|
||||
func (p *parser) asciiEscapeToUnicode(bs []byte) rune {
|
||||
s := string(bs)
|
||||
hex, err := strconv.ParseUint(strings.ToLower(s), 16, 32)
|
||||
if err != nil {
|
||||
p.bug("Could not parse '%s' as a hexadecimal number, but the "+
|
||||
"lexer claims it's OK: %s", s, err)
|
||||
}
|
||||
if !utf8.ValidRune(rune(hex)) {
|
||||
p.panicf("Escaped character '\\u%s' is not valid UTF-8.", s)
|
||||
}
|
||||
return rune(hex)
|
||||
}
|
||||
|
||||
func isStringType(ty itemType) bool {
|
||||
return ty == itemString || ty == itemMultilineString ||
|
||||
ty == itemRawString || ty == itemRawMultilineString
|
||||
}
|
|
@ -0,0 +1 @@
|
|||
au BufWritePost *.go silent!make tags > /dev/null 2>&1
|
|
@ -0,0 +1,91 @@
|
|||
package toml
|
||||
|
||||
// tomlType represents any Go type that corresponds to a TOML type.
|
||||
// While the first draft of the TOML spec has a simplistic type system that
|
||||
// probably doesn't need this level of sophistication, we seem to be militating
|
||||
// toward adding real composite types.
|
||||
type tomlType interface {
|
||||
typeString() string
|
||||
}
|
||||
|
||||
// typeEqual accepts any two types and returns true if they are equal.
|
||||
func typeEqual(t1, t2 tomlType) bool {
|
||||
if t1 == nil || t2 == nil {
|
||||
return false
|
||||
}
|
||||
return t1.typeString() == t2.typeString()
|
||||
}
|
||||
|
||||
func typeIsHash(t tomlType) bool {
|
||||
return typeEqual(t, tomlHash) || typeEqual(t, tomlArrayHash)
|
||||
}
|
||||
|
||||
type tomlBaseType string
|
||||
|
||||
func (btype tomlBaseType) typeString() string {
|
||||
return string(btype)
|
||||
}
|
||||
|
||||
func (btype tomlBaseType) String() string {
|
||||
return btype.typeString()
|
||||
}
|
||||
|
||||
var (
|
||||
tomlInteger tomlBaseType = "Integer"
|
||||
tomlFloat tomlBaseType = "Float"
|
||||
tomlDatetime tomlBaseType = "Datetime"
|
||||
tomlString tomlBaseType = "String"
|
||||
tomlBool tomlBaseType = "Bool"
|
||||
tomlArray tomlBaseType = "Array"
|
||||
tomlHash tomlBaseType = "Hash"
|
||||
tomlArrayHash tomlBaseType = "ArrayHash"
|
||||
)
|
||||
|
||||
// typeOfPrimitive returns a tomlType of any primitive value in TOML.
|
||||
// Primitive values are: Integer, Float, Datetime, String and Bool.
|
||||
//
|
||||
// Passing a lexer item other than the following will cause a BUG message
|
||||
// to occur: itemString, itemBool, itemInteger, itemFloat, itemDatetime.
|
||||
func (p *parser) typeOfPrimitive(lexItem item) tomlType {
|
||||
switch lexItem.typ {
|
||||
case itemInteger:
|
||||
return tomlInteger
|
||||
case itemFloat:
|
||||
return tomlFloat
|
||||
case itemDatetime:
|
||||
return tomlDatetime
|
||||
case itemString:
|
||||
return tomlString
|
||||
case itemMultilineString:
|
||||
return tomlString
|
||||
case itemRawString:
|
||||
return tomlString
|
||||
case itemRawMultilineString:
|
||||
return tomlString
|
||||
case itemBool:
|
||||
return tomlBool
|
||||
}
|
||||
p.bug("Cannot infer primitive type of lex item '%s'.", lexItem)
|
||||
panic("unreachable")
|
||||
}
|
||||
|
||||
// typeOfArray returns a tomlType for an array given a list of types of its
|
||||
// values.
|
||||
//
|
||||
// In the current spec, if an array is homogeneous, then its type is always
|
||||
// "Array". If the array is not homogeneous, an error is generated.
|
||||
func (p *parser) typeOfArray(types []tomlType) tomlType {
|
||||
// Empty arrays are cool.
|
||||
if len(types) == 0 {
|
||||
return tomlArray
|
||||
}
|
||||
|
||||
theType := types[0]
|
||||
for _, t := range types[1:] {
|
||||
if !typeEqual(theType, t) {
|
||||
p.panicf("Array contains values of type '%s' and '%s', but "+
|
||||
"arrays must be homogeneous.", theType, t)
|
||||
}
|
||||
}
|
||||
return tomlArray
|
||||
}
|
|
@ -0,0 +1,242 @@
|
|||
package toml
|
||||
|
||||
// Struct field handling is adapted from code in encoding/json:
|
||||
//
|
||||
// Copyright 2010 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the Go distribution.
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// A field represents a single field found in a struct.
|
||||
type field struct {
|
||||
name string // the name of the field (`toml` tag included)
|
||||
tag bool // whether field has a `toml` tag
|
||||
index []int // represents the depth of an anonymous field
|
||||
typ reflect.Type // the type of the field
|
||||
}
|
||||
|
||||
// byName sorts field by name, breaking ties with depth,
|
||||
// then breaking ties with "name came from toml tag", then
|
||||
// breaking ties with index sequence.
|
||||
type byName []field
|
||||
|
||||
func (x byName) Len() int { return len(x) }
|
||||
|
||||
func (x byName) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
|
||||
|
||||
func (x byName) Less(i, j int) bool {
|
||||
if x[i].name != x[j].name {
|
||||
return x[i].name < x[j].name
|
||||
}
|
||||
if len(x[i].index) != len(x[j].index) {
|
||||
return len(x[i].index) < len(x[j].index)
|
||||
}
|
||||
if x[i].tag != x[j].tag {
|
||||
return x[i].tag
|
||||
}
|
||||
return byIndex(x).Less(i, j)
|
||||
}
|
||||
|
||||
// byIndex sorts field by index sequence.
|
||||
type byIndex []field
|
||||
|
||||
func (x byIndex) Len() int { return len(x) }
|
||||
|
||||
func (x byIndex) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
|
||||
|
||||
func (x byIndex) Less(i, j int) bool {
|
||||
for k, xik := range x[i].index {
|
||||
if k >= len(x[j].index) {
|
||||
return false
|
||||
}
|
||||
if xik != x[j].index[k] {
|
||||
return xik < x[j].index[k]
|
||||
}
|
||||
}
|
||||
return len(x[i].index) < len(x[j].index)
|
||||
}
|
||||
|
||||
// typeFields returns a list of fields that TOML should recognize for the given
|
||||
// type. The algorithm is breadth-first search over the set of structs to
|
||||
// include - the top struct and then any reachable anonymous structs.
|
||||
func typeFields(t reflect.Type) []field {
|
||||
// Anonymous fields to explore at the current level and the next.
|
||||
current := []field{}
|
||||
next := []field{{typ: t}}
|
||||
|
||||
// Count of queued names for current level and the next.
|
||||
count := map[reflect.Type]int{}
|
||||
nextCount := map[reflect.Type]int{}
|
||||
|
||||
// Types already visited at an earlier level.
|
||||
visited := map[reflect.Type]bool{}
|
||||
|
||||
// Fields found.
|
||||
var fields []field
|
||||
|
||||
for len(next) > 0 {
|
||||
current, next = next, current[:0]
|
||||
count, nextCount = nextCount, map[reflect.Type]int{}
|
||||
|
||||
for _, f := range current {
|
||||
if visited[f.typ] {
|
||||
continue
|
||||
}
|
||||
visited[f.typ] = true
|
||||
|
||||
// Scan f.typ for fields to include.
|
||||
for i := 0; i < f.typ.NumField(); i++ {
|
||||
sf := f.typ.Field(i)
|
||||
if sf.PkgPath != "" && !sf.Anonymous { // unexported
|
||||
continue
|
||||
}
|
||||
opts := getOptions(sf.Tag)
|
||||
if opts.skip {
|
||||
continue
|
||||
}
|
||||
index := make([]int, len(f.index)+1)
|
||||
copy(index, f.index)
|
||||
index[len(f.index)] = i
|
||||
|
||||
ft := sf.Type
|
||||
if ft.Name() == "" && ft.Kind() == reflect.Ptr {
|
||||
// Follow pointer.
|
||||
ft = ft.Elem()
|
||||
}
|
||||
|
||||
// Record found field and index sequence.
|
||||
if opts.name != "" || !sf.Anonymous || ft.Kind() != reflect.Struct {
|
||||
tagged := opts.name != ""
|
||||
name := opts.name
|
||||
if name == "" {
|
||||
name = sf.Name
|
||||
}
|
||||
fields = append(fields, field{name, tagged, index, ft})
|
||||
if count[f.typ] > 1 {
|
||||
// If there were multiple instances, add a second,
|
||||
// so that the annihilation code will see a duplicate.
|
||||
// It only cares about the distinction between 1 or 2,
|
||||
// so don't bother generating any more copies.
|
||||
fields = append(fields, fields[len(fields)-1])
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Record new anonymous struct to explore in next round.
|
||||
nextCount[ft]++
|
||||
if nextCount[ft] == 1 {
|
||||
f := field{name: ft.Name(), index: index, typ: ft}
|
||||
next = append(next, f)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sort.Sort(byName(fields))
|
||||
|
||||
// Delete all fields that are hidden by the Go rules for embedded fields,
|
||||
// except that fields with TOML tags are promoted.
|
||||
|
||||
// The fields are sorted in primary order of name, secondary order
|
||||
// of field index length. Loop over names; for each name, delete
|
||||
// hidden fields by choosing the one dominant field that survives.
|
||||
out := fields[:0]
|
||||
for advance, i := 0, 0; i < len(fields); i += advance {
|
||||
// One iteration per name.
|
||||
// Find the sequence of fields with the name of this first field.
|
||||
fi := fields[i]
|
||||
name := fi.name
|
||||
for advance = 1; i+advance < len(fields); advance++ {
|
||||
fj := fields[i+advance]
|
||||
if fj.name != name {
|
||||
break
|
||||
}
|
||||
}
|
||||
if advance == 1 { // Only one field with this name
|
||||
out = append(out, fi)
|
||||
continue
|
||||
}
|
||||
dominant, ok := dominantField(fields[i : i+advance])
|
||||
if ok {
|
||||
out = append(out, dominant)
|
||||
}
|
||||
}
|
||||
|
||||
fields = out
|
||||
sort.Sort(byIndex(fields))
|
||||
|
||||
return fields
|
||||
}
|
||||
|
||||
// dominantField looks through the fields, all of which are known to
|
||||
// have the same name, to find the single field that dominates the
|
||||
// others using Go's embedding rules, modified by the presence of
|
||||
// TOML tags. If there are multiple top-level fields, the boolean
|
||||
// will be false: This condition is an error in Go and we skip all
|
||||
// the fields.
|
||||
func dominantField(fields []field) (field, bool) {
|
||||
// The fields are sorted in increasing index-length order. The winner
|
||||
// must therefore be one with the shortest index length. Drop all
|
||||
// longer entries, which is easy: just truncate the slice.
|
||||
length := len(fields[0].index)
|
||||
tagged := -1 // Index of first tagged field.
|
||||
for i, f := range fields {
|
||||
if len(f.index) > length {
|
||||
fields = fields[:i]
|
||||
break
|
||||
}
|
||||
if f.tag {
|
||||
if tagged >= 0 {
|
||||
// Multiple tagged fields at the same level: conflict.
|
||||
// Return no field.
|
||||
return field{}, false
|
||||
}
|
||||
tagged = i
|
||||
}
|
||||
}
|
||||
if tagged >= 0 {
|
||||
return fields[tagged], true
|
||||
}
|
||||
// All remaining fields have the same length. If there's more than one,
|
||||
// we have a conflict (two fields named "X" at the same level) and we
|
||||
// return no field.
|
||||
if len(fields) > 1 {
|
||||
return field{}, false
|
||||
}
|
||||
return fields[0], true
|
||||
}
|
||||
|
||||
var fieldCache struct {
|
||||
sync.RWMutex
|
||||
m map[reflect.Type][]field
|
||||
}
|
||||
|
||||
// cachedTypeFields is like typeFields but uses a cache to avoid repeated work.
|
||||
func cachedTypeFields(t reflect.Type) []field {
|
||||
fieldCache.RLock()
|
||||
f := fieldCache.m[t]
|
||||
fieldCache.RUnlock()
|
||||
if f != nil {
|
||||
return f
|
||||
}
|
||||
|
||||
// Compute fields without lock.
|
||||
// Might duplicate effort but won't hold other computations back.
|
||||
f = typeFields(t)
|
||||
if f == nil {
|
||||
f = []field{}
|
||||
}
|
||||
|
||||
fieldCache.Lock()
|
||||
if fieldCache.m == nil {
|
||||
fieldCache.m = map[reflect.Type][]field{}
|
||||
}
|
||||
fieldCache.m[t] = f
|
||||
fieldCache.Unlock()
|
||||
return f
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
Copyright (C) 2013 Blake Mizerany
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining
|
||||
a copy of this software and associated documentation files (the
|
||||
"Software"), to deal in the Software without restriction, including
|
||||
without limitation the rights to use, copy, modify, merge, publish,
|
||||
distribute, sublicense, and/or sell copies of the Software, and to
|
||||
permit persons to whom the Software is furnished to do so, subject to
|
||||
the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be
|
||||
included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
||||
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
||||
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,316 @@
|
|||
// Package quantile computes approximate quantiles over an unbounded data
|
||||
// stream within low memory and CPU bounds.
|
||||
//
|
||||
// A small amount of accuracy is traded to achieve the above properties.
|
||||
//
|
||||
// Multiple streams can be merged before calling Query to generate a single set
|
||||
// of results. This is meaningful when the streams represent the same type of
|
||||
// data. See Merge and Samples.
|
||||
//
|
||||
// For more detailed information about the algorithm used, see:
|
||||
//
|
||||
// Effective Computation of Biased Quantiles over Data Streams
|
||||
//
|
||||
// http://www.cs.rutgers.edu/~muthu/bquant.pdf
|
||||
package quantile
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// Sample holds an observed value and meta information for compression. JSON
|
||||
// tags have been added for convenience.
|
||||
type Sample struct {
|
||||
Value float64 `json:",string"`
|
||||
Width float64 `json:",string"`
|
||||
Delta float64 `json:",string"`
|
||||
}
|
||||
|
||||
// Samples represents a slice of samples. It implements sort.Interface.
|
||||
type Samples []Sample
|
||||
|
||||
func (a Samples) Len() int { return len(a) }
|
||||
func (a Samples) Less(i, j int) bool { return a[i].Value < a[j].Value }
|
||||
func (a Samples) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||
|
||||
type invariant func(s *stream, r float64) float64
|
||||
|
||||
// NewLowBiased returns an initialized Stream for low-biased quantiles
|
||||
// (e.g. 0.01, 0.1, 0.5) where the needed quantiles are not known a priori, but
|
||||
// error guarantees can still be given even for the lower ranks of the data
|
||||
// distribution.
|
||||
//
|
||||
// The provided epsilon is a relative error, i.e. the true quantile of a value
|
||||
// returned by a query is guaranteed to be within (1±Epsilon)*Quantile.
|
||||
//
|
||||
// See http://www.cs.rutgers.edu/~muthu/bquant.pdf for time, space, and error
|
||||
// properties.
|
||||
func NewLowBiased(epsilon float64) *Stream {
|
||||
ƒ := func(s *stream, r float64) float64 {
|
||||
return 2 * epsilon * r
|
||||
}
|
||||
return newStream(ƒ)
|
||||
}
|
||||
|
||||
// NewHighBiased returns an initialized Stream for high-biased quantiles
|
||||
// (e.g. 0.01, 0.1, 0.5) where the needed quantiles are not known a priori, but
|
||||
// error guarantees can still be given even for the higher ranks of the data
|
||||
// distribution.
|
||||
//
|
||||
// The provided epsilon is a relative error, i.e. the true quantile of a value
|
||||
// returned by a query is guaranteed to be within 1-(1±Epsilon)*(1-Quantile).
|
||||
//
|
||||
// See http://www.cs.rutgers.edu/~muthu/bquant.pdf for time, space, and error
|
||||
// properties.
|
||||
func NewHighBiased(epsilon float64) *Stream {
|
||||
ƒ := func(s *stream, r float64) float64 {
|
||||
return 2 * epsilon * (s.n - r)
|
||||
}
|
||||
return newStream(ƒ)
|
||||
}
|
||||
|
||||
// NewTargeted returns an initialized Stream concerned with a particular set of
|
||||
// quantile values that are supplied a priori. Knowing these a priori reduces
|
||||
// space and computation time. The targets map maps the desired quantiles to
|
||||
// their absolute errors, i.e. the true quantile of a value returned by a query
|
||||
// is guaranteed to be within (Quantile±Epsilon).
|
||||
//
|
||||
// See http://www.cs.rutgers.edu/~muthu/bquant.pdf for time, space, and error properties.
|
||||
func NewTargeted(targetMap map[float64]float64) *Stream {
|
||||
// Convert map to slice to avoid slow iterations on a map.
|
||||
// ƒ is called on the hot path, so converting the map to a slice
|
||||
// beforehand results in significant CPU savings.
|
||||
targets := targetMapToSlice(targetMap)
|
||||
|
||||
ƒ := func(s *stream, r float64) float64 {
|
||||
var m = math.MaxFloat64
|
||||
var f float64
|
||||
for _, t := range targets {
|
||||
if t.quantile*s.n <= r {
|
||||
f = (2 * t.epsilon * r) / t.quantile
|
||||
} else {
|
||||
f = (2 * t.epsilon * (s.n - r)) / (1 - t.quantile)
|
||||
}
|
||||
if f < m {
|
||||
m = f
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
return newStream(ƒ)
|
||||
}
|
||||
|
||||
type target struct {
|
||||
quantile float64
|
||||
epsilon float64
|
||||
}
|
||||
|
||||
func targetMapToSlice(targetMap map[float64]float64) []target {
|
||||
targets := make([]target, 0, len(targetMap))
|
||||
|
||||
for quantile, epsilon := range targetMap {
|
||||
t := target{
|
||||
quantile: quantile,
|
||||
epsilon: epsilon,
|
||||
}
|
||||
targets = append(targets, t)
|
||||
}
|
||||
|
||||
return targets
|
||||
}
|
||||
|
||||
// Stream computes quantiles for a stream of float64s. It is not thread-safe by
|
||||
// design. Take care when using across multiple goroutines.
|
||||
type Stream struct {
|
||||
*stream
|
||||
b Samples
|
||||
sorted bool
|
||||
}
|
||||
|
||||
func newStream(ƒ invariant) *Stream {
|
||||
x := &stream{ƒ: ƒ}
|
||||
return &Stream{x, make(Samples, 0, 500), true}
|
||||
}
|
||||
|
||||
// Insert inserts v into the stream.
|
||||
func (s *Stream) Insert(v float64) {
|
||||
s.insert(Sample{Value: v, Width: 1})
|
||||
}
|
||||
|
||||
func (s *Stream) insert(sample Sample) {
|
||||
s.b = append(s.b, sample)
|
||||
s.sorted = false
|
||||
if len(s.b) == cap(s.b) {
|
||||
s.flush()
|
||||
}
|
||||
}
|
||||
|
||||
// Query returns the computed qth percentiles value. If s was created with
|
||||
// NewTargeted, and q is not in the set of quantiles provided a priori, Query
|
||||
// will return an unspecified result.
|
||||
func (s *Stream) Query(q float64) float64 {
|
||||
if !s.flushed() {
|
||||
// Fast path when there hasn't been enough data for a flush;
|
||||
// this also yields better accuracy for small sets of data.
|
||||
l := len(s.b)
|
||||
if l == 0 {
|
||||
return 0
|
||||
}
|
||||
i := int(math.Ceil(float64(l) * q))
|
||||
if i > 0 {
|
||||
i -= 1
|
||||
}
|
||||
s.maybeSort()
|
||||
return s.b[i].Value
|
||||
}
|
||||
s.flush()
|
||||
return s.stream.query(q)
|
||||
}
|
||||
|
||||
// Merge merges samples into the underlying streams samples. This is handy when
|
||||
// merging multiple streams from separate threads, database shards, etc.
|
||||
//
|
||||
// ATTENTION: This method is broken and does not yield correct results. The
|
||||
// underlying algorithm is not capable of merging streams correctly.
|
||||
func (s *Stream) Merge(samples Samples) {
|
||||
sort.Sort(samples)
|
||||
s.stream.merge(samples)
|
||||
}
|
||||
|
||||
// Reset reinitializes and clears the list reusing the samples buffer memory.
|
||||
func (s *Stream) Reset() {
|
||||
s.stream.reset()
|
||||
s.b = s.b[:0]
|
||||
}
|
||||
|
||||
// Samples returns stream samples held by s.
|
||||
func (s *Stream) Samples() Samples {
|
||||
if !s.flushed() {
|
||||
return s.b
|
||||
}
|
||||
s.flush()
|
||||
return s.stream.samples()
|
||||
}
|
||||
|
||||
// Count returns the total number of samples observed in the stream
|
||||
// since initialization.
|
||||
func (s *Stream) Count() int {
|
||||
return len(s.b) + s.stream.count()
|
||||
}
|
||||
|
||||
func (s *Stream) flush() {
|
||||
s.maybeSort()
|
||||
s.stream.merge(s.b)
|
||||
s.b = s.b[:0]
|
||||
}
|
||||
|
||||
func (s *Stream) maybeSort() {
|
||||
if !s.sorted {
|
||||
s.sorted = true
|
||||
sort.Sort(s.b)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Stream) flushed() bool {
|
||||
return len(s.stream.l) > 0
|
||||
}
|
||||
|
||||
type stream struct {
|
||||
n float64
|
||||
l []Sample
|
||||
ƒ invariant
|
||||
}
|
||||
|
||||
func (s *stream) reset() {
|
||||
s.l = s.l[:0]
|
||||
s.n = 0
|
||||
}
|
||||
|
||||
func (s *stream) insert(v float64) {
|
||||
s.merge(Samples{{v, 1, 0}})
|
||||
}
|
||||
|
||||
func (s *stream) merge(samples Samples) {
|
||||
// TODO(beorn7): This tries to merge not only individual samples, but
|
||||
// whole summaries. The paper doesn't mention merging summaries at
|
||||
// all. Unittests show that the merging is inaccurate. Find out how to
|
||||
// do merges properly.
|
||||
var r float64
|
||||
i := 0
|
||||
for _, sample := range samples {
|
||||
for ; i < len(s.l); i++ {
|
||||
c := s.l[i]
|
||||
if c.Value > sample.Value {
|
||||
// Insert at position i.
|
||||
s.l = append(s.l, Sample{})
|
||||
copy(s.l[i+1:], s.l[i:])
|
||||
s.l[i] = Sample{
|
||||
sample.Value,
|
||||
sample.Width,
|
||||
math.Max(sample.Delta, math.Floor(s.ƒ(s, r))-1),
|
||||
// TODO(beorn7): How to calculate delta correctly?
|
||||
}
|
||||
i++
|
||||
goto inserted
|
||||
}
|
||||
r += c.Width
|
||||
}
|
||||
s.l = append(s.l, Sample{sample.Value, sample.Width, 0})
|
||||
i++
|
||||
inserted:
|
||||
s.n += sample.Width
|
||||
r += sample.Width
|
||||
}
|
||||
s.compress()
|
||||
}
|
||||
|
||||
func (s *stream) count() int {
|
||||
return int(s.n)
|
||||
}
|
||||
|
||||
func (s *stream) query(q float64) float64 {
|
||||
t := math.Ceil(q * s.n)
|
||||
t += math.Ceil(s.ƒ(s, t) / 2)
|
||||
p := s.l[0]
|
||||
var r float64
|
||||
for _, c := range s.l[1:] {
|
||||
r += p.Width
|
||||
if r+c.Width+c.Delta > t {
|
||||
return p.Value
|
||||
}
|
||||
p = c
|
||||
}
|
||||
return p.Value
|
||||
}
|
||||
|
||||
func (s *stream) compress() {
|
||||
if len(s.l) < 2 {
|
||||
return
|
||||
}
|
||||
x := s.l[len(s.l)-1]
|
||||
xi := len(s.l) - 1
|
||||
r := s.n - 1 - x.Width
|
||||
|
||||
for i := len(s.l) - 2; i >= 0; i-- {
|
||||
c := s.l[i]
|
||||
if c.Width+x.Width+x.Delta <= s.ƒ(s, r) {
|
||||
x.Width += c.Width
|
||||
s.l[xi] = x
|
||||
// Remove element at i.
|
||||
copy(s.l[i:], s.l[i+1:])
|
||||
s.l = s.l[:len(s.l)-1]
|
||||
xi -= 1
|
||||
} else {
|
||||
x = c
|
||||
xi = i
|
||||
}
|
||||
r -= c.Width
|
||||
}
|
||||
}
|
||||
|
||||
func (s *stream) samples() Samples {
|
||||
samples := make(Samples, len(s.l))
|
||||
copy(samples, s.l)
|
||||
return samples
|
||||
}
|
|
@ -0,0 +1,3 @@
|
|||
This Source Code Form is subject to the terms of the Mozilla Public License,
|
||||
v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain
|
||||
one at http://mozilla.org/MPL/2.0/.
|
|
@ -0,0 +1,60 @@
|
|||
# GoCertifi: SSL Certificates for Golang
|
||||
|
||||
This Go package contains a CA bundle that you can reference in your Go code.
|
||||
This is useful for systems that do not have CA bundles that Golang can find
|
||||
itself, or where a uniform set of CAs is valuable.
|
||||
|
||||
This is the same CA bundle that ships with the
|
||||
[Python Requests](https://github.com/kennethreitz/requests) library, and is a
|
||||
Golang specific port of [certifi](https://github.com/kennethreitz/certifi). The
|
||||
CA bundle is derived from Mozilla's canonical set.
|
||||
|
||||
## Usage
|
||||
|
||||
You can use the `gocertifi` package as follows:
|
||||
|
||||
```go
|
||||
import "github.com/certifi/gocertifi"
|
||||
|
||||
cert_pool, err := gocertifi.CACerts()
|
||||
```
|
||||
|
||||
You can use the returned `*x509.CertPool` as part of an HTTP transport, for example:
|
||||
|
||||
```go
|
||||
import (
|
||||
"net/http"
|
||||
"crypto/tls"
|
||||
)
|
||||
|
||||
// Setup an HTTP client with a custom transport
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{RootCAs: cert_pool},
|
||||
}
|
||||
client := &http.Client{Transport: transport}
|
||||
|
||||
// Make an HTTP request using our custom transport
|
||||
resp, err := client.Get("https://example.com")
|
||||
```
|
||||
|
||||
## Detailed Documentation
|
||||
|
||||
Import as follows:
|
||||
|
||||
```go
|
||||
import "github.com/certifi/gocertifi"
|
||||
```
|
||||
|
||||
### Errors
|
||||
|
||||
```go
|
||||
var ErrParseFailed = errors.New("gocertifi: error when parsing certificates")
|
||||
```
|
||||
|
||||
### Functions
|
||||
|
||||
```go
|
||||
func CACerts() (*x509.CertPool, error)
|
||||
```
|
||||
CACerts builds an X.509 certificate pool containing the Mozilla CA Certificate
|
||||
bundle. Returns nil on error along with an appropriate error code.
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,72 @@
|
|||
// +build ignore
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"text/template"
|
||||
"time"
|
||||
)
|
||||
|
||||
func main() {
|
||||
const url = "https://mkcert.org/generate/"
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
log.Fatal("expected 200, got", resp.StatusCode)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
bundle, err := ioutil.ReadAll(resp.Body)
|
||||
|
||||
pool := x509.NewCertPool()
|
||||
if !pool.AppendCertsFromPEM(bundle) {
|
||||
log.Fatalf("can't parse cerficiates from %s", url)
|
||||
}
|
||||
|
||||
fp, err := os.Create("certifi.go")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer fp.Close()
|
||||
|
||||
tmpl.Execute(fp, struct {
|
||||
Timestamp time.Time
|
||||
URL string
|
||||
Bundle string
|
||||
}{
|
||||
Timestamp: time.Now(),
|
||||
URL: url,
|
||||
Bundle: string(bundle),
|
||||
})
|
||||
}
|
||||
|
||||
var tmpl = template.Must(template.New("").Parse(`// Code generated by go generate; DO NOT EDIT.
|
||||
// {{ .Timestamp }}
|
||||
// {{ .URL }}
|
||||
|
||||
package gocertifi
|
||||
|
||||
//go:generate go run gen.go
|
||||
|
||||
import "crypto/x509"
|
||||
|
||||
const pemcerts string = ` + "`" + `
|
||||
{{ .Bundle }}
|
||||
` + "`" + `
|
||||
|
||||
// CACerts builds an X.509 certificate pool containing the Mozilla CA
|
||||
// Certificate bundle. Returns nil on error along with an appropriate error
|
||||
// code.
|
||||
func CACerts() (*x509.CertPool, error) {
|
||||
pool := x509.NewCertPool()
|
||||
pool.AppendCertsFromPEM([]byte(pemcerts))
|
||||
return pool, nil
|
||||
}
|
||||
`))
|
|
@ -0,0 +1,130 @@
|
|||
/* Copyright 2013 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
/* Function to find backward reference copies. */
|
||||
|
||||
#include "./enc/backward_references.h"
|
||||
|
||||
#include "./common/constants.h"
|
||||
#include "./common/dictionary.h"
|
||||
#include <brotli/types.h>
|
||||
#include "./enc/command.h"
|
||||
#include "./enc/dictionary_hash.h"
|
||||
#include "./enc/memory.h"
|
||||
#include "./enc/port.h"
|
||||
#include "./enc/quality.h"
|
||||
|
||||
#if defined(__cplusplus) || defined(c_plusplus)
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
static BROTLI_INLINE size_t ComputeDistanceCode(size_t distance,
|
||||
size_t max_distance,
|
||||
const int* dist_cache) {
|
||||
if (distance <= max_distance) {
|
||||
size_t distance_plus_3 = distance + 3;
|
||||
size_t offset0 = distance_plus_3 - (size_t)dist_cache[0];
|
||||
size_t offset1 = distance_plus_3 - (size_t)dist_cache[1];
|
||||
if (distance == (size_t)dist_cache[0]) {
|
||||
return 0;
|
||||
} else if (distance == (size_t)dist_cache[1]) {
|
||||
return 1;
|
||||
} else if (offset0 < 7) {
|
||||
return (0x9750468 >> (4 * offset0)) & 0xF;
|
||||
} else if (offset1 < 7) {
|
||||
return (0xFDB1ACE >> (4 * offset1)) & 0xF;
|
||||
} else if (distance == (size_t)dist_cache[2]) {
|
||||
return 2;
|
||||
} else if (distance == (size_t)dist_cache[3]) {
|
||||
return 3;
|
||||
}
|
||||
}
|
||||
return distance + BROTLI_NUM_DISTANCE_SHORT_CODES - 1;
|
||||
}
|
||||
|
||||
#define EXPAND_CAT(a, b) CAT(a, b)
|
||||
#define CAT(a, b) a ## b
|
||||
#define FN(X) EXPAND_CAT(X, HASHER())
|
||||
|
||||
#define HASHER() H2
|
||||
/* NOLINTNEXTLINE(build/include) */
|
||||
#include "./enc/backward_references_inc.h"
|
||||
#undef HASHER
|
||||
|
||||
#define HASHER() H3
|
||||
/* NOLINTNEXTLINE(build/include) */
|
||||
#include "./enc/backward_references_inc.h"
|
||||
#undef HASHER
|
||||
|
||||
#define HASHER() H4
|
||||
/* NOLINTNEXTLINE(build/include) */
|
||||
#include "./enc/backward_references_inc.h"
|
||||
#undef HASHER
|
||||
|
||||
#define HASHER() H5
|
||||
/* NOLINTNEXTLINE(build/include) */
|
||||
#include "./enc/backward_references_inc.h"
|
||||
#undef HASHER
|
||||
|
||||
#define HASHER() H6
|
||||
/* NOLINTNEXTLINE(build/include) */
|
||||
#include "./enc/backward_references_inc.h"
|
||||
#undef HASHER
|
||||
|
||||
#define HASHER() H40
|
||||
/* NOLINTNEXTLINE(build/include) */
|
||||
#include "./enc/backward_references_inc.h"
|
||||
#undef HASHER
|
||||
|
||||
#define HASHER() H41
|
||||
/* NOLINTNEXTLINE(build/include) */
|
||||
#include "./enc/backward_references_inc.h"
|
||||
#undef HASHER
|
||||
|
||||
#define HASHER() H42
|
||||
/* NOLINTNEXTLINE(build/include) */
|
||||
#include "./enc/backward_references_inc.h"
|
||||
#undef HASHER
|
||||
|
||||
#define HASHER() H54
|
||||
/* NOLINTNEXTLINE(build/include) */
|
||||
#include "./enc/backward_references_inc.h"
|
||||
#undef HASHER
|
||||
|
||||
#undef FN
|
||||
#undef CAT
|
||||
#undef EXPAND_CAT
|
||||
|
||||
void BrotliCreateBackwardReferences(const BrotliDictionary* dictionary,
|
||||
size_t num_bytes,
|
||||
size_t position,
|
||||
const uint8_t* ringbuffer,
|
||||
size_t ringbuffer_mask,
|
||||
const BrotliEncoderParams* params,
|
||||
HasherHandle hasher,
|
||||
int* dist_cache,
|
||||
size_t* last_insert_len,
|
||||
Command* commands,
|
||||
size_t* num_commands,
|
||||
size_t* num_literals) {
|
||||
switch (params->hasher.type) {
|
||||
#define CASE_(N) \
|
||||
case N: \
|
||||
CreateBackwardReferencesH ## N(dictionary, \
|
||||
kStaticDictionaryHash, num_bytes, position, ringbuffer, \
|
||||
ringbuffer_mask, params, hasher, dist_cache, \
|
||||
last_insert_len, commands, num_commands, num_literals); \
|
||||
break;
|
||||
FOR_GENERIC_HASHERS(CASE_)
|
||||
#undef CASE_
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(__cplusplus) || defined(c_plusplus)
|
||||
} /* extern "C" */
|
||||
#endif
|
|
@ -0,0 +1,790 @@
|
|||
/* Copyright 2013 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
/* Function to find backward reference copies. */
|
||||
|
||||
#include "./enc/backward_references_hq.h"
|
||||
|
||||
#include <string.h> /* memcpy, memset */
|
||||
|
||||
#include "./common/constants.h"
|
||||
#include <brotli/types.h>
|
||||
#include "./enc/command.h"
|
||||
#include "./enc/fast_log.h"
|
||||
#include "./enc/find_match_length.h"
|
||||
#include "./enc/literal_cost.h"
|
||||
#include "./enc/memory.h"
|
||||
#include "./enc/port.h"
|
||||
#include "./enc/prefix.h"
|
||||
#include "./enc/quality.h"
|
||||
|
||||
#if defined(__cplusplus) || defined(c_plusplus)
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
static const float kInfinity = 1.7e38f; /* ~= 2 ^ 127 */
|
||||
|
||||
static const uint32_t kDistanceCacheIndex[] = {
|
||||
0, 1, 2, 3, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1,
|
||||
};
|
||||
static const int kDistanceCacheOffset[] = {
|
||||
0, 0, 0, 0, -1, 1, -2, 2, -3, 3, -1, 1, -2, 2, -3, 3
|
||||
};
|
||||
|
||||
void BrotliInitZopfliNodes(ZopfliNode* array, size_t length) {
|
||||
ZopfliNode stub;
|
||||
size_t i;
|
||||
stub.length = 1;
|
||||
stub.distance = 0;
|
||||
stub.insert_length = 0;
|
||||
stub.u.cost = kInfinity;
|
||||
for (i = 0; i < length; ++i) array[i] = stub;
|
||||
}
|
||||
|
||||
static BROTLI_INLINE uint32_t ZopfliNodeCopyLength(const ZopfliNode* self) {
|
||||
return self->length & 0xffffff;
|
||||
}
|
||||
|
||||
static BROTLI_INLINE uint32_t ZopfliNodeLengthCode(const ZopfliNode* self) {
|
||||
const uint32_t modifier = self->length >> 24;
|
||||
return ZopfliNodeCopyLength(self) + 9u - modifier;
|
||||
}
|
||||
|
||||
static BROTLI_INLINE uint32_t ZopfliNodeCopyDistance(const ZopfliNode* self) {
|
||||
return self->distance & 0x1ffffff;
|
||||
}
|
||||
|
||||
static BROTLI_INLINE uint32_t ZopfliNodeDistanceCode(const ZopfliNode* self) {
|
||||
const uint32_t short_code = self->distance >> 25;
|
||||
return short_code == 0 ?
|
||||
ZopfliNodeCopyDistance(self) + BROTLI_NUM_DISTANCE_SHORT_CODES - 1 :
|
||||
short_code - 1;
|
||||
}
|
||||
|
||||
static BROTLI_INLINE uint32_t ZopfliNodeCommandLength(const ZopfliNode* self) {
|
||||
return ZopfliNodeCopyLength(self) + self->insert_length;
|
||||
}
|
||||
|
||||
/* Histogram based cost model for zopflification. */
|
||||
typedef struct ZopfliCostModel {
|
||||
/* The insert and copy length symbols. */
|
||||
float cost_cmd_[BROTLI_NUM_COMMAND_SYMBOLS];
|
||||
float cost_dist_[BROTLI_NUM_DISTANCE_SYMBOLS];
|
||||
/* Cumulative costs of literals per position in the stream. */
|
||||
float* literal_costs_;
|
||||
float min_cost_cmd_;
|
||||
size_t num_bytes_;
|
||||
} ZopfliCostModel;
|
||||
|
||||
static void InitZopfliCostModel(
|
||||
MemoryManager* m, ZopfliCostModel* self, size_t num_bytes) {
|
||||
self->num_bytes_ = num_bytes;
|
||||
self->literal_costs_ = BROTLI_ALLOC(m, float, num_bytes + 2);
|
||||
if (BROTLI_IS_OOM(m)) return;
|
||||
}
|
||||
|
||||
static void CleanupZopfliCostModel(MemoryManager* m, ZopfliCostModel* self) {
|
||||
BROTLI_FREE(m, self->literal_costs_);
|
||||
}
|
||||
|
||||
static void SetCost(const uint32_t* histogram, size_t histogram_size,
|
||||
float* cost) {
|
||||
size_t sum = 0;
|
||||
float log2sum;
|
||||
size_t i;
|
||||
for (i = 0; i < histogram_size; i++) {
|
||||
sum += histogram[i];
|
||||
}
|
||||
log2sum = (float)FastLog2(sum);
|
||||
for (i = 0; i < histogram_size; i++) {
|
||||
if (histogram[i] == 0) {
|
||||
cost[i] = log2sum + 2;
|
||||
continue;
|
||||
}
|
||||
|
||||
/* Shannon bits for this symbol. */
|
||||
cost[i] = log2sum - (float)FastLog2(histogram[i]);
|
||||
|
||||
/* Cannot be coded with less than 1 bit */
|
||||
if (cost[i] < 1) cost[i] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
static void ZopfliCostModelSetFromCommands(ZopfliCostModel* self,
|
||||
size_t position,
|
||||
const uint8_t* ringbuffer,
|
||||
size_t ringbuffer_mask,
|
||||
const Command* commands,
|
||||
size_t num_commands,
|
||||
size_t last_insert_len) {
|
||||
uint32_t histogram_literal[BROTLI_NUM_LITERAL_SYMBOLS];
|
||||
uint32_t histogram_cmd[BROTLI_NUM_COMMAND_SYMBOLS];
|
||||
uint32_t histogram_dist[BROTLI_NUM_DISTANCE_SYMBOLS];
|
||||
float cost_literal[BROTLI_NUM_LITERAL_SYMBOLS];
|
||||
size_t pos = position - last_insert_len;
|
||||
float min_cost_cmd = kInfinity;
|
||||
size_t i;
|
||||
float* cost_cmd = self->cost_cmd_;
|
||||
|
||||
memset(histogram_literal, 0, sizeof(histogram_literal));
|
||||
memset(histogram_cmd, 0, sizeof(histogram_cmd));
|
||||
memset(histogram_dist, 0, sizeof(histogram_dist));
|
||||
|
||||
for (i = 0; i < num_commands; i++) {
|
||||
size_t inslength = commands[i].insert_len_;
|
||||
size_t copylength = CommandCopyLen(&commands[i]);
|
||||
size_t distcode = commands[i].dist_prefix_;
|
||||
size_t cmdcode = commands[i].cmd_prefix_;
|
||||
size_t j;
|
||||
|
||||
histogram_cmd[cmdcode]++;
|
||||
if (cmdcode >= 128) histogram_dist[distcode]++;
|
||||
|
||||
for (j = 0; j < inslength; j++) {
|
||||
histogram_literal[ringbuffer[(pos + j) & ringbuffer_mask]]++;
|
||||
}
|
||||
|
||||
pos += inslength + copylength;
|
||||
}
|
||||
|
||||
SetCost(histogram_literal, BROTLI_NUM_LITERAL_SYMBOLS, cost_literal);
|
||||
SetCost(histogram_cmd, BROTLI_NUM_COMMAND_SYMBOLS, cost_cmd);
|
||||
SetCost(histogram_dist, BROTLI_NUM_DISTANCE_SYMBOLS, self->cost_dist_);
|
||||
|
||||
for (i = 0; i < BROTLI_NUM_COMMAND_SYMBOLS; ++i) {
|
||||
min_cost_cmd = BROTLI_MIN(float, min_cost_cmd, cost_cmd[i]);
|
||||
}
|
||||
self->min_cost_cmd_ = min_cost_cmd;
|
||||
|
||||
{
|
||||
float* literal_costs = self->literal_costs_;
|
||||
size_t num_bytes = self->num_bytes_;
|
||||
literal_costs[0] = 0.0;
|
||||
for (i = 0; i < num_bytes; ++i) {
|
||||
literal_costs[i + 1] = literal_costs[i] +
|
||||
cost_literal[ringbuffer[(position + i) & ringbuffer_mask]];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ZopfliCostModelSetFromLiteralCosts(ZopfliCostModel* self,
|
||||
size_t position,
|
||||
const uint8_t* ringbuffer,
|
||||
size_t ringbuffer_mask) {
|
||||
float* literal_costs = self->literal_costs_;
|
||||
float* cost_dist = self->cost_dist_;
|
||||
float* cost_cmd = self->cost_cmd_;
|
||||
size_t num_bytes = self->num_bytes_;
|
||||
size_t i;
|
||||
BrotliEstimateBitCostsForLiterals(position, num_bytes, ringbuffer_mask,
|
||||
ringbuffer, &literal_costs[1]);
|
||||
literal_costs[0] = 0.0;
|
||||
for (i = 0; i < num_bytes; ++i) {
|
||||
literal_costs[i + 1] += literal_costs[i];
|
||||
}
|
||||
for (i = 0; i < BROTLI_NUM_COMMAND_SYMBOLS; ++i) {
|
||||
cost_cmd[i] = (float)FastLog2(11 + (uint32_t)i);
|
||||
}
|
||||
for (i = 0; i < BROTLI_NUM_DISTANCE_SYMBOLS; ++i) {
|
||||
cost_dist[i] = (float)FastLog2(20 + (uint32_t)i);
|
||||
}
|
||||
self->min_cost_cmd_ = (float)FastLog2(11);
|
||||
}
|
||||
|
||||
static BROTLI_INLINE float ZopfliCostModelGetCommandCost(
|
||||
const ZopfliCostModel* self, uint16_t cmdcode) {
|
||||
return self->cost_cmd_[cmdcode];
|
||||
}
|
||||
|
||||
static BROTLI_INLINE float ZopfliCostModelGetDistanceCost(
|
||||
const ZopfliCostModel* self, size_t distcode) {
|
||||
return self->cost_dist_[distcode];
|
||||
}
|
||||
|
||||
static BROTLI_INLINE float ZopfliCostModelGetLiteralCosts(
|
||||
const ZopfliCostModel* self, size_t from, size_t to) {
|
||||
return self->literal_costs_[to] - self->literal_costs_[from];
|
||||
}
|
||||
|
||||
static BROTLI_INLINE float ZopfliCostModelGetMinCostCmd(
|
||||
const ZopfliCostModel* self) {
|
||||
return self->min_cost_cmd_;
|
||||
}
|
||||
|
||||
/* REQUIRES: len >= 2, start_pos <= pos */
|
||||
/* REQUIRES: cost < kInfinity, nodes[start_pos].cost < kInfinity */
|
||||
/* Maintains the "ZopfliNode array invariant". */
|
||||
static BROTLI_INLINE void UpdateZopfliNode(ZopfliNode* nodes, size_t pos,
|
||||
size_t start_pos, size_t len, size_t len_code, size_t dist,
|
||||
size_t short_code, float cost) {
|
||||
ZopfliNode* next = &nodes[pos + len];
|
||||
next->length = (uint32_t)(len | ((len + 9u - len_code) << 24));
|
||||
next->distance = (uint32_t)(dist | (short_code << 25));
|
||||
next->insert_length = (uint32_t)(pos - start_pos);
|
||||
next->u.cost = cost;
|
||||
}
|
||||
|
||||
typedef struct PosData {
|
||||
size_t pos;
|
||||
int distance_cache[4];
|
||||
float costdiff;
|
||||
float cost;
|
||||
} PosData;
|
||||
|
||||
/* Maintains the smallest 8 cost difference together with their positions */
|
||||
typedef struct StartPosQueue {
|
||||
PosData q_[8];
|
||||
size_t idx_;
|
||||
} StartPosQueue;
|
||||
|
||||
static BROTLI_INLINE void InitStartPosQueue(StartPosQueue* self) {
|
||||
self->idx_ = 0;
|
||||
}
|
||||
|
||||
static size_t StartPosQueueSize(const StartPosQueue* self) {
|
||||
return BROTLI_MIN(size_t, self->idx_, 8);
|
||||
}
|
||||
|
||||
static void StartPosQueuePush(StartPosQueue* self, const PosData* posdata) {
|
||||
size_t offset = ~(self->idx_++) & 7;
|
||||
size_t len = StartPosQueueSize(self);
|
||||
size_t i;
|
||||
PosData* q = self->q_;
|
||||
q[offset] = *posdata;
|
||||
/* Restore the sorted order. In the list of |len| items at most |len - 1|
|
||||
adjacent element comparisons / swaps are required. */
|
||||
for (i = 1; i < len; ++i) {
|
||||
if (q[offset & 7].costdiff > q[(offset + 1) & 7].costdiff) {
|
||||
BROTLI_SWAP(PosData, q, offset & 7, (offset + 1) & 7);
|
||||
}
|
||||
++offset;
|
||||
}
|
||||
}
|
||||
|
||||
static const PosData* StartPosQueueAt(const StartPosQueue* self, size_t k) {
|
||||
return &self->q_[(k - self->idx_) & 7];
|
||||
}
|
||||
|
||||
/* Returns the minimum possible copy length that can improve the cost of any */
|
||||
/* future position. */
|
||||
static size_t ComputeMinimumCopyLength(const float start_cost,
|
||||
const ZopfliNode* nodes,
|
||||
const size_t num_bytes,
|
||||
const size_t pos) {
|
||||
/* Compute the minimum possible cost of reaching any future position. */
|
||||
float min_cost = start_cost;
|
||||
size_t len = 2;
|
||||
size_t next_len_bucket = 4;
|
||||
size_t next_len_offset = 10;
|
||||
while (pos + len <= num_bytes && nodes[pos + len].u.cost <= min_cost) {
|
||||
/* We already reached (pos + len) with no more cost than the minimum
|
||||
possible cost of reaching anything from this pos, so there is no point in
|
||||
looking for lengths <= len. */
|
||||
++len;
|
||||
if (len == next_len_offset) {
|
||||
/* We reached the next copy length code bucket, so we add one more
|
||||
extra bit to the minimum cost. */
|
||||
min_cost += 1.0f;
|
||||
next_len_offset += next_len_bucket;
|
||||
next_len_bucket *= 2;
|
||||
}
|
||||
}
|
||||
return len;
|
||||
}
|
||||
|
||||
/* REQUIRES: nodes[pos].cost < kInfinity
|
||||
REQUIRES: nodes[0..pos] satisfies that "ZopfliNode array invariant". */
|
||||
static uint32_t ComputeDistanceShortcut(const size_t block_start,
|
||||
const size_t pos,
|
||||
const size_t max_backward,
|
||||
const ZopfliNode* nodes) {
|
||||
const size_t clen = ZopfliNodeCopyLength(&nodes[pos]);
|
||||
const size_t ilen = nodes[pos].insert_length;
|
||||
const size_t dist = ZopfliNodeCopyDistance(&nodes[pos]);
|
||||
/* Since |block_start + pos| is the end position of the command, the copy part
|
||||
starts from |block_start + pos - clen|. Distances that are greater than
|
||||
this or greater than |max_backward| are static dictionary references, and
|
||||
do not update the last distances. Also distance code 0 (last distance)
|
||||
does not update the last distances. */
|
||||
if (pos == 0) {
|
||||
return 0;
|
||||
} else if (dist + clen <= block_start + pos &&
|
||||
dist <= max_backward &&
|
||||
ZopfliNodeDistanceCode(&nodes[pos]) > 0) {
|
||||
return (uint32_t)pos;
|
||||
} else {
|
||||
return nodes[pos - clen - ilen].u.shortcut;
|
||||
}
|
||||
}
|
||||
|
||||
/* Fills in dist_cache[0..3] with the last four distances (as defined by
|
||||
Section 4. of the Spec) that would be used at (block_start + pos) if we
|
||||
used the shortest path of commands from block_start, computed from
|
||||
nodes[0..pos]. The last four distances at block_start are in
|
||||
starting_dist_cache[0..3].
|
||||
REQUIRES: nodes[pos].cost < kInfinity
|
||||
REQUIRES: nodes[0..pos] satisfies that "ZopfliNode array invariant". */
|
||||
static void ComputeDistanceCache(const size_t pos,
|
||||
const int* starting_dist_cache,
|
||||
const ZopfliNode* nodes,
|
||||
int* dist_cache) {
|
||||
int idx = 0;
|
||||
size_t p = nodes[pos].u.shortcut;
|
||||
while (idx < 4 && p > 0) {
|
||||
const size_t ilen = nodes[p].insert_length;
|
||||
const size_t clen = ZopfliNodeCopyLength(&nodes[p]);
|
||||
const size_t dist = ZopfliNodeCopyDistance(&nodes[p]);
|
||||
dist_cache[idx++] = (int)dist;
|
||||
/* Because of prerequisite, p >= clen + ilen >= 2. */
|
||||
p = nodes[p - clen - ilen].u.shortcut;
|
||||
}
|
||||
for (; idx < 4; ++idx) {
|
||||
dist_cache[idx] = *starting_dist_cache++;
|
||||
}
|
||||
}
|
||||
|
||||
/* Maintains "ZopfliNode array invariant" and pushes node to the queue, if it
|
||||
is eligible. */
|
||||
static void EvaluateNode(
|
||||
const size_t block_start, const size_t pos, const size_t max_backward_limit,
|
||||
const int* starting_dist_cache, const ZopfliCostModel* model,
|
||||
StartPosQueue* queue, ZopfliNode* nodes) {
|
||||
/* Save cost, because ComputeDistanceCache invalidates it. */
|
||||
float node_cost = nodes[pos].u.cost;
|
||||
nodes[pos].u.shortcut = ComputeDistanceShortcut(
|
||||
block_start, pos, max_backward_limit, nodes);
|
||||
if (node_cost <= ZopfliCostModelGetLiteralCosts(model, 0, pos)) {
|
||||
PosData posdata;
|
||||
posdata.pos = pos;
|
||||
posdata.cost = node_cost;
|
||||
posdata.costdiff = node_cost -
|
||||
ZopfliCostModelGetLiteralCosts(model, 0, pos);
|
||||
ComputeDistanceCache(
|
||||
pos, starting_dist_cache, nodes, posdata.distance_cache);
|
||||
StartPosQueuePush(queue, &posdata);
|
||||
}
|
||||
}
|
||||
|
||||
/* Returns longest copy length. */
|
||||
static size_t UpdateNodes(
|
||||
const size_t num_bytes, const size_t block_start, const size_t pos,
|
||||
const uint8_t* ringbuffer, const size_t ringbuffer_mask,
|
||||
const BrotliEncoderParams* params, const size_t max_backward_limit,
|
||||
const int* starting_dist_cache, const size_t num_matches,
|
||||
const BackwardMatch* matches, const ZopfliCostModel* model,
|
||||
StartPosQueue* queue, ZopfliNode* nodes) {
|
||||
const size_t cur_ix = block_start + pos;
|
||||
const size_t cur_ix_masked = cur_ix & ringbuffer_mask;
|
||||
const size_t max_distance = BROTLI_MIN(size_t, cur_ix, max_backward_limit);
|
||||
const size_t max_len = num_bytes - pos;
|
||||
const size_t max_zopfli_len = MaxZopfliLen(params);
|
||||
const size_t max_iters = MaxZopfliCandidates(params);
|
||||
size_t min_len;
|
||||
size_t result = 0;
|
||||
size_t k;
|
||||
|
||||
EvaluateNode(block_start, pos, max_backward_limit, starting_dist_cache, model,
|
||||
queue, nodes);
|
||||
|
||||
{
|
||||
const PosData* posdata = StartPosQueueAt(queue, 0);
|
||||
float min_cost = (posdata->cost + ZopfliCostModelGetMinCostCmd(model) +
|
||||
ZopfliCostModelGetLiteralCosts(model, posdata->pos, pos));
|
||||
min_len = ComputeMinimumCopyLength(min_cost, nodes, num_bytes, pos);
|
||||
}
|
||||
|
||||
/* Go over the command starting positions in order of increasing cost
|
||||
difference. */
|
||||
for (k = 0; k < max_iters && k < StartPosQueueSize(queue); ++k) {
|
||||
const PosData* posdata = StartPosQueueAt(queue, k);
|
||||
const size_t start = posdata->pos;
|
||||
const uint16_t inscode = GetInsertLengthCode(pos - start);
|
||||
const float start_costdiff = posdata->costdiff;
|
||||
const float base_cost = start_costdiff + (float)GetInsertExtra(inscode) +
|
||||
ZopfliCostModelGetLiteralCosts(model, 0, pos);
|
||||
|
||||
/* Look for last distance matches using the distance cache from this
|
||||
starting position. */
|
||||
size_t best_len = min_len - 1;
|
||||
size_t j = 0;
|
||||
for (; j < BROTLI_NUM_DISTANCE_SHORT_CODES && best_len < max_len; ++j) {
|
||||
const size_t idx = kDistanceCacheIndex[j];
|
||||
const size_t backward =
|
||||
(size_t)(posdata->distance_cache[idx] + kDistanceCacheOffset[j]);
|
||||
size_t prev_ix = cur_ix - backward;
|
||||
if (prev_ix >= cur_ix) {
|
||||
continue;
|
||||
}
|
||||
if (BROTLI_PREDICT_FALSE(backward > max_distance)) {
|
||||
continue;
|
||||
}
|
||||
prev_ix &= ringbuffer_mask;
|
||||
|
||||
if (cur_ix_masked + best_len > ringbuffer_mask ||
|
||||
prev_ix + best_len > ringbuffer_mask ||
|
||||
ringbuffer[cur_ix_masked + best_len] !=
|
||||
ringbuffer[prev_ix + best_len]) {
|
||||
continue;
|
||||
}
|
||||
{
|
||||
const size_t len =
|
||||
FindMatchLengthWithLimit(&ringbuffer[prev_ix],
|
||||
&ringbuffer[cur_ix_masked],
|
||||
max_len);
|
||||
const float dist_cost = base_cost +
|
||||
ZopfliCostModelGetDistanceCost(model, j);
|
||||
size_t l;
|
||||
for (l = best_len + 1; l <= len; ++l) {
|
||||
const uint16_t copycode = GetCopyLengthCode(l);
|
||||
const uint16_t cmdcode =
|
||||
CombineLengthCodes(inscode, copycode, j == 0);
|
||||
const float cost = (cmdcode < 128 ? base_cost : dist_cost) +
|
||||
(float)GetCopyExtra(copycode) +
|
||||
ZopfliCostModelGetCommandCost(model, cmdcode);
|
||||
if (cost < nodes[pos + l].u.cost) {
|
||||
UpdateZopfliNode(nodes, pos, start, l, l, backward, j + 1, cost);
|
||||
result = BROTLI_MAX(size_t, result, l);
|
||||
}
|
||||
best_len = l;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* At higher iterations look only for new last distance matches, since
|
||||
looking only for new command start positions with the same distances
|
||||
does not help much. */
|
||||
if (k >= 2) continue;
|
||||
|
||||
{
|
||||
/* Loop through all possible copy lengths at this position. */
|
||||
size_t len = min_len;
|
||||
for (j = 0; j < num_matches; ++j) {
|
||||
BackwardMatch match = matches[j];
|
||||
size_t dist = match.distance;
|
||||
BROTLI_BOOL is_dictionary_match = TO_BROTLI_BOOL(dist > max_distance);
|
||||
/* We already tried all possible last distance matches, so we can use
|
||||
normal distance code here. */
|
||||
size_t dist_code = dist + BROTLI_NUM_DISTANCE_SHORT_CODES - 1;
|
||||
uint16_t dist_symbol;
|
||||
uint32_t distextra;
|
||||
uint32_t distnumextra;
|
||||
float dist_cost;
|
||||
size_t max_match_len;
|
||||
PrefixEncodeCopyDistance(dist_code, 0, 0, &dist_symbol, &distextra);
|
||||
distnumextra = distextra >> 24;
|
||||
dist_cost = base_cost + (float)distnumextra +
|
||||
ZopfliCostModelGetDistanceCost(model, dist_symbol);
|
||||
|
||||
/* Try all copy lengths up until the maximum copy length corresponding
|
||||
to this distance. If the distance refers to the static dictionary, or
|
||||
the maximum length is long enough, try only one maximum length. */
|
||||
max_match_len = BackwardMatchLength(&match);
|
||||
if (len < max_match_len &&
|
||||
(is_dictionary_match || max_match_len > max_zopfli_len)) {
|
||||
len = max_match_len;
|
||||
}
|
||||
for (; len <= max_match_len; ++len) {
|
||||
const size_t len_code =
|
||||
is_dictionary_match ? BackwardMatchLengthCode(&match) : len;
|
||||
const uint16_t copycode = GetCopyLengthCode(len_code);
|
||||
const uint16_t cmdcode = CombineLengthCodes(inscode, copycode, 0);
|
||||
const float cost = dist_cost + (float)GetCopyExtra(copycode) +
|
||||
ZopfliCostModelGetCommandCost(model, cmdcode);
|
||||
if (cost < nodes[pos + len].u.cost) {
|
||||
UpdateZopfliNode(nodes, pos, start, len, len_code, dist, 0, cost);
|
||||
result = BROTLI_MAX(size_t, result, len);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static size_t ComputeShortestPathFromNodes(size_t num_bytes,
|
||||
ZopfliNode* nodes) {
|
||||
size_t index = num_bytes;
|
||||
size_t num_commands = 0;
|
||||
while (nodes[index].insert_length == 0 && nodes[index].length == 1) --index;
|
||||
nodes[index].u.next = BROTLI_UINT32_MAX;
|
||||
while (index != 0) {
|
||||
size_t len = ZopfliNodeCommandLength(&nodes[index]);
|
||||
index -= len;
|
||||
nodes[index].u.next = (uint32_t)len;
|
||||
num_commands++;
|
||||
}
|
||||
return num_commands;
|
||||
}
|
||||
|
||||
/* REQUIRES: nodes != NULL and len(nodes) >= num_bytes + 1 */
|
||||
void BrotliZopfliCreateCommands(const size_t num_bytes,
|
||||
const size_t block_start,
|
||||
const size_t max_backward_limit,
|
||||
const ZopfliNode* nodes,
|
||||
int* dist_cache,
|
||||
size_t* last_insert_len,
|
||||
Command* commands,
|
||||
size_t* num_literals) {
|
||||
size_t pos = 0;
|
||||
uint32_t offset = nodes[0].u.next;
|
||||
size_t i;
|
||||
for (i = 0; offset != BROTLI_UINT32_MAX; i++) {
|
||||
const ZopfliNode* next = &nodes[pos + offset];
|
||||
size_t copy_length = ZopfliNodeCopyLength(next);
|
||||
size_t insert_length = next->insert_length;
|
||||
pos += insert_length;
|
||||
offset = next->u.next;
|
||||
if (i == 0) {
|
||||
insert_length += *last_insert_len;
|
||||
*last_insert_len = 0;
|
||||
}
|
||||
{
|
||||
size_t distance = ZopfliNodeCopyDistance(next);
|
||||
size_t len_code = ZopfliNodeLengthCode(next);
|
||||
size_t max_distance =
|
||||
BROTLI_MIN(size_t, block_start + pos, max_backward_limit);
|
||||
BROTLI_BOOL is_dictionary = TO_BROTLI_BOOL(distance > max_distance);
|
||||
size_t dist_code = ZopfliNodeDistanceCode(next);
|
||||
|
||||
InitCommand(&commands[i], insert_length,
|
||||
copy_length, (int)len_code - (int)copy_length, dist_code);
|
||||
|
||||
if (!is_dictionary && dist_code > 0) {
|
||||
dist_cache[3] = dist_cache[2];
|
||||
dist_cache[2] = dist_cache[1];
|
||||
dist_cache[1] = dist_cache[0];
|
||||
dist_cache[0] = (int)distance;
|
||||
}
|
||||
}
|
||||
|
||||
*num_literals += insert_length;
|
||||
pos += copy_length;
|
||||
}
|
||||
*last_insert_len += num_bytes - pos;
|
||||
}
|
||||
|
||||
static size_t ZopfliIterate(size_t num_bytes,
|
||||
size_t position,
|
||||
const uint8_t* ringbuffer,
|
||||
size_t ringbuffer_mask,
|
||||
const BrotliEncoderParams* params,
|
||||
const size_t max_backward_limit,
|
||||
const int* dist_cache,
|
||||
const ZopfliCostModel* model,
|
||||
const uint32_t* num_matches,
|
||||
const BackwardMatch* matches,
|
||||
ZopfliNode* nodes) {
|
||||
const size_t max_zopfli_len = MaxZopfliLen(params);
|
||||
StartPosQueue queue;
|
||||
size_t cur_match_pos = 0;
|
||||
size_t i;
|
||||
nodes[0].length = 0;
|
||||
nodes[0].u.cost = 0;
|
||||
InitStartPosQueue(&queue);
|
||||
for (i = 0; i + 3 < num_bytes; i++) {
|
||||
size_t skip = UpdateNodes(num_bytes, position, i, ringbuffer,
|
||||
ringbuffer_mask, params, max_backward_limit, dist_cache,
|
||||
num_matches[i], &matches[cur_match_pos], model, &queue, nodes);
|
||||
if (skip < BROTLI_LONG_COPY_QUICK_STEP) skip = 0;
|
||||
cur_match_pos += num_matches[i];
|
||||
if (num_matches[i] == 1 &&
|
||||
BackwardMatchLength(&matches[cur_match_pos - 1]) > max_zopfli_len) {
|
||||
skip = BROTLI_MAX(size_t,
|
||||
BackwardMatchLength(&matches[cur_match_pos - 1]), skip);
|
||||
}
|
||||
if (skip > 1) {
|
||||
skip--;
|
||||
while (skip) {
|
||||
i++;
|
||||
if (i + 3 >= num_bytes) break;
|
||||
EvaluateNode(
|
||||
position, i, max_backward_limit, dist_cache, model, &queue, nodes);
|
||||
cur_match_pos += num_matches[i];
|
||||
skip--;
|
||||
}
|
||||
}
|
||||
}
|
||||
return ComputeShortestPathFromNodes(num_bytes, nodes);
|
||||
}
|
||||
|
||||
/* REQUIRES: nodes != NULL and len(nodes) >= num_bytes + 1 */
|
||||
size_t BrotliZopfliComputeShortestPath(MemoryManager* m,
|
||||
const BrotliDictionary* dictionary,
|
||||
size_t num_bytes,
|
||||
size_t position,
|
||||
const uint8_t* ringbuffer,
|
||||
size_t ringbuffer_mask,
|
||||
const BrotliEncoderParams* params,
|
||||
const size_t max_backward_limit,
|
||||
const int* dist_cache,
|
||||
HasherHandle hasher,
|
||||
ZopfliNode* nodes) {
|
||||
const size_t max_zopfli_len = MaxZopfliLen(params);
|
||||
ZopfliCostModel model;
|
||||
StartPosQueue queue;
|
||||
BackwardMatch matches[MAX_NUM_MATCHES_H10];
|
||||
const size_t store_end = num_bytes >= StoreLookaheadH10() ?
|
||||
position + num_bytes - StoreLookaheadH10() + 1 : position;
|
||||
size_t i;
|
||||
nodes[0].length = 0;
|
||||
nodes[0].u.cost = 0;
|
||||
InitZopfliCostModel(m, &model, num_bytes);
|
||||
if (BROTLI_IS_OOM(m)) return 0;
|
||||
ZopfliCostModelSetFromLiteralCosts(
|
||||
&model, position, ringbuffer, ringbuffer_mask);
|
||||
InitStartPosQueue(&queue);
|
||||
for (i = 0; i + HashTypeLengthH10() - 1 < num_bytes; i++) {
|
||||
const size_t pos = position + i;
|
||||
const size_t max_distance = BROTLI_MIN(size_t, pos, max_backward_limit);
|
||||
size_t num_matches = FindAllMatchesH10(hasher, dictionary, ringbuffer,
|
||||
ringbuffer_mask, pos, num_bytes - i, max_distance, params, matches);
|
||||
size_t skip;
|
||||
if (num_matches > 0 &&
|
||||
BackwardMatchLength(&matches[num_matches - 1]) > max_zopfli_len) {
|
||||
matches[0] = matches[num_matches - 1];
|
||||
num_matches = 1;
|
||||
}
|
||||
skip = UpdateNodes(num_bytes, position, i, ringbuffer, ringbuffer_mask,
|
||||
params, max_backward_limit, dist_cache, num_matches, matches, &model,
|
||||
&queue, nodes);
|
||||
if (skip < BROTLI_LONG_COPY_QUICK_STEP) skip = 0;
|
||||
if (num_matches == 1 && BackwardMatchLength(&matches[0]) > max_zopfli_len) {
|
||||
skip = BROTLI_MAX(size_t, BackwardMatchLength(&matches[0]), skip);
|
||||
}
|
||||
if (skip > 1) {
|
||||
/* Add the tail of the copy to the hasher. */
|
||||
StoreRangeH10(hasher, ringbuffer, ringbuffer_mask, pos + 1, BROTLI_MIN(
|
||||
size_t, pos + skip, store_end));
|
||||
skip--;
|
||||
while (skip) {
|
||||
i++;
|
||||
if (i + HashTypeLengthH10() - 1 >= num_bytes) break;
|
||||
EvaluateNode(
|
||||
position, i, max_backward_limit, dist_cache, &model, &queue, nodes);
|
||||
skip--;
|
||||
}
|
||||
}
|
||||
}
|
||||
CleanupZopfliCostModel(m, &model);
|
||||
return ComputeShortestPathFromNodes(num_bytes, nodes);
|
||||
}
|
||||
|
||||
void BrotliCreateZopfliBackwardReferences(
|
||||
MemoryManager* m, const BrotliDictionary* dictionary, size_t num_bytes,
|
||||
size_t position, const uint8_t* ringbuffer, size_t ringbuffer_mask,
|
||||
const BrotliEncoderParams* params, HasherHandle hasher, int* dist_cache,
|
||||
size_t* last_insert_len, Command* commands, size_t* num_commands,
|
||||
size_t* num_literals) {
|
||||
const size_t max_backward_limit = BROTLI_MAX_BACKWARD_LIMIT(params->lgwin);
|
||||
ZopfliNode* nodes;
|
||||
nodes = BROTLI_ALLOC(m, ZopfliNode, num_bytes + 1);
|
||||
if (BROTLI_IS_OOM(m)) return;
|
||||
BrotliInitZopfliNodes(nodes, num_bytes + 1);
|
||||
*num_commands += BrotliZopfliComputeShortestPath(m, dictionary, num_bytes,
|
||||
position, ringbuffer, ringbuffer_mask, params, max_backward_limit,
|
||||
dist_cache, hasher, nodes);
|
||||
if (BROTLI_IS_OOM(m)) return;
|
||||
BrotliZopfliCreateCommands(num_bytes, position, max_backward_limit, nodes,
|
||||
dist_cache, last_insert_len, commands, num_literals);
|
||||
BROTLI_FREE(m, nodes);
|
||||
}
|
||||
|
||||
void BrotliCreateHqZopfliBackwardReferences(
|
||||
MemoryManager* m, const BrotliDictionary* dictionary, size_t num_bytes,
|
||||
size_t position, const uint8_t* ringbuffer, size_t ringbuffer_mask,
|
||||
const BrotliEncoderParams* params, HasherHandle hasher, int* dist_cache,
|
||||
size_t* last_insert_len, Command* commands, size_t* num_commands,
|
||||
size_t* num_literals) {
|
||||
const size_t max_backward_limit = BROTLI_MAX_BACKWARD_LIMIT(params->lgwin);
|
||||
uint32_t* num_matches = BROTLI_ALLOC(m, uint32_t, num_bytes);
|
||||
size_t matches_size = 4 * num_bytes;
|
||||
const size_t store_end = num_bytes >= StoreLookaheadH10() ?
|
||||
position + num_bytes - StoreLookaheadH10() + 1 : position;
|
||||
size_t cur_match_pos = 0;
|
||||
size_t i;
|
||||
size_t orig_num_literals;
|
||||
size_t orig_last_insert_len;
|
||||
int orig_dist_cache[4];
|
||||
size_t orig_num_commands;
|
||||
ZopfliCostModel model;
|
||||
ZopfliNode* nodes;
|
||||
BackwardMatch* matches = BROTLI_ALLOC(m, BackwardMatch, matches_size);
|
||||
if (BROTLI_IS_OOM(m)) return;
|
||||
for (i = 0; i + HashTypeLengthH10() - 1 < num_bytes; ++i) {
|
||||
const size_t pos = position + i;
|
||||
size_t max_distance = BROTLI_MIN(size_t, pos, max_backward_limit);
|
||||
size_t max_length = num_bytes - i;
|
||||
size_t num_found_matches;
|
||||
size_t cur_match_end;
|
||||
size_t j;
|
||||
/* Ensure that we have enough free slots. */
|
||||
BROTLI_ENSURE_CAPACITY(m, BackwardMatch, matches, matches_size,
|
||||
cur_match_pos + MAX_NUM_MATCHES_H10);
|
||||
if (BROTLI_IS_OOM(m)) return;
|
||||
num_found_matches = FindAllMatchesH10(hasher, dictionary, ringbuffer,
|
||||
ringbuffer_mask, pos, max_length, max_distance, params,
|
||||
&matches[cur_match_pos]);
|
||||
cur_match_end = cur_match_pos + num_found_matches;
|
||||
for (j = cur_match_pos; j + 1 < cur_match_end; ++j) {
|
||||
assert(BackwardMatchLength(&matches[j]) <
|
||||
BackwardMatchLength(&matches[j + 1]));
|
||||
assert(matches[j].distance > max_distance ||
|
||||
matches[j].distance <= matches[j + 1].distance);
|
||||
}
|
||||
num_matches[i] = (uint32_t)num_found_matches;
|
||||
if (num_found_matches > 0) {
|
||||
const size_t match_len = BackwardMatchLength(&matches[cur_match_end - 1]);
|
||||
if (match_len > MAX_ZOPFLI_LEN_QUALITY_11) {
|
||||
const size_t skip = match_len - 1;
|
||||
matches[cur_match_pos++] = matches[cur_match_end - 1];
|
||||
num_matches[i] = 1;
|
||||
/* Add the tail of the copy to the hasher. */
|
||||
StoreRangeH10(hasher, ringbuffer, ringbuffer_mask, pos + 1,
|
||||
BROTLI_MIN(size_t, pos + match_len, store_end));
|
||||
memset(&num_matches[i + 1], 0, skip * sizeof(num_matches[0]));
|
||||
i += skip;
|
||||
} else {
|
||||
cur_match_pos = cur_match_end;
|
||||
}
|
||||
}
|
||||
}
|
||||
orig_num_literals = *num_literals;
|
||||
orig_last_insert_len = *last_insert_len;
|
||||
memcpy(orig_dist_cache, dist_cache, 4 * sizeof(dist_cache[0]));
|
||||
orig_num_commands = *num_commands;
|
||||
nodes = BROTLI_ALLOC(m, ZopfliNode, num_bytes + 1);
|
||||
if (BROTLI_IS_OOM(m)) return;
|
||||
InitZopfliCostModel(m, &model, num_bytes);
|
||||
if (BROTLI_IS_OOM(m)) return;
|
||||
for (i = 0; i < 2; i++) {
|
||||
BrotliInitZopfliNodes(nodes, num_bytes + 1);
|
||||
if (i == 0) {
|
||||
ZopfliCostModelSetFromLiteralCosts(
|
||||
&model, position, ringbuffer, ringbuffer_mask);
|
||||
} else {
|
||||
ZopfliCostModelSetFromCommands(&model, position, ringbuffer,
|
||||
ringbuffer_mask, commands, *num_commands - orig_num_commands,
|
||||
orig_last_insert_len);
|
||||
}
|
||||
*num_commands = orig_num_commands;
|
||||
*num_literals = orig_num_literals;
|
||||
*last_insert_len = orig_last_insert_len;
|
||||
memcpy(dist_cache, orig_dist_cache, 4 * sizeof(dist_cache[0]));
|
||||
*num_commands += ZopfliIterate(num_bytes, position, ringbuffer,
|
||||
ringbuffer_mask, params, max_backward_limit, dist_cache,
|
||||
&model, num_matches, matches, nodes);
|
||||
BrotliZopfliCreateCommands(num_bytes, position, max_backward_limit,
|
||||
nodes, dist_cache, last_insert_len, commands, num_literals);
|
||||
}
|
||||
CleanupZopfliCostModel(m, &model);
|
||||
BROTLI_FREE(m, nodes);
|
||||
BROTLI_FREE(m, matches);
|
||||
BROTLI_FREE(m, num_matches);
|
||||
}
|
||||
|
||||
#if defined(__cplusplus) || defined(c_plusplus)
|
||||
} /* extern "C" */
|
||||
#endif
|
|
@ -0,0 +1,35 @@
|
|||
/* Copyright 2013 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
/* Functions to estimate the bit cost of Huffman trees. */
|
||||
|
||||
#include "./enc/bit_cost.h"
|
||||
|
||||
#include "./common/constants.h"
|
||||
#include <brotli/types.h>
|
||||
#include "./enc/fast_log.h"
|
||||
#include "./enc/histogram.h"
|
||||
#include "./enc/port.h"
|
||||
|
||||
#if defined(__cplusplus) || defined(c_plusplus)
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#define FN(X) X ## Literal
|
||||
#include "./enc/bit_cost_inc.h" /* NOLINT(build/include) */
|
||||
#undef FN
|
||||
|
||||
#define FN(X) X ## Command
|
||||
#include "./enc/bit_cost_inc.h" /* NOLINT(build/include) */
|
||||
#undef FN
|
||||
|
||||
#define FN(X) X ## Distance
|
||||
#include "./enc/bit_cost_inc.h" /* NOLINT(build/include) */
|
||||
#undef FN
|
||||
|
||||
#if defined(__cplusplus) || defined(c_plusplus)
|
||||
} /* extern "C" */
|
||||
#endif
|
|
@ -0,0 +1,48 @@
|
|||
/* Copyright 2013 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
/* Bit reading helpers */
|
||||
|
||||
#include "./dec/bit_reader.h"
|
||||
|
||||
#include <brotli/types.h>
|
||||
#include "./dec/port.h"
|
||||
|
||||
#if defined(__cplusplus) || defined(c_plusplus)
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
void BrotliInitBitReader(BrotliBitReader* const br) {
|
||||
br->val_ = 0;
|
||||
br->bit_pos_ = sizeof(br->val_) << 3;
|
||||
}
|
||||
|
||||
BROTLI_BOOL BrotliWarmupBitReader(BrotliBitReader* const br) {
|
||||
size_t aligned_read_mask = (sizeof(br->val_) >> 1) - 1;
|
||||
/* Fixing alignment after unaligned BrotliFillWindow would result accumulator
|
||||
overflow. If unalignment is caused by BrotliSafeReadBits, then there is
|
||||
enough space in accumulator to fix alignment. */
|
||||
if (!BROTLI_ALIGNED_READ) {
|
||||
aligned_read_mask = 0;
|
||||
}
|
||||
if (BrotliGetAvailableBits(br) == 0) {
|
||||
if (!BrotliPullByte(br)) {
|
||||
return BROTLI_FALSE;
|
||||
}
|
||||
}
|
||||
|
||||
while ((((size_t)br->next_in) & aligned_read_mask) != 0) {
|
||||
if (!BrotliPullByte(br)) {
|
||||
/* If we consumed all the input, we don't care about the alignment. */
|
||||
return BROTLI_TRUE;
|
||||
}
|
||||
}
|
||||
return BROTLI_TRUE;
|
||||
}
|
||||
|
||||
#if defined(__cplusplus) || defined(c_plusplus)
|
||||
} /* extern "C" */
|
||||
#endif
|
|
@ -0,0 +1,197 @@
|
|||
/* Copyright 2013 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
/* Block split point selection utilities. */
|
||||
|
||||
#include "./enc/block_splitter.h"
|
||||
|
||||
#include <assert.h>
|
||||
#include <string.h> /* memcpy, memset */
|
||||
|
||||
#include "./enc/bit_cost.h"
|
||||
#include "./enc/cluster.h"
|
||||
#include "./enc/command.h"
|
||||
#include "./enc/fast_log.h"
|
||||
#include "./enc/histogram.h"
|
||||
#include "./enc/memory.h"
|
||||
#include "./enc/port.h"
|
||||
#include "./enc/quality.h"
|
||||
|
||||
#if defined(__cplusplus) || defined(c_plusplus)
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
static const size_t kMaxLiteralHistograms = 100;
|
||||
static const size_t kMaxCommandHistograms = 50;
|
||||
static const double kLiteralBlockSwitchCost = 28.1;
|
||||
static const double kCommandBlockSwitchCost = 13.5;
|
||||
static const double kDistanceBlockSwitchCost = 14.6;
|
||||
static const size_t kLiteralStrideLength = 70;
|
||||
static const size_t kCommandStrideLength = 40;
|
||||
static const size_t kSymbolsPerLiteralHistogram = 544;
|
||||
static const size_t kSymbolsPerCommandHistogram = 530;
|
||||
static const size_t kSymbolsPerDistanceHistogram = 544;
|
||||
static const size_t kMinLengthForBlockSplitting = 128;
|
||||
static const size_t kIterMulForRefining = 2;
|
||||
static const size_t kMinItersForRefining = 100;
|
||||
|
||||
static size_t CountLiterals(const Command* cmds, const size_t num_commands) {
|
||||
/* Count how many we have. */
|
||||
size_t total_length = 0;
|
||||
size_t i;
|
||||
for (i = 0; i < num_commands; ++i) {
|
||||
total_length += cmds[i].insert_len_;
|
||||
}
|
||||
return total_length;
|
||||
}
|
||||
|
||||
static void CopyLiteralsToByteArray(const Command* cmds,
|
||||
const size_t num_commands,
|
||||
const uint8_t* data,
|
||||
const size_t offset,
|
||||
const size_t mask,
|
||||
uint8_t* literals) {
|
||||
size_t pos = 0;
|
||||
size_t from_pos = offset & mask;
|
||||
size_t i;
|
||||
for (i = 0; i < num_commands; ++i) {
|
||||
size_t insert_len = cmds[i].insert_len_;
|
||||
if (from_pos + insert_len > mask) {
|
||||
size_t head_size = mask + 1 - from_pos;
|
||||
memcpy(literals + pos, data + from_pos, head_size);
|
||||
from_pos = 0;
|
||||
pos += head_size;
|
||||
insert_len -= head_size;
|
||||
}
|
||||
if (insert_len > 0) {
|
||||
memcpy(literals + pos, data + from_pos, insert_len);
|
||||
pos += insert_len;
|
||||
}
|
||||
from_pos = (from_pos + insert_len + CommandCopyLen(&cmds[i])) & mask;
|
||||
}
|
||||
}
|
||||
|
||||
static BROTLI_INLINE unsigned int MyRand(unsigned int* seed) {
|
||||
*seed *= 16807U;
|
||||
if (*seed == 0) {
|
||||
*seed = 1;
|
||||
}
|
||||
return *seed;
|
||||
}
|
||||
|
||||
static BROTLI_INLINE double BitCost(size_t count) {
|
||||
return count == 0 ? -2.0 : FastLog2(count);
|
||||
}
|
||||
|
||||
#define HISTOGRAMS_PER_BATCH 64
|
||||
#define CLUSTERS_PER_BATCH 16
|
||||
|
||||
#define FN(X) X ## Literal
|
||||
#define DataType uint8_t
|
||||
/* NOLINTNEXTLINE(build/include) */
|
||||
#include "./enc/block_splitter_inc.h"
|
||||
#undef DataType
|
||||
#undef FN
|
||||
|
||||
#define FN(X) X ## Command
|
||||
#define DataType uint16_t
|
||||
/* NOLINTNEXTLINE(build/include) */
|
||||
#include "./enc/block_splitter_inc.h"
|
||||
#undef FN
|
||||
|
||||
#define FN(X) X ## Distance
|
||||
/* NOLINTNEXTLINE(build/include) */
|
||||
#include "./enc/block_splitter_inc.h"
|
||||
#undef DataType
|
||||
#undef FN
|
||||
|
||||
void BrotliInitBlockSplit(BlockSplit* self) {
|
||||
self->num_types = 0;
|
||||
self->num_blocks = 0;
|
||||
self->types = 0;
|
||||
self->lengths = 0;
|
||||
self->types_alloc_size = 0;
|
||||
self->lengths_alloc_size = 0;
|
||||
}
|
||||
|
||||
void BrotliDestroyBlockSplit(MemoryManager* m, BlockSplit* self) {
|
||||
BROTLI_FREE(m, self->types);
|
||||
BROTLI_FREE(m, self->lengths);
|
||||
}
|
||||
|
||||
void BrotliSplitBlock(MemoryManager* m,
|
||||
const Command* cmds,
|
||||
const size_t num_commands,
|
||||
const uint8_t* data,
|
||||
const size_t pos,
|
||||
const size_t mask,
|
||||
const BrotliEncoderParams* params,
|
||||
BlockSplit* literal_split,
|
||||
BlockSplit* insert_and_copy_split,
|
||||
BlockSplit* dist_split) {
|
||||
{
|
||||
size_t literals_count = CountLiterals(cmds, num_commands);
|
||||
uint8_t* literals = BROTLI_ALLOC(m, uint8_t, literals_count);
|
||||
if (BROTLI_IS_OOM(m)) return;
|
||||
/* Create a continuous array of literals. */
|
||||
CopyLiteralsToByteArray(cmds, num_commands, data, pos, mask, literals);
|
||||
/* Create the block split on the array of literals.
|
||||
Literal histograms have alphabet size 256. */
|
||||
SplitByteVectorLiteral(
|
||||
m, literals, literals_count,
|
||||
kSymbolsPerLiteralHistogram, kMaxLiteralHistograms,
|
||||
kLiteralStrideLength, kLiteralBlockSwitchCost, params,
|
||||
literal_split);
|
||||
if (BROTLI_IS_OOM(m)) return;
|
||||
BROTLI_FREE(m, literals);
|
||||
}
|
||||
|
||||
{
|
||||
/* Compute prefix codes for commands. */
|
||||
uint16_t* insert_and_copy_codes = BROTLI_ALLOC(m, uint16_t, num_commands);
|
||||
size_t i;
|
||||
if (BROTLI_IS_OOM(m)) return;
|
||||
for (i = 0; i < num_commands; ++i) {
|
||||
insert_and_copy_codes[i] = cmds[i].cmd_prefix_;
|
||||
}
|
||||
/* Create the block split on the array of command prefixes. */
|
||||
SplitByteVectorCommand(
|
||||
m, insert_and_copy_codes, num_commands,
|
||||
kSymbolsPerCommandHistogram, kMaxCommandHistograms,
|
||||
kCommandStrideLength, kCommandBlockSwitchCost, params,
|
||||
insert_and_copy_split);
|
||||
if (BROTLI_IS_OOM(m)) return;
|
||||
/* TODO: reuse for distances? */
|
||||
BROTLI_FREE(m, insert_and_copy_codes);
|
||||
}
|
||||
|
||||
{
|
||||
/* Create a continuous array of distance prefixes. */
|
||||
uint16_t* distance_prefixes = BROTLI_ALLOC(m, uint16_t, num_commands);
|
||||
size_t j = 0;
|
||||
size_t i;
|
||||
if (BROTLI_IS_OOM(m)) return;
|
||||
for (i = 0; i < num_commands; ++i) {
|
||||
const Command* cmd = &cmds[i];
|
||||
if (CommandCopyLen(cmd) && cmd->cmd_prefix_ >= 128) {
|
||||
distance_prefixes[j++] = cmd->dist_prefix_;
|
||||
}
|
||||
}
|
||||
/* Create the block split on the array of distance prefixes. */
|
||||
SplitByteVectorDistance(
|
||||
m, distance_prefixes, j,
|
||||
kSymbolsPerDistanceHistogram, kMaxCommandHistograms,
|
||||
kCommandStrideLength, kDistanceBlockSwitchCost, params,
|
||||
dist_split);
|
||||
if (BROTLI_IS_OOM(m)) return;
|
||||
BROTLI_FREE(m, distance_prefixes);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#if defined(__cplusplus) || defined(c_plusplus)
|
||||
} /* extern "C" */
|
||||
#endif
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,12 @@
|
|||
// Copyright 2017 Google Inc. All Rights Reserved.
|
||||
//
|
||||
// Distributed under MIT license.
|
||||
// See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
|
||||
package brotli
|
||||
|
||||
// Inform golang build system that it should link brotli libraries.
|
||||
|
||||
// #cgo CFLAGS: -O3
|
||||
// #cgo LDFLAGS: -lm
|
||||
import "C"
|
|
@ -0,0 +1,56 @@
|
|||
/* Copyright 2013 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
/* Functions for clustering similar histograms together. */
|
||||
|
||||
#include "./enc/cluster.h"
|
||||
|
||||
#include <brotli/types.h>
|
||||
#include "./enc/bit_cost.h" /* BrotliPopulationCost */
|
||||
#include "./enc/fast_log.h"
|
||||
#include "./enc/histogram.h"
|
||||
#include "./enc/memory.h"
|
||||
#include "./enc/port.h"
|
||||
|
||||
#if defined(__cplusplus) || defined(c_plusplus)
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
static BROTLI_INLINE BROTLI_BOOL HistogramPairIsLess(
|
||||
const HistogramPair* p1, const HistogramPair* p2) {
|
||||
if (p1->cost_diff != p2->cost_diff) {
|
||||
return TO_BROTLI_BOOL(p1->cost_diff > p2->cost_diff);
|
||||
}
|
||||
return TO_BROTLI_BOOL((p1->idx2 - p1->idx1) > (p2->idx2 - p2->idx1));
|
||||
}
|
||||
|
||||
/* Returns entropy reduction of the context map when we combine two clusters. */
|
||||
static BROTLI_INLINE double ClusterCostDiff(size_t size_a, size_t size_b) {
|
||||
size_t size_c = size_a + size_b;
|
||||
return (double)size_a * FastLog2(size_a) +
|
||||
(double)size_b * FastLog2(size_b) -
|
||||
(double)size_c * FastLog2(size_c);
|
||||
}
|
||||
|
||||
#define CODE(X) X
|
||||
|
||||
#define FN(X) X ## Literal
|
||||
#include "./enc/cluster_inc.h" /* NOLINT(build/include) */
|
||||
#undef FN
|
||||
|
||||
#define FN(X) X ## Command
|
||||
#include "./enc/cluster_inc.h" /* NOLINT(build/include) */
|
||||
#undef FN
|
||||
|
||||
#define FN(X) X ## Distance
|
||||
#include "./enc/cluster_inc.h" /* NOLINT(build/include) */
|
||||
#undef FN
|
||||
|
||||
#undef CODE
|
||||
|
||||
#if defined(__cplusplus) || defined(c_plusplus)
|
||||
} /* extern "C" */
|
||||
#endif
|
|
@ -0,0 +1,791 @@
|
|||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
/* Function for fast encoding of an input fragment, independently from the input
|
||||
history. This function uses one-pass processing: when we find a backward
|
||||
match, we immediately emit the corresponding command and literal codes to
|
||||
the bit stream.
|
||||
|
||||
Adapted from the CompressFragment() function in
|
||||
https://github.com/google/snappy/blob/master/snappy.cc */
|
||||
|
||||
#include "./enc/compress_fragment.h"
|
||||
|
||||
#include <string.h> /* memcmp, memcpy, memset */
|
||||
|
||||
#include "./common/constants.h"
|
||||
#include <brotli/types.h>
|
||||
#include "./enc/brotli_bit_stream.h"
|
||||
#include "./enc/entropy_encode.h"
|
||||
#include "./enc/fast_log.h"
|
||||
#include "./enc/find_match_length.h"
|
||||
#include "./enc/memory.h"
|
||||
#include "./enc/port.h"
|
||||
#include "./enc/write_bits.h"
|
||||
|
||||
|
||||
#if defined(__cplusplus) || defined(c_plusplus)
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#define MAX_DISTANCE (long)BROTLI_MAX_BACKWARD_LIMIT(18)
|
||||
|
||||
/* kHashMul32 multiplier has these properties:
|
||||
* The multiplier must be odd. Otherwise we may lose the highest bit.
|
||||
* No long streaks of ones or zeros.
|
||||
* There is no effort to ensure that it is a prime, the oddity is enough
|
||||
for this use.
|
||||
* The number has been tuned heuristically against compression benchmarks. */
|
||||
static const uint32_t kHashMul32 = 0x1e35a7bd;
|
||||
|
||||
static BROTLI_INLINE uint32_t Hash(const uint8_t* p, size_t shift) {
|
||||
const uint64_t h = (BROTLI_UNALIGNED_LOAD64(p) << 24) * kHashMul32;
|
||||
return (uint32_t)(h >> shift);
|
||||
}
|
||||
|
||||
static BROTLI_INLINE uint32_t HashBytesAtOffset(
|
||||
uint64_t v, int offset, size_t shift) {
|
||||
assert(offset >= 0);
|
||||
assert(offset <= 3);
|
||||
{
|
||||
const uint64_t h = ((v >> (8 * offset)) << 24) * kHashMul32;
|
||||
return (uint32_t)(h >> shift);
|
||||
}
|
||||
}
|
||||
|
||||
static BROTLI_INLINE BROTLI_BOOL IsMatch(const uint8_t* p1, const uint8_t* p2) {
|
||||
return TO_BROTLI_BOOL(
|
||||
BROTLI_UNALIGNED_LOAD32(p1) == BROTLI_UNALIGNED_LOAD32(p2) &&
|
||||
p1[4] == p2[4]);
|
||||
}
|
||||
|
||||
/* Builds a literal prefix code into "depths" and "bits" based on the statistics
|
||||
of the "input" string and stores it into the bit stream.
|
||||
Note that the prefix code here is built from the pre-LZ77 input, therefore
|
||||
we can only approximate the statistics of the actual literal stream.
|
||||
Moreover, for long inputs we build a histogram from a sample of the input
|
||||
and thus have to assign a non-zero depth for each literal.
|
||||
Returns estimated compression ratio millibytes/char for encoding given input
|
||||
with generated code. */
|
||||
static size_t BuildAndStoreLiteralPrefixCode(MemoryManager* m,
|
||||
const uint8_t* input,
|
||||
const size_t input_size,
|
||||
uint8_t depths[256],
|
||||
uint16_t bits[256],
|
||||
size_t* storage_ix,
|
||||
uint8_t* storage) {
|
||||
uint32_t histogram[256] = { 0 };
|
||||
size_t histogram_total;
|
||||
size_t i;
|
||||
if (input_size < (1 << 15)) {
|
||||
for (i = 0; i < input_size; ++i) {
|
||||
++histogram[input[i]];
|
||||
}
|
||||
histogram_total = input_size;
|
||||
for (i = 0; i < 256; ++i) {
|
||||
/* We weigh the first 11 samples with weight 3 to account for the
|
||||
balancing effect of the LZ77 phase on the histogram. */
|
||||
const uint32_t adjust = 2 * BROTLI_MIN(uint32_t, histogram[i], 11u);
|
||||
histogram[i] += adjust;
|
||||
histogram_total += adjust;
|
||||
}
|
||||
} else {
|
||||
static const size_t kSampleRate = 29;
|
||||
for (i = 0; i < input_size; i += kSampleRate) {
|
||||
++histogram[input[i]];
|
||||
}
|
||||
histogram_total = (input_size + kSampleRate - 1) / kSampleRate;
|
||||
for (i = 0; i < 256; ++i) {
|
||||
/* We add 1 to each population count to avoid 0 bit depths (since this is
|
||||
only a sample and we don't know if the symbol appears or not), and we
|
||||
weigh the first 11 samples with weight 3 to account for the balancing
|
||||
effect of the LZ77 phase on the histogram (more frequent symbols are
|
||||
more likely to be in backward references instead as literals). */
|
||||
const uint32_t adjust = 1 + 2 * BROTLI_MIN(uint32_t, histogram[i], 11u);
|
||||
histogram[i] += adjust;
|
||||
histogram_total += adjust;
|
||||
}
|
||||
}
|
||||
BrotliBuildAndStoreHuffmanTreeFast(m, histogram, histogram_total,
|
||||
/* max_bits = */ 8,
|
||||
depths, bits, storage_ix, storage);
|
||||
if (BROTLI_IS_OOM(m)) return 0;
|
||||
{
|
||||
size_t literal_ratio = 0;
|
||||
for (i = 0; i < 256; ++i) {
|
||||
if (histogram[i]) literal_ratio += histogram[i] * depths[i];
|
||||
}
|
||||
/* Estimated encoding ratio, millibytes per symbol. */
|
||||
return (literal_ratio * 125) / histogram_total;
|
||||
}
|
||||
}
|
||||
|
||||
/* Builds a command and distance prefix code (each 64 symbols) into "depth" and
|
||||
"bits" based on "histogram" and stores it into the bit stream. */
|
||||
static void BuildAndStoreCommandPrefixCode(const uint32_t histogram[128],
|
||||
uint8_t depth[128], uint16_t bits[128], size_t* storage_ix,
|
||||
uint8_t* storage) {
|
||||
/* Tree size for building a tree over 64 symbols is 2 * 64 + 1. */
|
||||
HuffmanTree tree[129];
|
||||
uint8_t cmd_depth[BROTLI_NUM_COMMAND_SYMBOLS] = { 0 };
|
||||
uint16_t cmd_bits[64];
|
||||
|
||||
BrotliCreateHuffmanTree(histogram, 64, 15, tree, depth);
|
||||
BrotliCreateHuffmanTree(&histogram[64], 64, 14, tree, &depth[64]);
|
||||
/* We have to jump through a few hoops here in order to compute
|
||||
the command bits because the symbols are in a different order than in
|
||||
the full alphabet. This looks complicated, but having the symbols
|
||||
in this order in the command bits saves a few branches in the Emit*
|
||||
functions. */
|
||||
memcpy(cmd_depth, depth, 24);
|
||||
memcpy(cmd_depth + 24, depth + 40, 8);
|
||||
memcpy(cmd_depth + 32, depth + 24, 8);
|
||||
memcpy(cmd_depth + 40, depth + 48, 8);
|
||||
memcpy(cmd_depth + 48, depth + 32, 8);
|
||||
memcpy(cmd_depth + 56, depth + 56, 8);
|
||||
BrotliConvertBitDepthsToSymbols(cmd_depth, 64, cmd_bits);
|
||||
memcpy(bits, cmd_bits, 48);
|
||||
memcpy(bits + 24, cmd_bits + 32, 16);
|
||||
memcpy(bits + 32, cmd_bits + 48, 16);
|
||||
memcpy(bits + 40, cmd_bits + 24, 16);
|
||||
memcpy(bits + 48, cmd_bits + 40, 16);
|
||||
memcpy(bits + 56, cmd_bits + 56, 16);
|
||||
BrotliConvertBitDepthsToSymbols(&depth[64], 64, &bits[64]);
|
||||
{
|
||||
/* Create the bit length array for the full command alphabet. */
|
||||
size_t i;
|
||||
memset(cmd_depth, 0, 64); /* only 64 first values were used */
|
||||
memcpy(cmd_depth, depth, 8);
|
||||
memcpy(cmd_depth + 64, depth + 8, 8);
|
||||
memcpy(cmd_depth + 128, depth + 16, 8);
|
||||
memcpy(cmd_depth + 192, depth + 24, 8);
|
||||
memcpy(cmd_depth + 384, depth + 32, 8);
|
||||
for (i = 0; i < 8; ++i) {
|
||||
cmd_depth[128 + 8 * i] = depth[40 + i];
|
||||
cmd_depth[256 + 8 * i] = depth[48 + i];
|
||||
cmd_depth[448 + 8 * i] = depth[56 + i];
|
||||
}
|
||||
BrotliStoreHuffmanTree(
|
||||
cmd_depth, BROTLI_NUM_COMMAND_SYMBOLS, tree, storage_ix, storage);
|
||||
}
|
||||
BrotliStoreHuffmanTree(&depth[64], 64, tree, storage_ix, storage);
|
||||
}
|
||||
|
||||
/* REQUIRES: insertlen < 6210 */
|
||||
static BROTLI_INLINE void EmitInsertLen(size_t insertlen,
|
||||
const uint8_t depth[128],
|
||||
const uint16_t bits[128],
|
||||
uint32_t histo[128],
|
||||
size_t* storage_ix,
|
||||
uint8_t* storage) {
|
||||
if (insertlen < 6) {
|
||||
const size_t code = insertlen + 40;
|
||||
BrotliWriteBits(depth[code], bits[code], storage_ix, storage);
|
||||
++histo[code];
|
||||
} else if (insertlen < 130) {
|
||||
const size_t tail = insertlen - 2;
|
||||
const uint32_t nbits = Log2FloorNonZero(tail) - 1u;
|
||||
const size_t prefix = tail >> nbits;
|
||||
const size_t inscode = (nbits << 1) + prefix + 42;
|
||||
BrotliWriteBits(depth[inscode], bits[inscode], storage_ix, storage);
|
||||
BrotliWriteBits(nbits, tail - (prefix << nbits), storage_ix, storage);
|
||||
++histo[inscode];
|
||||
} else if (insertlen < 2114) {
|
||||
const size_t tail = insertlen - 66;
|
||||
const uint32_t nbits = Log2FloorNonZero(tail);
|
||||
const size_t code = nbits + 50;
|
||||
BrotliWriteBits(depth[code], bits[code], storage_ix, storage);
|
||||
BrotliWriteBits(nbits, tail - ((size_t)1 << nbits), storage_ix, storage);
|
||||
++histo[code];
|
||||
} else {
|
||||
BrotliWriteBits(depth[61], bits[61], storage_ix, storage);
|
||||
BrotliWriteBits(12, insertlen - 2114, storage_ix, storage);
|
||||
++histo[21];
|
||||
}
|
||||
}
|
||||
|
||||
static BROTLI_INLINE void EmitLongInsertLen(size_t insertlen,
|
||||
const uint8_t depth[128],
|
||||
const uint16_t bits[128],
|
||||
uint32_t histo[128],
|
||||
size_t* storage_ix,
|
||||
uint8_t* storage) {
|
||||
if (insertlen < 22594) {
|
||||
BrotliWriteBits(depth[62], bits[62], storage_ix, storage);
|
||||
BrotliWriteBits(14, insertlen - 6210, storage_ix, storage);
|
||||
++histo[22];
|
||||
} else {
|
||||
BrotliWriteBits(depth[63], bits[63], storage_ix, storage);
|
||||
BrotliWriteBits(24, insertlen - 22594, storage_ix, storage);
|
||||
++histo[23];
|
||||
}
|
||||
}
|
||||
|
||||
static BROTLI_INLINE void EmitCopyLen(size_t copylen,
|
||||
const uint8_t depth[128],
|
||||
const uint16_t bits[128],
|
||||
uint32_t histo[128],
|
||||
size_t* storage_ix,
|
||||
uint8_t* storage) {
|
||||
if (copylen < 10) {
|
||||
BrotliWriteBits(
|
||||
depth[copylen + 14], bits[copylen + 14], storage_ix, storage);
|
||||
++histo[copylen + 14];
|
||||
} else if (copylen < 134) {
|
||||
const size_t tail = copylen - 6;
|
||||
const uint32_t nbits = Log2FloorNonZero(tail) - 1u;
|
||||
const size_t prefix = tail >> nbits;
|
||||
const size_t code = (nbits << 1) + prefix + 20;
|
||||
BrotliWriteBits(depth[code], bits[code], storage_ix, storage);
|
||||
BrotliWriteBits(nbits, tail - (prefix << nbits), storage_ix, storage);
|
||||
++histo[code];
|
||||
} else if (copylen < 2118) {
|
||||
const size_t tail = copylen - 70;
|
||||
const uint32_t nbits = Log2FloorNonZero(tail);
|
||||
const size_t code = nbits + 28;
|
||||
BrotliWriteBits(depth[code], bits[code], storage_ix, storage);
|
||||
BrotliWriteBits(nbits, tail - ((size_t)1 << nbits), storage_ix, storage);
|
||||
++histo[code];
|
||||
} else {
|
||||
BrotliWriteBits(depth[39], bits[39], storage_ix, storage);
|
||||
BrotliWriteBits(24, copylen - 2118, storage_ix, storage);
|
||||
++histo[47];
|
||||
}
|
||||
}
|
||||
|
||||
static BROTLI_INLINE void EmitCopyLenLastDistance(size_t copylen,
|
||||
const uint8_t depth[128],
|
||||
const uint16_t bits[128],
|
||||
uint32_t histo[128],
|
||||
size_t* storage_ix,
|
||||
uint8_t* storage) {
|
||||
if (copylen < 12) {
|
||||
BrotliWriteBits(depth[copylen - 4], bits[copylen - 4], storage_ix, storage);
|
||||
++histo[copylen - 4];
|
||||
} else if (copylen < 72) {
|
||||
const size_t tail = copylen - 8;
|
||||
const uint32_t nbits = Log2FloorNonZero(tail) - 1;
|
||||
const size_t prefix = tail >> nbits;
|
||||
const size_t code = (nbits << 1) + prefix + 4;
|
||||
BrotliWriteBits(depth[code], bits[code], storage_ix, storage);
|
||||
BrotliWriteBits(nbits, tail - (prefix << nbits), storage_ix, storage);
|
||||
++histo[code];
|
||||
} else if (copylen < 136) {
|
||||
const size_t tail = copylen - 8;
|
||||
const size_t code = (tail >> 5) + 30;
|
||||
BrotliWriteBits(depth[code], bits[code], storage_ix, storage);
|
||||
BrotliWriteBits(5, tail & 31, storage_ix, storage);
|
||||
BrotliWriteBits(depth[64], bits[64], storage_ix, storage);
|
||||
++histo[code];
|
||||
++histo[64];
|
||||
} else if (copylen < 2120) {
|
||||
const size_t tail = copylen - 72;
|
||||
const uint32_t nbits = Log2FloorNonZero(tail);
|
||||
const size_t code = nbits + 28;
|
||||
BrotliWriteBits(depth[code], bits[code], storage_ix, storage);
|
||||
BrotliWriteBits(nbits, tail - ((size_t)1 << nbits), storage_ix, storage);
|
||||
BrotliWriteBits(depth[64], bits[64], storage_ix, storage);
|
||||
++histo[code];
|
||||
++histo[64];
|
||||
} else {
|
||||
BrotliWriteBits(depth[39], bits[39], storage_ix, storage);
|
||||
BrotliWriteBits(24, copylen - 2120, storage_ix, storage);
|
||||
BrotliWriteBits(depth[64], bits[64], storage_ix, storage);
|
||||
++histo[47];
|
||||
++histo[64];
|
||||
}
|
||||
}
|
||||
|
||||
static BROTLI_INLINE void EmitDistance(size_t distance,
|
||||
const uint8_t depth[128],
|
||||
const uint16_t bits[128],
|
||||
uint32_t histo[128],
|
||||
size_t* storage_ix, uint8_t* storage) {
|
||||
const size_t d = distance + 3;
|
||||
const uint32_t nbits = Log2FloorNonZero(d) - 1u;
|
||||
const size_t prefix = (d >> nbits) & 1;
|
||||
const size_t offset = (2 + prefix) << nbits;
|
||||
const size_t distcode = 2 * (nbits - 1) + prefix + 80;
|
||||
BrotliWriteBits(depth[distcode], bits[distcode], storage_ix, storage);
|
||||
BrotliWriteBits(nbits, d - offset, storage_ix, storage);
|
||||
++histo[distcode];
|
||||
}
|
||||
|
||||
static BROTLI_INLINE void EmitLiterals(const uint8_t* input, const size_t len,
|
||||
const uint8_t depth[256],
|
||||
const uint16_t bits[256],
|
||||
size_t* storage_ix, uint8_t* storage) {
|
||||
size_t j;
|
||||
for (j = 0; j < len; j++) {
|
||||
const uint8_t lit = input[j];
|
||||
BrotliWriteBits(depth[lit], bits[lit], storage_ix, storage);
|
||||
}
|
||||
}
|
||||
|
||||
/* REQUIRES: len <= 1 << 24. */
|
||||
static void BrotliStoreMetaBlockHeader(
|
||||
size_t len, BROTLI_BOOL is_uncompressed, size_t* storage_ix,
|
||||
uint8_t* storage) {
|
||||
size_t nibbles = 6;
|
||||
/* ISLAST */
|
||||
BrotliWriteBits(1, 0, storage_ix, storage);
|
||||
if (len <= (1U << 16)) {
|
||||
nibbles = 4;
|
||||
} else if (len <= (1U << 20)) {
|
||||
nibbles = 5;
|
||||
}
|
||||
BrotliWriteBits(2, nibbles - 4, storage_ix, storage);
|
||||
BrotliWriteBits(nibbles * 4, len - 1, storage_ix, storage);
|
||||
/* ISUNCOMPRESSED */
|
||||
BrotliWriteBits(1, (uint64_t)is_uncompressed, storage_ix, storage);
|
||||
}
|
||||
|
||||
static void UpdateBits(size_t n_bits, uint32_t bits, size_t pos,
|
||||
uint8_t *array) {
|
||||
while (n_bits > 0) {
|
||||
size_t byte_pos = pos >> 3;
|
||||
size_t n_unchanged_bits = pos & 7;
|
||||
size_t n_changed_bits = BROTLI_MIN(size_t, n_bits, 8 - n_unchanged_bits);
|
||||
size_t total_bits = n_unchanged_bits + n_changed_bits;
|
||||
uint32_t mask =
|
||||
(~((1u << total_bits) - 1u)) | ((1u << n_unchanged_bits) - 1u);
|
||||
uint32_t unchanged_bits = array[byte_pos] & mask;
|
||||
uint32_t changed_bits = bits & ((1u << n_changed_bits) - 1u);
|
||||
array[byte_pos] =
|
||||
(uint8_t)((changed_bits << n_unchanged_bits) | unchanged_bits);
|
||||
n_bits -= n_changed_bits;
|
||||
bits >>= n_changed_bits;
|
||||
pos += n_changed_bits;
|
||||
}
|
||||
}
|
||||
|
||||
static void RewindBitPosition(const size_t new_storage_ix,
|
||||
size_t* storage_ix, uint8_t* storage) {
|
||||
const size_t bitpos = new_storage_ix & 7;
|
||||
const size_t mask = (1u << bitpos) - 1;
|
||||
storage[new_storage_ix >> 3] &= (uint8_t)mask;
|
||||
*storage_ix = new_storage_ix;
|
||||
}
|
||||
|
||||
static BROTLI_BOOL ShouldMergeBlock(
|
||||
const uint8_t* data, size_t len, const uint8_t* depths) {
|
||||
size_t histo[256] = { 0 };
|
||||
static const size_t kSampleRate = 43;
|
||||
size_t i;
|
||||
for (i = 0; i < len; i += kSampleRate) {
|
||||
++histo[data[i]];
|
||||
}
|
||||
{
|
||||
const size_t total = (len + kSampleRate - 1) / kSampleRate;
|
||||
double r = (FastLog2(total) + 0.5) * (double)total + 200;
|
||||
for (i = 0; i < 256; ++i) {
|
||||
r -= (double)histo[i] * (depths[i] + FastLog2(histo[i]));
|
||||
}
|
||||
return TO_BROTLI_BOOL(r >= 0.0);
|
||||
}
|
||||
}
|
||||
|
||||
/* Acceptable loss for uncompressible speedup is 2% */
|
||||
#define MIN_RATIO 980
|
||||
|
||||
static BROTLI_INLINE BROTLI_BOOL ShouldUseUncompressedMode(
|
||||
const uint8_t* metablock_start, const uint8_t* next_emit,
|
||||
const size_t insertlen, const size_t literal_ratio) {
|
||||
const size_t compressed = (size_t)(next_emit - metablock_start);
|
||||
if (compressed * 50 > insertlen) {
|
||||
return BROTLI_FALSE;
|
||||
} else {
|
||||
return TO_BROTLI_BOOL(literal_ratio > MIN_RATIO);
|
||||
}
|
||||
}
|
||||
|
||||
static void EmitUncompressedMetaBlock(const uint8_t* begin, const uint8_t* end,
|
||||
const size_t storage_ix_start,
|
||||
size_t* storage_ix, uint8_t* storage) {
|
||||
const size_t len = (size_t)(end - begin);
|
||||
RewindBitPosition(storage_ix_start, storage_ix, storage);
|
||||
BrotliStoreMetaBlockHeader(len, 1, storage_ix, storage);
|
||||
*storage_ix = (*storage_ix + 7u) & ~7u;
|
||||
memcpy(&storage[*storage_ix >> 3], begin, len);
|
||||
*storage_ix += len << 3;
|
||||
storage[*storage_ix >> 3] = 0;
|
||||
}
|
||||
|
||||
static uint32_t kCmdHistoSeed[128] = {
|
||||
0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 0, 0, 0, 0,
|
||||
};
|
||||
|
||||
static BROTLI_INLINE void BrotliCompressFragmentFastImpl(
|
||||
MemoryManager* m, const uint8_t* input, size_t input_size,
|
||||
BROTLI_BOOL is_last, int* table, size_t table_bits, uint8_t cmd_depth[128],
|
||||
uint16_t cmd_bits[128], size_t* cmd_code_numbits, uint8_t* cmd_code,
|
||||
size_t* storage_ix, uint8_t* storage) {
|
||||
uint32_t cmd_histo[128];
|
||||
const uint8_t* ip_end;
|
||||
|
||||
/* "next_emit" is a pointer to the first byte that is not covered by a
|
||||
previous copy. Bytes between "next_emit" and the start of the next copy or
|
||||
the end of the input will be emitted as literal bytes. */
|
||||
const uint8_t* next_emit = input;
|
||||
/* Save the start of the first block for position and distance computations.
|
||||
*/
|
||||
const uint8_t* base_ip = input;
|
||||
|
||||
static const size_t kFirstBlockSize = 3 << 15;
|
||||
static const size_t kMergeBlockSize = 1 << 16;
|
||||
|
||||
const size_t kInputMarginBytes = BROTLI_WINDOW_GAP;
|
||||
const size_t kMinMatchLen = 5;
|
||||
|
||||
const uint8_t* metablock_start = input;
|
||||
size_t block_size = BROTLI_MIN(size_t, input_size, kFirstBlockSize);
|
||||
size_t total_block_size = block_size;
|
||||
/* Save the bit position of the MLEN field of the meta-block header, so that
|
||||
we can update it later if we decide to extend this meta-block. */
|
||||
size_t mlen_storage_ix = *storage_ix + 3;
|
||||
|
||||
uint8_t lit_depth[256];
|
||||
uint16_t lit_bits[256];
|
||||
|
||||
size_t literal_ratio;
|
||||
|
||||
const uint8_t* ip;
|
||||
int last_distance;
|
||||
|
||||
const size_t shift = 64u - table_bits;
|
||||
|
||||
BrotliStoreMetaBlockHeader(block_size, 0, storage_ix, storage);
|
||||
/* No block splits, no contexts. */
|
||||
BrotliWriteBits(13, 0, storage_ix, storage);
|
||||
|
||||
literal_ratio = BuildAndStoreLiteralPrefixCode(
|
||||
m, input, block_size, lit_depth, lit_bits, storage_ix, storage);
|
||||
if (BROTLI_IS_OOM(m)) return;
|
||||
|
||||
{
|
||||
/* Store the pre-compressed command and distance prefix codes. */
|
||||
size_t i;
|
||||
for (i = 0; i + 7 < *cmd_code_numbits; i += 8) {
|
||||
BrotliWriteBits(8, cmd_code[i >> 3], storage_ix, storage);
|
||||
}
|
||||
}
|
||||
BrotliWriteBits(*cmd_code_numbits & 7, cmd_code[*cmd_code_numbits >> 3],
|
||||
storage_ix, storage);
|
||||
|
||||
emit_commands:
|
||||
/* Initialize the command and distance histograms. We will gather
|
||||
statistics of command and distance codes during the processing
|
||||
of this block and use it to update the command and distance
|
||||
prefix codes for the next block. */
|
||||
memcpy(cmd_histo, kCmdHistoSeed, sizeof(kCmdHistoSeed));
|
||||
|
||||
/* "ip" is the input pointer. */
|
||||
ip = input;
|
||||
last_distance = -1;
|
||||
ip_end = input + block_size;
|
||||
|
||||
if (BROTLI_PREDICT_TRUE(block_size >= kInputMarginBytes)) {
|
||||
/* For the last block, we need to keep a 16 bytes margin so that we can be
|
||||
sure that all distances are at most window size - 16.
|
||||
For all other blocks, we only need to keep a margin of 5 bytes so that
|
||||
we don't go over the block size with a copy. */
|
||||
const size_t len_limit = BROTLI_MIN(size_t, block_size - kMinMatchLen,
|
||||
input_size - kInputMarginBytes);
|
||||
const uint8_t* ip_limit = input + len_limit;
|
||||
|
||||
uint32_t next_hash;
|
||||
for (next_hash = Hash(++ip, shift); ; ) {
|
||||
/* Step 1: Scan forward in the input looking for a 5-byte-long match.
|
||||
If we get close to exhausting the input then goto emit_remainder.
|
||||
|
||||
Heuristic match skipping: If 32 bytes are scanned with no matches
|
||||
found, start looking only at every other byte. If 32 more bytes are
|
||||
scanned, look at every third byte, etc.. When a match is found,
|
||||
immediately go back to looking at every byte. This is a small loss
|
||||
(~5% performance, ~0.1% density) for compressible data due to more
|
||||
bookkeeping, but for non-compressible data (such as JPEG) it's a huge
|
||||
win since the compressor quickly "realizes" the data is incompressible
|
||||
and doesn't bother looking for matches everywhere.
|
||||
|
||||
The "skip" variable keeps track of how many bytes there are since the
|
||||
last match; dividing it by 32 (i.e. right-shifting by five) gives the
|
||||
number of bytes to move ahead for each iteration. */
|
||||
uint32_t skip = 32;
|
||||
|
||||
const uint8_t* next_ip = ip;
|
||||
const uint8_t* candidate;
|
||||
assert(next_emit < ip);
|
||||
trawl:
|
||||
do {
|
||||
uint32_t hash = next_hash;
|
||||
uint32_t bytes_between_hash_lookups = skip++ >> 5;
|
||||
assert(hash == Hash(next_ip, shift));
|
||||
ip = next_ip;
|
||||
next_ip = ip + bytes_between_hash_lookups;
|
||||
if (BROTLI_PREDICT_FALSE(next_ip > ip_limit)) {
|
||||
goto emit_remainder;
|
||||
}
|
||||
next_hash = Hash(next_ip, shift);
|
||||
candidate = ip - last_distance;
|
||||
if (IsMatch(ip, candidate)) {
|
||||
if (BROTLI_PREDICT_TRUE(candidate < ip)) {
|
||||
table[hash] = (int)(ip - base_ip);
|
||||
break;
|
||||
}
|
||||
}
|
||||
candidate = base_ip + table[hash];
|
||||
assert(candidate >= base_ip);
|
||||
assert(candidate < ip);
|
||||
|
||||
table[hash] = (int)(ip - base_ip);
|
||||
} while (BROTLI_PREDICT_TRUE(!IsMatch(ip, candidate)));
|
||||
|
||||
/* Check copy distance. If candidate is not feasible, continue search.
|
||||
Checking is done outside of hot loop to reduce overhead. */
|
||||
if (ip - candidate > MAX_DISTANCE) goto trawl;
|
||||
|
||||
/* Step 2: Emit the found match together with the literal bytes from
|
||||
"next_emit" to the bit stream, and then see if we can find a next match
|
||||
immediately afterwards. Repeat until we find no match for the input
|
||||
without emitting some literal bytes. */
|
||||
|
||||
{
|
||||
/* We have a 5-byte match at ip, and we need to emit bytes in
|
||||
[next_emit, ip). */
|
||||
const uint8_t* base = ip;
|
||||
size_t matched = 5 + FindMatchLengthWithLimit(
|
||||
candidate + 5, ip + 5, (size_t)(ip_end - ip) - 5);
|
||||
int distance = (int)(base - candidate); /* > 0 */
|
||||
size_t insert = (size_t)(base - next_emit);
|
||||
ip += matched;
|
||||
assert(0 == memcmp(base, candidate, matched));
|
||||
if (BROTLI_PREDICT_TRUE(insert < 6210)) {
|
||||
EmitInsertLen(insert, cmd_depth, cmd_bits, cmd_histo,
|
||||
storage_ix, storage);
|
||||
} else if (ShouldUseUncompressedMode(metablock_start, next_emit, insert,
|
||||
literal_ratio)) {
|
||||
EmitUncompressedMetaBlock(metablock_start, base, mlen_storage_ix - 3,
|
||||
storage_ix, storage);
|
||||
input_size -= (size_t)(base - input);
|
||||
input = base;
|
||||
next_emit = input;
|
||||
goto next_block;
|
||||
} else {
|
||||
EmitLongInsertLen(insert, cmd_depth, cmd_bits, cmd_histo,
|
||||
storage_ix, storage);
|
||||
}
|
||||
EmitLiterals(next_emit, insert, lit_depth, lit_bits,
|
||||
storage_ix, storage);
|
||||
if (distance == last_distance) {
|
||||
BrotliWriteBits(cmd_depth[64], cmd_bits[64], storage_ix, storage);
|
||||
++cmd_histo[64];
|
||||
} else {
|
||||
EmitDistance((size_t)distance, cmd_depth, cmd_bits,
|
||||
cmd_histo, storage_ix, storage);
|
||||
last_distance = distance;
|
||||
}
|
||||
EmitCopyLenLastDistance(matched, cmd_depth, cmd_bits, cmd_histo,
|
||||
storage_ix, storage);
|
||||
|
||||
next_emit = ip;
|
||||
if (BROTLI_PREDICT_FALSE(ip >= ip_limit)) {
|
||||
goto emit_remainder;
|
||||
}
|
||||
/* We could immediately start working at ip now, but to improve
|
||||
compression we first update "table" with the hashes of some positions
|
||||
within the last copy. */
|
||||
{
|
||||
uint64_t input_bytes = BROTLI_UNALIGNED_LOAD64(ip - 3);
|
||||
uint32_t prev_hash = HashBytesAtOffset(input_bytes, 0, shift);
|
||||
uint32_t cur_hash = HashBytesAtOffset(input_bytes, 3, shift);
|
||||
table[prev_hash] = (int)(ip - base_ip - 3);
|
||||
prev_hash = HashBytesAtOffset(input_bytes, 1, shift);
|
||||
table[prev_hash] = (int)(ip - base_ip - 2);
|
||||
prev_hash = HashBytesAtOffset(input_bytes, 2, shift);
|
||||
table[prev_hash] = (int)(ip - base_ip - 1);
|
||||
|
||||
candidate = base_ip + table[cur_hash];
|
||||
table[cur_hash] = (int)(ip - base_ip);
|
||||
}
|
||||
}
|
||||
|
||||
while (IsMatch(ip, candidate)) {
|
||||
/* We have a 5-byte match at ip, and no need to emit any literal bytes
|
||||
prior to ip. */
|
||||
const uint8_t* base = ip;
|
||||
size_t matched = 5 + FindMatchLengthWithLimit(
|
||||
candidate + 5, ip + 5, (size_t)(ip_end - ip) - 5);
|
||||
if (ip - candidate > MAX_DISTANCE) break;
|
||||
ip += matched;
|
||||
last_distance = (int)(base - candidate); /* > 0 */
|
||||
assert(0 == memcmp(base, candidate, matched));
|
||||
EmitCopyLen(matched, cmd_depth, cmd_bits, cmd_histo,
|
||||
storage_ix, storage);
|
||||
EmitDistance((size_t)last_distance, cmd_depth, cmd_bits,
|
||||
cmd_histo, storage_ix, storage);
|
||||
|
||||
next_emit = ip;
|
||||
if (BROTLI_PREDICT_FALSE(ip >= ip_limit)) {
|
||||
goto emit_remainder;
|
||||
}
|
||||
/* We could immediately start working at ip now, but to improve
|
||||
compression we first update "table" with the hashes of some positions
|
||||
within the last copy. */
|
||||
{
|
||||
uint64_t input_bytes = BROTLI_UNALIGNED_LOAD64(ip - 3);
|
||||
uint32_t prev_hash = HashBytesAtOffset(input_bytes, 0, shift);
|
||||
uint32_t cur_hash = HashBytesAtOffset(input_bytes, 3, shift);
|
||||
table[prev_hash] = (int)(ip - base_ip - 3);
|
||||
prev_hash = HashBytesAtOffset(input_bytes, 1, shift);
|
||||
table[prev_hash] = (int)(ip - base_ip - 2);
|
||||
prev_hash = HashBytesAtOffset(input_bytes, 2, shift);
|
||||
table[prev_hash] = (int)(ip - base_ip - 1);
|
||||
|
||||
candidate = base_ip + table[cur_hash];
|
||||
table[cur_hash] = (int)(ip - base_ip);
|
||||
}
|
||||
}
|
||||
|
||||
next_hash = Hash(++ip, shift);
|
||||
}
|
||||
}
|
||||
|
||||
emit_remainder:
|
||||
assert(next_emit <= ip_end);
|
||||
input += block_size;
|
||||
input_size -= block_size;
|
||||
block_size = BROTLI_MIN(size_t, input_size, kMergeBlockSize);
|
||||
|
||||
/* Decide if we want to continue this meta-block instead of emitting the
|
||||
last insert-only command. */
|
||||
if (input_size > 0 &&
|
||||
total_block_size + block_size <= (1 << 20) &&
|
||||
ShouldMergeBlock(input, block_size, lit_depth)) {
|
||||
assert(total_block_size > (1 << 16));
|
||||
/* Update the size of the current meta-block and continue emitting commands.
|
||||
We can do this because the current size and the new size both have 5
|
||||
nibbles. */
|
||||
total_block_size += block_size;
|
||||
UpdateBits(20, (uint32_t)(total_block_size - 1), mlen_storage_ix, storage);
|
||||
goto emit_commands;
|
||||
}
|
||||
|
||||
/* Emit the remaining bytes as literals. */
|
||||
if (next_emit < ip_end) {
|
||||
const size_t insert = (size_t)(ip_end - next_emit);
|
||||
if (BROTLI_PREDICT_TRUE(insert < 6210)) {
|
||||
EmitInsertLen(insert, cmd_depth, cmd_bits, cmd_histo,
|
||||
storage_ix, storage);
|
||||
EmitLiterals(next_emit, insert, lit_depth, lit_bits, storage_ix, storage);
|
||||
} else if (ShouldUseUncompressedMode(metablock_start, next_emit, insert,
|
||||
literal_ratio)) {
|
||||
EmitUncompressedMetaBlock(metablock_start, ip_end, mlen_storage_ix - 3,
|
||||
storage_ix, storage);
|
||||
} else {
|
||||
EmitLongInsertLen(insert, cmd_depth, cmd_bits, cmd_histo,
|
||||
storage_ix, storage);
|
||||
EmitLiterals(next_emit, insert, lit_depth, lit_bits,
|
||||
storage_ix, storage);
|
||||
}
|
||||
}
|
||||
next_emit = ip_end;
|
||||
|
||||
next_block:
|
||||
/* If we have more data, write a new meta-block header and prefix codes and
|
||||
then continue emitting commands. */
|
||||
if (input_size > 0) {
|
||||
metablock_start = input;
|
||||
block_size = BROTLI_MIN(size_t, input_size, kFirstBlockSize);
|
||||
total_block_size = block_size;
|
||||
/* Save the bit position of the MLEN field of the meta-block header, so that
|
||||
we can update it later if we decide to extend this meta-block. */
|
||||
mlen_storage_ix = *storage_ix + 3;
|
||||
BrotliStoreMetaBlockHeader(block_size, 0, storage_ix, storage);
|
||||
/* No block splits, no contexts. */
|
||||
BrotliWriteBits(13, 0, storage_ix, storage);
|
||||
literal_ratio = BuildAndStoreLiteralPrefixCode(
|
||||
m, input, block_size, lit_depth, lit_bits, storage_ix, storage);
|
||||
if (BROTLI_IS_OOM(m)) return;
|
||||
BuildAndStoreCommandPrefixCode(cmd_histo, cmd_depth, cmd_bits,
|
||||
storage_ix, storage);
|
||||
goto emit_commands;
|
||||
}
|
||||
|
||||
if (!is_last) {
|
||||
/* If this is not the last block, update the command and distance prefix
|
||||
codes for the next block and store the compressed forms. */
|
||||
cmd_code[0] = 0;
|
||||
*cmd_code_numbits = 0;
|
||||
BuildAndStoreCommandPrefixCode(cmd_histo, cmd_depth, cmd_bits,
|
||||
cmd_code_numbits, cmd_code);
|
||||
}
|
||||
}
|
||||
|
||||
#define FOR_TABLE_BITS_(X) X(9) X(11) X(13) X(15)
|
||||
|
||||
#define BAKE_METHOD_PARAM_(B) \
|
||||
static BROTLI_NOINLINE void BrotliCompressFragmentFastImpl ## B( \
|
||||
MemoryManager* m, const uint8_t* input, size_t input_size, \
|
||||
BROTLI_BOOL is_last, int* table, uint8_t cmd_depth[128], \
|
||||
uint16_t cmd_bits[128], size_t* cmd_code_numbits, uint8_t* cmd_code, \
|
||||
size_t* storage_ix, uint8_t* storage) { \
|
||||
BrotliCompressFragmentFastImpl(m, input, input_size, is_last, table, B, \
|
||||
cmd_depth, cmd_bits, cmd_code_numbits, cmd_code, storage_ix, storage); \
|
||||
}
|
||||
FOR_TABLE_BITS_(BAKE_METHOD_PARAM_)
|
||||
#undef BAKE_METHOD_PARAM_
|
||||
|
||||
void BrotliCompressFragmentFast(
|
||||
MemoryManager* m, const uint8_t* input, size_t input_size,
|
||||
BROTLI_BOOL is_last, int* table, size_t table_size, uint8_t cmd_depth[128],
|
||||
uint16_t cmd_bits[128], size_t* cmd_code_numbits, uint8_t* cmd_code,
|
||||
size_t* storage_ix, uint8_t* storage) {
|
||||
const size_t initial_storage_ix = *storage_ix;
|
||||
const size_t table_bits = Log2FloorNonZero(table_size);
|
||||
|
||||
if (input_size == 0) {
|
||||
assert(is_last);
|
||||
BrotliWriteBits(1, 1, storage_ix, storage); /* islast */
|
||||
BrotliWriteBits(1, 1, storage_ix, storage); /* isempty */
|
||||
*storage_ix = (*storage_ix + 7u) & ~7u;
|
||||
return;
|
||||
}
|
||||
|
||||
switch (table_bits) {
|
||||
#define CASE_(B) \
|
||||
case B: \
|
||||
BrotliCompressFragmentFastImpl ## B( \
|
||||
m, input, input_size, is_last, table, cmd_depth, cmd_bits, \
|
||||
cmd_code_numbits, cmd_code, storage_ix, storage); \
|
||||
break;
|
||||
FOR_TABLE_BITS_(CASE_)
|
||||
#undef CASE_
|
||||
default: assert(0); break;
|
||||
}
|
||||
|
||||
/* If output is larger than single uncompressed block, rewrite it. */
|
||||
if (*storage_ix - initial_storage_ix > 31 + (input_size << 3)) {
|
||||
EmitUncompressedMetaBlock(input, input + input_size, initial_storage_ix,
|
||||
storage_ix, storage);
|
||||
}
|
||||
|
||||
if (is_last) {
|
||||
BrotliWriteBits(1, 1, storage_ix, storage); /* islast */
|
||||
BrotliWriteBits(1, 1, storage_ix, storage); /* isempty */
|
||||
*storage_ix = (*storage_ix + 7u) & ~7u;
|
||||
}
|
||||
}
|
||||
|
||||
#undef FOR_TABLE_BITS_
|
||||
|
||||
#if defined(__cplusplus) || defined(c_plusplus)
|
||||
} /* extern "C" */
|
||||
#endif
|
612
vendor/github.com/cloudflare/brotli-go/compress_fragment_two_pass.c
generated
vendored
Normal file
612
vendor/github.com/cloudflare/brotli-go/compress_fragment_two_pass.c
generated
vendored
Normal file
|
@ -0,0 +1,612 @@
|
|||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
/* Function for fast encoding of an input fragment, independently from the input
|
||||
history. This function uses two-pass processing: in the first pass we save
|
||||
the found backward matches and literal bytes into a buffer, and in the
|
||||
second pass we emit them into the bit stream using prefix codes built based
|
||||
on the actual command and literal byte histograms. */
|
||||
|
||||
#include "./enc/compress_fragment_two_pass.h"
|
||||
|
||||
#include <string.h> /* memcmp, memcpy, memset */
|
||||
|
||||
#include "./common/constants.h"
|
||||
#include <brotli/types.h>
|
||||
#include "./enc/bit_cost.h"
|
||||
#include "./enc/brotli_bit_stream.h"
|
||||
#include "./enc/entropy_encode.h"
|
||||
#include "./enc/fast_log.h"
|
||||
#include "./enc/find_match_length.h"
|
||||
#include "./enc/memory.h"
|
||||
#include "./enc/port.h"
|
||||
#include "./enc/write_bits.h"
|
||||
|
||||
|
||||
#if defined(__cplusplus) || defined(c_plusplus)
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#define MAX_DISTANCE (long)BROTLI_MAX_BACKWARD_LIMIT(18)
|
||||
|
||||
/* kHashMul32 multiplier has these properties:
|
||||
* The multiplier must be odd. Otherwise we may lose the highest bit.
|
||||
* No long streaks of ones or zeros.
|
||||
* There is no effort to ensure that it is a prime, the oddity is enough
|
||||
for this use.
|
||||
* The number has been tuned heuristically against compression benchmarks. */
|
||||
static const uint32_t kHashMul32 = 0x1e35a7bd;
|
||||
|
||||
static BROTLI_INLINE uint32_t Hash(const uint8_t* p, size_t shift) {
|
||||
const uint64_t h = (BROTLI_UNALIGNED_LOAD64(p) << 16) * kHashMul32;
|
||||
return (uint32_t)(h >> shift);
|
||||
}
|
||||
|
||||
static BROTLI_INLINE uint32_t HashBytesAtOffset(
|
||||
uint64_t v, int offset, size_t shift) {
|
||||
assert(offset >= 0);
|
||||
assert(offset <= 2);
|
||||
{
|
||||
const uint64_t h = ((v >> (8 * offset)) << 16) * kHashMul32;
|
||||
return (uint32_t)(h >> shift);
|
||||
}
|
||||
}
|
||||
|
||||
static BROTLI_INLINE BROTLI_BOOL IsMatch(const uint8_t* p1, const uint8_t* p2) {
|
||||
return TO_BROTLI_BOOL(
|
||||
BROTLI_UNALIGNED_LOAD32(p1) == BROTLI_UNALIGNED_LOAD32(p2) &&
|
||||
p1[4] == p2[4] &&
|
||||
p1[5] == p2[5]);
|
||||
}
|
||||
|
||||
/* Builds a command and distance prefix code (each 64 symbols) into "depth" and
|
||||
"bits" based on "histogram" and stores it into the bit stream. */
|
||||
static void BuildAndStoreCommandPrefixCode(
|
||||
const uint32_t histogram[128],
|
||||
uint8_t depth[128], uint16_t bits[128],
|
||||
size_t* storage_ix, uint8_t* storage) {
|
||||
/* Tree size for building a tree over 64 symbols is 2 * 64 + 1. */
|
||||
HuffmanTree tree[129];
|
||||
uint8_t cmd_depth[BROTLI_NUM_COMMAND_SYMBOLS] = { 0 };
|
||||
uint16_t cmd_bits[64];
|
||||
BrotliCreateHuffmanTree(histogram, 64, 15, tree, depth);
|
||||
BrotliCreateHuffmanTree(&histogram[64], 64, 14, tree, &depth[64]);
|
||||
/* We have to jump through a few hoops here in order to compute
|
||||
the command bits because the symbols are in a different order than in
|
||||
the full alphabet. This looks complicated, but having the symbols
|
||||
in this order in the command bits saves a few branches in the Emit*
|
||||
functions. */
|
||||
memcpy(cmd_depth, depth + 24, 24);
|
||||
memcpy(cmd_depth + 24, depth, 8);
|
||||
memcpy(cmd_depth + 32, depth + 48, 8);
|
||||
memcpy(cmd_depth + 40, depth + 8, 8);
|
||||
memcpy(cmd_depth + 48, depth + 56, 8);
|
||||
memcpy(cmd_depth + 56, depth + 16, 8);
|
||||
BrotliConvertBitDepthsToSymbols(cmd_depth, 64, cmd_bits);
|
||||
memcpy(bits, cmd_bits + 24, 16);
|
||||
memcpy(bits + 8, cmd_bits + 40, 16);
|
||||
memcpy(bits + 16, cmd_bits + 56, 16);
|
||||
memcpy(bits + 24, cmd_bits, 48);
|
||||
memcpy(bits + 48, cmd_bits + 32, 16);
|
||||
memcpy(bits + 56, cmd_bits + 48, 16);
|
||||
BrotliConvertBitDepthsToSymbols(&depth[64], 64, &bits[64]);
|
||||
{
|
||||
/* Create the bit length array for the full command alphabet. */
|
||||
size_t i;
|
||||
memset(cmd_depth, 0, 64); /* only 64 first values were used */
|
||||
memcpy(cmd_depth, depth + 24, 8);
|
||||
memcpy(cmd_depth + 64, depth + 32, 8);
|
||||
memcpy(cmd_depth + 128, depth + 40, 8);
|
||||
memcpy(cmd_depth + 192, depth + 48, 8);
|
||||
memcpy(cmd_depth + 384, depth + 56, 8);
|
||||
for (i = 0; i < 8; ++i) {
|
||||
cmd_depth[128 + 8 * i] = depth[i];
|
||||
cmd_depth[256 + 8 * i] = depth[8 + i];
|
||||
cmd_depth[448 + 8 * i] = depth[16 + i];
|
||||
}
|
||||
BrotliStoreHuffmanTree(
|
||||
cmd_depth, BROTLI_NUM_COMMAND_SYMBOLS, tree, storage_ix, storage);
|
||||
}
|
||||
BrotliStoreHuffmanTree(&depth[64], 64, tree, storage_ix, storage);
|
||||
}
|
||||
|
||||
static BROTLI_INLINE void EmitInsertLen(
|
||||
uint32_t insertlen, uint32_t** commands) {
|
||||
if (insertlen < 6) {
|
||||
**commands = insertlen;
|
||||
} else if (insertlen < 130) {
|
||||
const uint32_t tail = insertlen - 2;
|
||||
const uint32_t nbits = Log2FloorNonZero(tail) - 1u;
|
||||
const uint32_t prefix = tail >> nbits;
|
||||
const uint32_t inscode = (nbits << 1) + prefix + 2;
|
||||
const uint32_t extra = tail - (prefix << nbits);
|
||||
**commands = inscode | (extra << 8);
|
||||
} else if (insertlen < 2114) {
|
||||
const uint32_t tail = insertlen - 66;
|
||||
const uint32_t nbits = Log2FloorNonZero(tail);
|
||||
const uint32_t code = nbits + 10;
|
||||
const uint32_t extra = tail - (1u << nbits);
|
||||
**commands = code | (extra << 8);
|
||||
} else if (insertlen < 6210) {
|
||||
const uint32_t extra = insertlen - 2114;
|
||||
**commands = 21 | (extra << 8);
|
||||
} else if (insertlen < 22594) {
|
||||
const uint32_t extra = insertlen - 6210;
|
||||
**commands = 22 | (extra << 8);
|
||||
} else {
|
||||
const uint32_t extra = insertlen - 22594;
|
||||
**commands = 23 | (extra << 8);
|
||||
}
|
||||
++(*commands);
|
||||
}
|
||||
|
||||
static BROTLI_INLINE void EmitCopyLen(size_t copylen, uint32_t** commands) {
|
||||
if (copylen < 10) {
|
||||
**commands = (uint32_t)(copylen + 38);
|
||||
} else if (copylen < 134) {
|
||||
const size_t tail = copylen - 6;
|
||||
const size_t nbits = Log2FloorNonZero(tail) - 1;
|
||||
const size_t prefix = tail >> nbits;
|
||||
const size_t code = (nbits << 1) + prefix + 44;
|
||||
const size_t extra = tail - (prefix << nbits);
|
||||
**commands = (uint32_t)(code | (extra << 8));
|
||||
} else if (copylen < 2118) {
|
||||
const size_t tail = copylen - 70;
|
||||
const size_t nbits = Log2FloorNonZero(tail);
|
||||
const size_t code = nbits + 52;
|
||||
const size_t extra = tail - ((size_t)1 << nbits);
|
||||
**commands = (uint32_t)(code | (extra << 8));
|
||||
} else {
|
||||
const size_t extra = copylen - 2118;
|
||||
**commands = (uint32_t)(63 | (extra << 8));
|
||||
}
|
||||
++(*commands);
|
||||
}
|
||||
|
||||
static BROTLI_INLINE void EmitCopyLenLastDistance(
|
||||
size_t copylen, uint32_t** commands) {
|
||||
if (copylen < 12) {
|
||||
**commands = (uint32_t)(copylen + 20);
|
||||
++(*commands);
|
||||
} else if (copylen < 72) {
|
||||
const size_t tail = copylen - 8;
|
||||
const size_t nbits = Log2FloorNonZero(tail) - 1;
|
||||
const size_t prefix = tail >> nbits;
|
||||
const size_t code = (nbits << 1) + prefix + 28;
|
||||
const size_t extra = tail - (prefix << nbits);
|
||||
**commands = (uint32_t)(code | (extra << 8));
|
||||
++(*commands);
|
||||
} else if (copylen < 136) {
|
||||
const size_t tail = copylen - 8;
|
||||
const size_t code = (tail >> 5) + 54;
|
||||
const size_t extra = tail & 31;
|
||||
**commands = (uint32_t)(code | (extra << 8));
|
||||
++(*commands);
|
||||
**commands = 64;
|
||||
++(*commands);
|
||||
} else if (copylen < 2120) {
|
||||
const size_t tail = copylen - 72;
|
||||
const size_t nbits = Log2FloorNonZero(tail);
|
||||
const size_t code = nbits + 52;
|
||||
const size_t extra = tail - ((size_t)1 << nbits);
|
||||
**commands = (uint32_t)(code | (extra << 8));
|
||||
++(*commands);
|
||||
**commands = 64;
|
||||
++(*commands);
|
||||
} else {
|
||||
const size_t extra = copylen - 2120;
|
||||
**commands = (uint32_t)(63 | (extra << 8));
|
||||
++(*commands);
|
||||
**commands = 64;
|
||||
++(*commands);
|
||||
}
|
||||
}
|
||||
|
||||
static BROTLI_INLINE void EmitDistance(uint32_t distance, uint32_t** commands) {
|
||||
uint32_t d = distance + 3;
|
||||
uint32_t nbits = Log2FloorNonZero(d) - 1;
|
||||
const uint32_t prefix = (d >> nbits) & 1;
|
||||
const uint32_t offset = (2 + prefix) << nbits;
|
||||
const uint32_t distcode = 2 * (nbits - 1) + prefix + 80;
|
||||
uint32_t extra = d - offset;
|
||||
**commands = distcode | (extra << 8);
|
||||
++(*commands);
|
||||
}
|
||||
|
||||
/* REQUIRES: len <= 1 << 24. */
|
||||
static void BrotliStoreMetaBlockHeader(
|
||||
size_t len, BROTLI_BOOL is_uncompressed, size_t* storage_ix,
|
||||
uint8_t* storage) {
|
||||
size_t nibbles = 6;
|
||||
/* ISLAST */
|
||||
BrotliWriteBits(1, 0, storage_ix, storage);
|
||||
if (len <= (1U << 16)) {
|
||||
nibbles = 4;
|
||||
} else if (len <= (1U << 20)) {
|
||||
nibbles = 5;
|
||||
}
|
||||
BrotliWriteBits(2, nibbles - 4, storage_ix, storage);
|
||||
BrotliWriteBits(nibbles * 4, len - 1, storage_ix, storage);
|
||||
/* ISUNCOMPRESSED */
|
||||
BrotliWriteBits(1, (uint64_t)is_uncompressed, storage_ix, storage);
|
||||
}
|
||||
|
||||
static BROTLI_INLINE void CreateCommands(const uint8_t* input,
|
||||
size_t block_size, size_t input_size, const uint8_t* base_ip, int* table,
|
||||
size_t table_bits, uint8_t** literals, uint32_t** commands) {
|
||||
/* "ip" is the input pointer. */
|
||||
const uint8_t* ip = input;
|
||||
const size_t shift = 64u - table_bits;
|
||||
const uint8_t* ip_end = input + block_size;
|
||||
/* "next_emit" is a pointer to the first byte that is not covered by a
|
||||
previous copy. Bytes between "next_emit" and the start of the next copy or
|
||||
the end of the input will be emitted as literal bytes. */
|
||||
const uint8_t* next_emit = input;
|
||||
|
||||
int last_distance = -1;
|
||||
const size_t kInputMarginBytes = BROTLI_WINDOW_GAP;
|
||||
const size_t kMinMatchLen = 6;
|
||||
|
||||
if (BROTLI_PREDICT_TRUE(block_size >= kInputMarginBytes)) {
|
||||
/* For the last block, we need to keep a 16 bytes margin so that we can be
|
||||
sure that all distances are at most window size - 16.
|
||||
For all other blocks, we only need to keep a margin of 5 bytes so that
|
||||
we don't go over the block size with a copy. */
|
||||
const size_t len_limit = BROTLI_MIN(size_t, block_size - kMinMatchLen,
|
||||
input_size - kInputMarginBytes);
|
||||
const uint8_t* ip_limit = input + len_limit;
|
||||
|
||||
uint32_t next_hash;
|
||||
for (next_hash = Hash(++ip, shift); ; ) {
|
||||
/* Step 1: Scan forward in the input looking for a 6-byte-long match.
|
||||
If we get close to exhausting the input then goto emit_remainder.
|
||||
|
||||
Heuristic match skipping: If 32 bytes are scanned with no matches
|
||||
found, start looking only at every other byte. If 32 more bytes are
|
||||
scanned, look at every third byte, etc.. When a match is found,
|
||||
immediately go back to looking at every byte. This is a small loss
|
||||
(~5% performance, ~0.1% density) for compressible data due to more
|
||||
bookkeeping, but for non-compressible data (such as JPEG) it's a huge
|
||||
win since the compressor quickly "realizes" the data is incompressible
|
||||
and doesn't bother looking for matches everywhere.
|
||||
|
||||
The "skip" variable keeps track of how many bytes there are since the
|
||||
last match; dividing it by 32 (ie. right-shifting by five) gives the
|
||||
number of bytes to move ahead for each iteration. */
|
||||
uint32_t skip = 32;
|
||||
|
||||
const uint8_t* next_ip = ip;
|
||||
const uint8_t* candidate;
|
||||
|
||||
assert(next_emit < ip);
|
||||
trawl:
|
||||
do {
|
||||
uint32_t hash = next_hash;
|
||||
uint32_t bytes_between_hash_lookups = skip++ >> 5;
|
||||
ip = next_ip;
|
||||
assert(hash == Hash(ip, shift));
|
||||
next_ip = ip + bytes_between_hash_lookups;
|
||||
if (BROTLI_PREDICT_FALSE(next_ip > ip_limit)) {
|
||||
goto emit_remainder;
|
||||
}
|
||||
next_hash = Hash(next_ip, shift);
|
||||
candidate = ip - last_distance;
|
||||
if (IsMatch(ip, candidate)) {
|
||||
if (BROTLI_PREDICT_TRUE(candidate < ip)) {
|
||||
table[hash] = (int)(ip - base_ip);
|
||||
break;
|
||||
}
|
||||
}
|
||||
candidate = base_ip + table[hash];
|
||||
assert(candidate >= base_ip);
|
||||
assert(candidate < ip);
|
||||
|
||||
table[hash] = (int)(ip - base_ip);
|
||||
} while (BROTLI_PREDICT_TRUE(!IsMatch(ip, candidate)));
|
||||
|
||||
/* Check copy distance. If candidate is not feasible, continue search.
|
||||
Checking is done outside of hot loop to reduce overhead. */
|
||||
if (ip - candidate > MAX_DISTANCE) goto trawl;
|
||||
|
||||
/* Step 2: Emit the found match together with the literal bytes from
|
||||
"next_emit", and then see if we can find a next match immediately
|
||||
afterwards. Repeat until we find no match for the input
|
||||
without emitting some literal bytes. */
|
||||
|
||||
{
|
||||
/* We have a 6-byte match at ip, and we need to emit bytes in
|
||||
[next_emit, ip). */
|
||||
const uint8_t* base = ip;
|
||||
size_t matched = 6 + FindMatchLengthWithLimit(
|
||||
candidate + 6, ip + 6, (size_t)(ip_end - ip) - 6);
|
||||
int distance = (int)(base - candidate); /* > 0 */
|
||||
int insert = (int)(base - next_emit);
|
||||
ip += matched;
|
||||
assert(0 == memcmp(base, candidate, matched));
|
||||
EmitInsertLen((uint32_t)insert, commands);
|
||||
memcpy(*literals, next_emit, (size_t)insert);
|
||||
*literals += insert;
|
||||
if (distance == last_distance) {
|
||||
**commands = 64;
|
||||
++(*commands);
|
||||
} else {
|
||||
EmitDistance((uint32_t)distance, commands);
|
||||
last_distance = distance;
|
||||
}
|
||||
EmitCopyLenLastDistance(matched, commands);
|
||||
|
||||
next_emit = ip;
|
||||
if (BROTLI_PREDICT_FALSE(ip >= ip_limit)) {
|
||||
goto emit_remainder;
|
||||
}
|
||||
{
|
||||
/* We could immediately start working at ip now, but to improve
|
||||
compression we first update "table" with the hashes of some
|
||||
positions within the last copy. */
|
||||
uint64_t input_bytes = BROTLI_UNALIGNED_LOAD64(ip - 5);
|
||||
uint32_t prev_hash = HashBytesAtOffset(input_bytes, 0, shift);
|
||||
uint32_t cur_hash;
|
||||
table[prev_hash] = (int)(ip - base_ip - 5);
|
||||
prev_hash = HashBytesAtOffset(input_bytes, 1, shift);
|
||||
table[prev_hash] = (int)(ip - base_ip - 4);
|
||||
prev_hash = HashBytesAtOffset(input_bytes, 2, shift);
|
||||
table[prev_hash] = (int)(ip - base_ip - 3);
|
||||
input_bytes = BROTLI_UNALIGNED_LOAD64(ip - 2);
|
||||
cur_hash = HashBytesAtOffset(input_bytes, 2, shift);
|
||||
prev_hash = HashBytesAtOffset(input_bytes, 0, shift);
|
||||
table[prev_hash] = (int)(ip - base_ip - 2);
|
||||
prev_hash = HashBytesAtOffset(input_bytes, 1, shift);
|
||||
table[prev_hash] = (int)(ip - base_ip - 1);
|
||||
|
||||
candidate = base_ip + table[cur_hash];
|
||||
table[cur_hash] = (int)(ip - base_ip);
|
||||
}
|
||||
}
|
||||
|
||||
while (ip - candidate <= MAX_DISTANCE && IsMatch(ip, candidate)) {
|
||||
/* We have a 6-byte match at ip, and no need to emit any
|
||||
literal bytes prior to ip. */
|
||||
const uint8_t* base = ip;
|
||||
size_t matched = 6 + FindMatchLengthWithLimit(
|
||||
candidate + 6, ip + 6, (size_t)(ip_end - ip) - 6);
|
||||
ip += matched;
|
||||
last_distance = (int)(base - candidate); /* > 0 */
|
||||
assert(0 == memcmp(base, candidate, matched));
|
||||
EmitCopyLen(matched, commands);
|
||||
EmitDistance((uint32_t)last_distance, commands);
|
||||
|
||||
next_emit = ip;
|
||||
if (BROTLI_PREDICT_FALSE(ip >= ip_limit)) {
|
||||
goto emit_remainder;
|
||||
}
|
||||
{
|
||||
/* We could immediately start working at ip now, but to improve
|
||||
compression we first update "table" with the hashes of some
|
||||
positions within the last copy. */
|
||||
uint64_t input_bytes = BROTLI_UNALIGNED_LOAD64(ip - 5);
|
||||
uint32_t prev_hash = HashBytesAtOffset(input_bytes, 0, shift);
|
||||
uint32_t cur_hash;
|
||||
table[prev_hash] = (int)(ip - base_ip - 5);
|
||||
prev_hash = HashBytesAtOffset(input_bytes, 1, shift);
|
||||
table[prev_hash] = (int)(ip - base_ip - 4);
|
||||
prev_hash = HashBytesAtOffset(input_bytes, 2, shift);
|
||||
table[prev_hash] = (int)(ip - base_ip - 3);
|
||||
input_bytes = BROTLI_UNALIGNED_LOAD64(ip - 2);
|
||||
cur_hash = HashBytesAtOffset(input_bytes, 2, shift);
|
||||
prev_hash = HashBytesAtOffset(input_bytes, 0, shift);
|
||||
table[prev_hash] = (int)(ip - base_ip - 2);
|
||||
prev_hash = HashBytesAtOffset(input_bytes, 1, shift);
|
||||
table[prev_hash] = (int)(ip - base_ip - 1);
|
||||
|
||||
candidate = base_ip + table[cur_hash];
|
||||
table[cur_hash] = (int)(ip - base_ip);
|
||||
}
|
||||
}
|
||||
|
||||
next_hash = Hash(++ip, shift);
|
||||
}
|
||||
}
|
||||
|
||||
emit_remainder:
|
||||
assert(next_emit <= ip_end);
|
||||
/* Emit the remaining bytes as literals. */
|
||||
if (next_emit < ip_end) {
|
||||
const uint32_t insert = (uint32_t)(ip_end - next_emit);
|
||||
EmitInsertLen(insert, commands);
|
||||
memcpy(*literals, next_emit, insert);
|
||||
*literals += insert;
|
||||
}
|
||||
}
|
||||
|
||||
static void StoreCommands(MemoryManager* m,
|
||||
const uint8_t* literals, const size_t num_literals,
|
||||
const uint32_t* commands, const size_t num_commands,
|
||||
size_t* storage_ix, uint8_t* storage) {
|
||||
static const uint32_t kNumExtraBits[128] = {
|
||||
0, 0, 0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 7, 8, 9, 10, 12, 14, 24,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 7, 8, 9, 10, 24,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8,
|
||||
9, 9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15, 16, 16,
|
||||
17, 17, 18, 18, 19, 19, 20, 20, 21, 21, 22, 22, 23, 23, 24, 24,
|
||||
};
|
||||
static const uint32_t kInsertOffset[24] = {
|
||||
0, 1, 2, 3, 4, 5, 6, 8, 10, 14, 18, 26, 34, 50, 66, 98, 130, 194, 322, 578,
|
||||
1090, 2114, 6210, 22594,
|
||||
};
|
||||
|
||||
uint8_t lit_depths[256];
|
||||
uint16_t lit_bits[256];
|
||||
uint32_t lit_histo[256] = { 0 };
|
||||
uint8_t cmd_depths[128] = { 0 };
|
||||
uint16_t cmd_bits[128] = { 0 };
|
||||
uint32_t cmd_histo[128] = { 0 };
|
||||
size_t i;
|
||||
for (i = 0; i < num_literals; ++i) {
|
||||
++lit_histo[literals[i]];
|
||||
}
|
||||
BrotliBuildAndStoreHuffmanTreeFast(m, lit_histo, num_literals,
|
||||
/* max_bits = */ 8,
|
||||
lit_depths, lit_bits,
|
||||
storage_ix, storage);
|
||||
if (BROTLI_IS_OOM(m)) return;
|
||||
|
||||
for (i = 0; i < num_commands; ++i) {
|
||||
const uint32_t code = commands[i] & 0xFF;
|
||||
assert(code < 128);
|
||||
++cmd_histo[code];
|
||||
}
|
||||
cmd_histo[1] += 1;
|
||||
cmd_histo[2] += 1;
|
||||
cmd_histo[64] += 1;
|
||||
cmd_histo[84] += 1;
|
||||
BuildAndStoreCommandPrefixCode(cmd_histo, cmd_depths, cmd_bits,
|
||||
storage_ix, storage);
|
||||
|
||||
for (i = 0; i < num_commands; ++i) {
|
||||
const uint32_t cmd = commands[i];
|
||||
const uint32_t code = cmd & 0xFF;
|
||||
const uint32_t extra = cmd >> 8;
|
||||
assert(code < 128);
|
||||
BrotliWriteBits(cmd_depths[code], cmd_bits[code], storage_ix, storage);
|
||||
BrotliWriteBits(kNumExtraBits[code], extra, storage_ix, storage);
|
||||
if (code < 24) {
|
||||
const uint32_t insert = kInsertOffset[code] + extra;
|
||||
uint32_t j;
|
||||
for (j = 0; j < insert; ++j) {
|
||||
const uint8_t lit = *literals;
|
||||
BrotliWriteBits(lit_depths[lit], lit_bits[lit], storage_ix, storage);
|
||||
++literals;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Acceptable loss for uncompressible speedup is 2% */
|
||||
#define MIN_RATIO 0.98
|
||||
#define SAMPLE_RATE 43
|
||||
|
||||
static BROTLI_BOOL ShouldCompress(
|
||||
const uint8_t* input, size_t input_size, size_t num_literals) {
|
||||
double corpus_size = (double)input_size;
|
||||
if (num_literals < MIN_RATIO * corpus_size) {
|
||||
return BROTLI_TRUE;
|
||||
} else {
|
||||
uint32_t literal_histo[256] = { 0 };
|
||||
const double max_total_bit_cost = corpus_size * 8 * MIN_RATIO / SAMPLE_RATE;
|
||||
size_t i;
|
||||
for (i = 0; i < input_size; i += SAMPLE_RATE) {
|
||||
++literal_histo[input[i]];
|
||||
}
|
||||
return TO_BROTLI_BOOL(BitsEntropy(literal_histo, 256) < max_total_bit_cost);
|
||||
}
|
||||
}
|
||||
|
||||
static void RewindBitPosition(const size_t new_storage_ix,
|
||||
size_t* storage_ix, uint8_t* storage) {
|
||||
const size_t bitpos = new_storage_ix & 7;
|
||||
const size_t mask = (1u << bitpos) - 1;
|
||||
storage[new_storage_ix >> 3] &= (uint8_t)mask;
|
||||
*storage_ix = new_storage_ix;
|
||||
}
|
||||
|
||||
static void EmitUncompressedMetaBlock(const uint8_t* input, size_t input_size,
|
||||
size_t* storage_ix, uint8_t* storage) {
|
||||
BrotliStoreMetaBlockHeader(input_size, 1, storage_ix, storage);
|
||||
*storage_ix = (*storage_ix + 7u) & ~7u;
|
||||
memcpy(&storage[*storage_ix >> 3], input, input_size);
|
||||
*storage_ix += input_size << 3;
|
||||
storage[*storage_ix >> 3] = 0;
|
||||
}
|
||||
|
||||
static BROTLI_INLINE void BrotliCompressFragmentTwoPassImpl(
|
||||
MemoryManager* m, const uint8_t* input, size_t input_size,
|
||||
BROTLI_BOOL is_last, uint32_t* command_buf, uint8_t* literal_buf,
|
||||
int* table, size_t table_bits, size_t* storage_ix, uint8_t* storage) {
|
||||
/* Save the start of the first block for position and distance computations.
|
||||
*/
|
||||
const uint8_t* base_ip = input;
|
||||
BROTLI_UNUSED(is_last);
|
||||
|
||||
while (input_size > 0) {
|
||||
size_t block_size =
|
||||
BROTLI_MIN(size_t, input_size, kCompressFragmentTwoPassBlockSize);
|
||||
uint32_t* commands = command_buf;
|
||||
uint8_t* literals = literal_buf;
|
||||
size_t num_literals;
|
||||
CreateCommands(input, block_size, input_size, base_ip, table, table_bits,
|
||||
&literals, &commands);
|
||||
num_literals = (size_t)(literals - literal_buf);
|
||||
if (ShouldCompress(input, block_size, num_literals)) {
|
||||
const size_t num_commands = (size_t)(commands - command_buf);
|
||||
BrotliStoreMetaBlockHeader(block_size, 0, storage_ix, storage);
|
||||
/* No block splits, no contexts. */
|
||||
BrotliWriteBits(13, 0, storage_ix, storage);
|
||||
StoreCommands(m, literal_buf, num_literals, command_buf, num_commands,
|
||||
storage_ix, storage);
|
||||
if (BROTLI_IS_OOM(m)) return;
|
||||
} else {
|
||||
/* Since we did not find many backward references and the entropy of
|
||||
the data is close to 8 bits, we can simply emit an uncompressed block.
|
||||
This makes compression speed of uncompressible data about 3x faster. */
|
||||
EmitUncompressedMetaBlock(input, block_size, storage_ix, storage);
|
||||
}
|
||||
input += block_size;
|
||||
input_size -= block_size;
|
||||
}
|
||||
}
|
||||
|
||||
#define FOR_TABLE_BITS_(X) \
|
||||
X(8) X(9) X(10) X(11) X(12) X(13) X(14) X(15) X(16) X(17)
|
||||
|
||||
#define BAKE_METHOD_PARAM_(B) \
|
||||
static BROTLI_NOINLINE void BrotliCompressFragmentTwoPassImpl ## B( \
|
||||
MemoryManager* m, const uint8_t* input, size_t input_size, \
|
||||
BROTLI_BOOL is_last, uint32_t* command_buf, uint8_t* literal_buf, \
|
||||
int* table, size_t* storage_ix, uint8_t* storage) { \
|
||||
BrotliCompressFragmentTwoPassImpl(m, input, input_size, is_last, command_buf,\
|
||||
literal_buf, table, B, storage_ix, storage); \
|
||||
}
|
||||
FOR_TABLE_BITS_(BAKE_METHOD_PARAM_)
|
||||
#undef BAKE_METHOD_PARAM_
|
||||
|
||||
void BrotliCompressFragmentTwoPass(
|
||||
MemoryManager* m, const uint8_t* input, size_t input_size,
|
||||
BROTLI_BOOL is_last, uint32_t* command_buf, uint8_t* literal_buf,
|
||||
int* table, size_t table_size, size_t* storage_ix, uint8_t* storage) {
|
||||
const size_t initial_storage_ix = *storage_ix;
|
||||
const size_t table_bits = Log2FloorNonZero(table_size);
|
||||
switch (table_bits) {
|
||||
#define CASE_(B) \
|
||||
case B: \
|
||||
BrotliCompressFragmentTwoPassImpl ## B( \
|
||||
m, input, input_size, is_last, command_buf, \
|
||||
literal_buf, table, storage_ix, storage); \
|
||||
break;
|
||||
FOR_TABLE_BITS_(CASE_)
|
||||
#undef CASE_
|
||||
default: assert(0); break;
|
||||
}
|
||||
|
||||
/* If output is larger than single uncompressed block, rewrite it. */
|
||||
if (*storage_ix - initial_storage_ix > 31 + (input_size << 3)) {
|
||||
RewindBitPosition(initial_storage_ix, storage_ix, storage);
|
||||
EmitUncompressedMetaBlock(input, input_size, storage_ix, storage);
|
||||
}
|
||||
|
||||
if (is_last) {
|
||||
BrotliWriteBits(1, 1, storage_ix, storage); /* islast */
|
||||
BrotliWriteBits(1, 1, storage_ix, storage); /* isempty */
|
||||
*storage_ix = (*storage_ix + 7u) & ~7u;
|
||||
}
|
||||
}
|
||||
|
||||
#undef FOR_TABLE_BITS_
|
||||
|
||||
#if defined(__cplusplus) || defined(c_plusplus)
|
||||
} /* extern "C" */
|
||||
#endif
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,501 @@
|
|||
/* Copyright 2010 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
/* Entropy encoding (Huffman) utilities. */
|
||||
|
||||
#include "./enc/entropy_encode.h"
|
||||
|
||||
#include <string.h> /* memset */
|
||||
|
||||
#include "./common/constants.h"
|
||||
#include <brotli/types.h>
|
||||
#include "./enc/port.h"
|
||||
|
||||
#if defined(__cplusplus) || defined(c_plusplus)
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
BROTLI_BOOL BrotliSetDepth(
|
||||
int p0, HuffmanTree* pool, uint8_t* depth, int max_depth) {
|
||||
int stack[16];
|
||||
int level = 0;
|
||||
int p = p0;
|
||||
assert(max_depth <= 15);
|
||||
stack[0] = -1;
|
||||
while (BROTLI_TRUE) {
|
||||
if (pool[p].index_left_ >= 0) {
|
||||
level++;
|
||||
if (level > max_depth) return BROTLI_FALSE;
|
||||
stack[level] = pool[p].index_right_or_value_;
|
||||
p = pool[p].index_left_;
|
||||
continue;
|
||||
} else {
|
||||
depth[pool[p].index_right_or_value_] = (uint8_t)level;
|
||||
}
|
||||
while (level >= 0 && stack[level] == -1) level--;
|
||||
if (level < 0) return BROTLI_TRUE;
|
||||
p = stack[level];
|
||||
stack[level] = -1;
|
||||
}
|
||||
}
|
||||
|
||||
/* Sort the root nodes, least popular first. */
|
||||
static BROTLI_INLINE BROTLI_BOOL SortHuffmanTree(
|
||||
const HuffmanTree* v0, const HuffmanTree* v1) {
|
||||
if (v0->total_count_ != v1->total_count_) {
|
||||
return TO_BROTLI_BOOL(v0->total_count_ < v1->total_count_);
|
||||
}
|
||||
return TO_BROTLI_BOOL(v0->index_right_or_value_ > v1->index_right_or_value_);
|
||||
}
|
||||
|
||||
/* This function will create a Huffman tree.
|
||||
|
||||
The catch here is that the tree cannot be arbitrarily deep.
|
||||
Brotli specifies a maximum depth of 15 bits for "code trees"
|
||||
and 7 bits for "code length code trees."
|
||||
|
||||
count_limit is the value that is to be faked as the minimum value
|
||||
and this minimum value is raised until the tree matches the
|
||||
maximum length requirement.
|
||||
|
||||
This algorithm is not of excellent performance for very long data blocks,
|
||||
especially when population counts are longer than 2**tree_limit, but
|
||||
we are not planning to use this with extremely long blocks.
|
||||
|
||||
See http://en.wikipedia.org/wiki/Huffman_coding */
|
||||
void BrotliCreateHuffmanTree(const uint32_t *data,
|
||||
const size_t length,
|
||||
const int tree_limit,
|
||||
HuffmanTree* tree,
|
||||
uint8_t *depth) {
|
||||
uint32_t count_limit;
|
||||
HuffmanTree sentinel;
|
||||
InitHuffmanTree(&sentinel, BROTLI_UINT32_MAX, -1, -1);
|
||||
/* For block sizes below 64 kB, we never need to do a second iteration
|
||||
of this loop. Probably all of our block sizes will be smaller than
|
||||
that, so this loop is mostly of academic interest. If we actually
|
||||
would need this, we would be better off with the Katajainen algorithm. */
|
||||
for (count_limit = 1; ; count_limit *= 2) {
|
||||
size_t n = 0;
|
||||
size_t i;
|
||||
size_t j;
|
||||
size_t k;
|
||||
for (i = length; i != 0;) {
|
||||
--i;
|
||||
if (data[i]) {
|
||||
const uint32_t count = BROTLI_MAX(uint32_t, data[i], count_limit);
|
||||
InitHuffmanTree(&tree[n++], count, -1, (int16_t)i);
|
||||
}
|
||||
}
|
||||
|
||||
if (n == 1) {
|
||||
depth[tree[0].index_right_or_value_] = 1; /* Only one element. */
|
||||
break;
|
||||
}
|
||||
|
||||
SortHuffmanTreeItems(tree, n, SortHuffmanTree);
|
||||
|
||||
/* The nodes are:
|
||||
[0, n): the sorted leaf nodes that we start with.
|
||||
[n]: we add a sentinel here.
|
||||
[n + 1, 2n): new parent nodes are added here, starting from
|
||||
(n+1). These are naturally in ascending order.
|
||||
[2n]: we add a sentinel at the end as well.
|
||||
There will be (2n+1) elements at the end. */
|
||||
tree[n] = sentinel;
|
||||
tree[n + 1] = sentinel;
|
||||
|
||||
i = 0; /* Points to the next leaf node. */
|
||||
j = n + 1; /* Points to the next non-leaf node. */
|
||||
for (k = n - 1; k != 0; --k) {
|
||||
size_t left, right;
|
||||
if (tree[i].total_count_ <= tree[j].total_count_) {
|
||||
left = i;
|
||||
++i;
|
||||
} else {
|
||||
left = j;
|
||||
++j;
|
||||
}
|
||||
if (tree[i].total_count_ <= tree[j].total_count_) {
|
||||
right = i;
|
||||
++i;
|
||||
} else {
|
||||
right = j;
|
||||
++j;
|
||||
}
|
||||
|
||||
{
|
||||
/* The sentinel node becomes the parent node. */
|
||||
size_t j_end = 2 * n - k;
|
||||
tree[j_end].total_count_ =
|
||||
tree[left].total_count_ + tree[right].total_count_;
|
||||
tree[j_end].index_left_ = (int16_t)left;
|
||||
tree[j_end].index_right_or_value_ = (int16_t)right;
|
||||
|
||||
/* Add back the last sentinel node. */
|
||||
tree[j_end + 1] = sentinel;
|
||||
}
|
||||
}
|
||||
if (BrotliSetDepth((int)(2 * n - 1), &tree[0], depth, tree_limit)) {
|
||||
/* We need to pack the Huffman tree in tree_limit bits. If this was not
|
||||
successful, add fake entities to the lowest values and retry. */
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void Reverse(uint8_t* v, size_t start, size_t end) {
|
||||
--end;
|
||||
while (start < end) {
|
||||
uint8_t tmp = v[start];
|
||||
v[start] = v[end];
|
||||
v[end] = tmp;
|
||||
++start;
|
||||
--end;
|
||||
}
|
||||
}
|
||||
|
||||
static void BrotliWriteHuffmanTreeRepetitions(
|
||||
const uint8_t previous_value,
|
||||
const uint8_t value,
|
||||
size_t repetitions,
|
||||
size_t* tree_size,
|
||||
uint8_t* tree,
|
||||
uint8_t* extra_bits_data) {
|
||||
assert(repetitions > 0);
|
||||
if (previous_value != value) {
|
||||
tree[*tree_size] = value;
|
||||
extra_bits_data[*tree_size] = 0;
|
||||
++(*tree_size);
|
||||
--repetitions;
|
||||
}
|
||||
if (repetitions == 7) {
|
||||
tree[*tree_size] = value;
|
||||
extra_bits_data[*tree_size] = 0;
|
||||
++(*tree_size);
|
||||
--repetitions;
|
||||
}
|
||||
if (repetitions < 3) {
|
||||
size_t i;
|
||||
for (i = 0; i < repetitions; ++i) {
|
||||
tree[*tree_size] = value;
|
||||
extra_bits_data[*tree_size] = 0;
|
||||
++(*tree_size);
|
||||
}
|
||||
} else {
|
||||
size_t start = *tree_size;
|
||||
repetitions -= 3;
|
||||
while (BROTLI_TRUE) {
|
||||
tree[*tree_size] = BROTLI_REPEAT_PREVIOUS_CODE_LENGTH;
|
||||
extra_bits_data[*tree_size] = repetitions & 0x3;
|
||||
++(*tree_size);
|
||||
repetitions >>= 2;
|
||||
if (repetitions == 0) {
|
||||
break;
|
||||
}
|
||||
--repetitions;
|
||||
}
|
||||
Reverse(tree, start, *tree_size);
|
||||
Reverse(extra_bits_data, start, *tree_size);
|
||||
}
|
||||
}
|
||||
|
||||
static void BrotliWriteHuffmanTreeRepetitionsZeros(
|
||||
size_t repetitions,
|
||||
size_t* tree_size,
|
||||
uint8_t* tree,
|
||||
uint8_t* extra_bits_data) {
|
||||
if (repetitions == 11) {
|
||||
tree[*tree_size] = 0;
|
||||
extra_bits_data[*tree_size] = 0;
|
||||
++(*tree_size);
|
||||
--repetitions;
|
||||
}
|
||||
if (repetitions < 3) {
|
||||
size_t i;
|
||||
for (i = 0; i < repetitions; ++i) {
|
||||
tree[*tree_size] = 0;
|
||||
extra_bits_data[*tree_size] = 0;
|
||||
++(*tree_size);
|
||||
}
|
||||
} else {
|
||||
size_t start = *tree_size;
|
||||
repetitions -= 3;
|
||||
while (BROTLI_TRUE) {
|
||||
tree[*tree_size] = BROTLI_REPEAT_ZERO_CODE_LENGTH;
|
||||
extra_bits_data[*tree_size] = repetitions & 0x7;
|
||||
++(*tree_size);
|
||||
repetitions >>= 3;
|
||||
if (repetitions == 0) {
|
||||
break;
|
||||
}
|
||||
--repetitions;
|
||||
}
|
||||
Reverse(tree, start, *tree_size);
|
||||
Reverse(extra_bits_data, start, *tree_size);
|
||||
}
|
||||
}
|
||||
|
||||
void BrotliOptimizeHuffmanCountsForRle(size_t length, uint32_t* counts,
|
||||
uint8_t* good_for_rle) {
|
||||
size_t nonzero_count = 0;
|
||||
size_t stride;
|
||||
size_t limit;
|
||||
size_t sum;
|
||||
const size_t streak_limit = 1240;
|
||||
/* Let's make the Huffman code more compatible with RLE encoding. */
|
||||
size_t i;
|
||||
for (i = 0; i < length; i++) {
|
||||
if (counts[i]) {
|
||||
++nonzero_count;
|
||||
}
|
||||
}
|
||||
if (nonzero_count < 16) {
|
||||
return;
|
||||
}
|
||||
while (length != 0 && counts[length - 1] == 0) {
|
||||
--length;
|
||||
}
|
||||
if (length == 0) {
|
||||
return; /* All zeros. */
|
||||
}
|
||||
/* Now counts[0..length - 1] does not have trailing zeros. */
|
||||
{
|
||||
size_t nonzeros = 0;
|
||||
uint32_t smallest_nonzero = 1 << 30;
|
||||
for (i = 0; i < length; ++i) {
|
||||
if (counts[i] != 0) {
|
||||
++nonzeros;
|
||||
if (smallest_nonzero > counts[i]) {
|
||||
smallest_nonzero = counts[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
if (nonzeros < 5) {
|
||||
/* Small histogram will model it well. */
|
||||
return;
|
||||
}
|
||||
if (smallest_nonzero < 4) {
|
||||
size_t zeros = length - nonzeros;
|
||||
if (zeros < 6) {
|
||||
for (i = 1; i < length - 1; ++i) {
|
||||
if (counts[i - 1] != 0 && counts[i] == 0 && counts[i + 1] != 0) {
|
||||
counts[i] = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (nonzeros < 28) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
/* 2) Let's mark all population counts that already can be encoded
|
||||
with an RLE code. */
|
||||
memset(good_for_rle, 0, length);
|
||||
{
|
||||
/* Let's not spoil any of the existing good RLE codes.
|
||||
Mark any seq of 0's that is longer as 5 as a good_for_rle.
|
||||
Mark any seq of non-0's that is longer as 7 as a good_for_rle. */
|
||||
uint32_t symbol = counts[0];
|
||||
size_t step = 0;
|
||||
for (i = 0; i <= length; ++i) {
|
||||
if (i == length || counts[i] != symbol) {
|
||||
if ((symbol == 0 && step >= 5) ||
|
||||
(symbol != 0 && step >= 7)) {
|
||||
size_t k;
|
||||
for (k = 0; k < step; ++k) {
|
||||
good_for_rle[i - k - 1] = 1;
|
||||
}
|
||||
}
|
||||
step = 1;
|
||||
if (i != length) {
|
||||
symbol = counts[i];
|
||||
}
|
||||
} else {
|
||||
++step;
|
||||
}
|
||||
}
|
||||
}
|
||||
/* 3) Let's replace those population counts that lead to more RLE codes.
|
||||
Math here is in 24.8 fixed point representation. */
|
||||
stride = 0;
|
||||
limit = 256 * (counts[0] + counts[1] + counts[2]) / 3 + 420;
|
||||
sum = 0;
|
||||
for (i = 0; i <= length; ++i) {
|
||||
if (i == length || good_for_rle[i] ||
|
||||
(i != 0 && good_for_rle[i - 1]) ||
|
||||
(256 * counts[i] - limit + streak_limit) >= 2 * streak_limit) {
|
||||
if (stride >= 4 || (stride >= 3 && sum == 0)) {
|
||||
size_t k;
|
||||
/* The stride must end, collapse what we have, if we have enough (4). */
|
||||
size_t count = (sum + stride / 2) / stride;
|
||||
if (count == 0) {
|
||||
count = 1;
|
||||
}
|
||||
if (sum == 0) {
|
||||
/* Don't make an all zeros stride to be upgraded to ones. */
|
||||
count = 0;
|
||||
}
|
||||
for (k = 0; k < stride; ++k) {
|
||||
/* We don't want to change value at counts[i],
|
||||
that is already belonging to the next stride. Thus - 1. */
|
||||
counts[i - k - 1] = (uint32_t)count;
|
||||
}
|
||||
}
|
||||
stride = 0;
|
||||
sum = 0;
|
||||
if (i < length - 2) {
|
||||
/* All interesting strides have a count of at least 4, */
|
||||
/* at least when non-zeros. */
|
||||
limit = 256 * (counts[i] + counts[i + 1] + counts[i + 2]) / 3 + 420;
|
||||
} else if (i < length) {
|
||||
limit = 256 * counts[i];
|
||||
} else {
|
||||
limit = 0;
|
||||
}
|
||||
}
|
||||
++stride;
|
||||
if (i != length) {
|
||||
sum += counts[i];
|
||||
if (stride >= 4) {
|
||||
limit = (256 * sum + stride / 2) / stride;
|
||||
}
|
||||
if (stride == 4) {
|
||||
limit += 120;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void DecideOverRleUse(const uint8_t* depth, const size_t length,
|
||||
BROTLI_BOOL *use_rle_for_non_zero,
|
||||
BROTLI_BOOL *use_rle_for_zero) {
|
||||
size_t total_reps_zero = 0;
|
||||
size_t total_reps_non_zero = 0;
|
||||
size_t count_reps_zero = 1;
|
||||
size_t count_reps_non_zero = 1;
|
||||
size_t i;
|
||||
for (i = 0; i < length;) {
|
||||
const uint8_t value = depth[i];
|
||||
size_t reps = 1;
|
||||
size_t k;
|
||||
for (k = i + 1; k < length && depth[k] == value; ++k) {
|
||||
++reps;
|
||||
}
|
||||
if (reps >= 3 && value == 0) {
|
||||
total_reps_zero += reps;
|
||||
++count_reps_zero;
|
||||
}
|
||||
if (reps >= 4 && value != 0) {
|
||||
total_reps_non_zero += reps;
|
||||
++count_reps_non_zero;
|
||||
}
|
||||
i += reps;
|
||||
}
|
||||
*use_rle_for_non_zero =
|
||||
TO_BROTLI_BOOL(total_reps_non_zero > count_reps_non_zero * 2);
|
||||
*use_rle_for_zero = TO_BROTLI_BOOL(total_reps_zero > count_reps_zero * 2);
|
||||
}
|
||||
|
||||
void BrotliWriteHuffmanTree(const uint8_t* depth,
|
||||
size_t length,
|
||||
size_t* tree_size,
|
||||
uint8_t* tree,
|
||||
uint8_t* extra_bits_data) {
|
||||
uint8_t previous_value = BROTLI_INITIAL_REPEATED_CODE_LENGTH;
|
||||
size_t i;
|
||||
BROTLI_BOOL use_rle_for_non_zero = BROTLI_FALSE;
|
||||
BROTLI_BOOL use_rle_for_zero = BROTLI_FALSE;
|
||||
|
||||
/* Throw away trailing zeros. */
|
||||
size_t new_length = length;
|
||||
for (i = 0; i < length; ++i) {
|
||||
if (depth[length - i - 1] == 0) {
|
||||
--new_length;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
/* First gather statistics on if it is a good idea to do RLE. */
|
||||
if (length > 50) {
|
||||
/* Find RLE coding for longer codes.
|
||||
Shorter codes seem not to benefit from RLE. */
|
||||
DecideOverRleUse(depth, new_length,
|
||||
&use_rle_for_non_zero, &use_rle_for_zero);
|
||||
}
|
||||
|
||||
/* Actual RLE coding. */
|
||||
for (i = 0; i < new_length;) {
|
||||
const uint8_t value = depth[i];
|
||||
size_t reps = 1;
|
||||
if ((value != 0 && use_rle_for_non_zero) ||
|
||||
(value == 0 && use_rle_for_zero)) {
|
||||
size_t k;
|
||||
for (k = i + 1; k < new_length && depth[k] == value; ++k) {
|
||||
++reps;
|
||||
}
|
||||
}
|
||||
if (value == 0) {
|
||||
BrotliWriteHuffmanTreeRepetitionsZeros(
|
||||
reps, tree_size, tree, extra_bits_data);
|
||||
} else {
|
||||
BrotliWriteHuffmanTreeRepetitions(previous_value,
|
||||
value, reps, tree_size,
|
||||
tree, extra_bits_data);
|
||||
previous_value = value;
|
||||
}
|
||||
i += reps;
|
||||
}
|
||||
}
|
||||
|
||||
static uint16_t BrotliReverseBits(size_t num_bits, uint16_t bits) {
|
||||
static const size_t kLut[16] = { /* Pre-reversed 4-bit values. */
|
||||
0x0, 0x8, 0x4, 0xc, 0x2, 0xa, 0x6, 0xe,
|
||||
0x1, 0x9, 0x5, 0xd, 0x3, 0xb, 0x7, 0xf
|
||||
};
|
||||
size_t retval = kLut[bits & 0xf];
|
||||
size_t i;
|
||||
for (i = 4; i < num_bits; i += 4) {
|
||||
retval <<= 4;
|
||||
bits = (uint16_t)(bits >> 4);
|
||||
retval |= kLut[bits & 0xf];
|
||||
}
|
||||
retval >>= ((0 - num_bits) & 0x3);
|
||||
return (uint16_t)retval;
|
||||
}
|
||||
|
||||
/* 0..15 are values for bits */
|
||||
#define MAX_HUFFMAN_BITS 16
|
||||
|
||||
void BrotliConvertBitDepthsToSymbols(const uint8_t *depth,
|
||||
size_t len,
|
||||
uint16_t *bits) {
|
||||
/* In Brotli, all bit depths are [1..15]
|
||||
0 bit depth means that the symbol does not exist. */
|
||||
uint16_t bl_count[MAX_HUFFMAN_BITS] = { 0 };
|
||||
uint16_t next_code[MAX_HUFFMAN_BITS];
|
||||
size_t i;
|
||||
int code = 0;
|
||||
for (i = 0; i < len; ++i) {
|
||||
++bl_count[depth[i]];
|
||||
}
|
||||
bl_count[0] = 0;
|
||||
next_code[0] = 0;
|
||||
for (i = 1; i < MAX_HUFFMAN_BITS; ++i) {
|
||||
code = (code + bl_count[i - 1]) << 1;
|
||||
next_code[i] = (uint16_t)code;
|
||||
}
|
||||
for (i = 0; i < len; ++i) {
|
||||
if (depth[i]) {
|
||||
bits[i] = BrotliReverseBits(depth[i], next_code[depth[i]]++);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(__cplusplus) || defined(c_plusplus)
|
||||
} /* extern "C" */
|
||||
#endif
|
|
@ -0,0 +1,97 @@
|
|||
/* Copyright 2013 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
/* Build per-context histograms of literals, commands and distance codes. */
|
||||
|
||||
#include "./enc/histogram.h"
|
||||
|
||||
#include "./enc/block_splitter.h"
|
||||
#include "./enc/command.h"
|
||||
#include "./enc/context.h"
|
||||
|
||||
#if defined(__cplusplus) || defined(c_plusplus)
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
typedef struct BlockSplitIterator {
|
||||
const BlockSplit* split_; /* Not owned. */
|
||||
size_t idx_;
|
||||
size_t type_;
|
||||
size_t length_;
|
||||
} BlockSplitIterator;
|
||||
|
||||
static void InitBlockSplitIterator(BlockSplitIterator* self,
|
||||
const BlockSplit* split) {
|
||||
self->split_ = split;
|
||||
self->idx_ = 0;
|
||||
self->type_ = 0;
|
||||
self->length_ = split->lengths ? split->lengths[0] : 0;
|
||||
}
|
||||
|
||||
static void BlockSplitIteratorNext(BlockSplitIterator* self) {
|
||||
if (self->length_ == 0) {
|
||||
++self->idx_;
|
||||
self->type_ = self->split_->types[self->idx_];
|
||||
self->length_ = self->split_->lengths[self->idx_];
|
||||
}
|
||||
--self->length_;
|
||||
}
|
||||
|
||||
void BrotliBuildHistogramsWithContext(
|
||||
const Command* cmds, const size_t num_commands,
|
||||
const BlockSplit* literal_split, const BlockSplit* insert_and_copy_split,
|
||||
const BlockSplit* dist_split, const uint8_t* ringbuffer, size_t start_pos,
|
||||
size_t mask, uint8_t prev_byte, uint8_t prev_byte2,
|
||||
const ContextType* context_modes, HistogramLiteral* literal_histograms,
|
||||
HistogramCommand* insert_and_copy_histograms,
|
||||
HistogramDistance* copy_dist_histograms) {
|
||||
size_t pos = start_pos;
|
||||
BlockSplitIterator literal_it;
|
||||
BlockSplitIterator insert_and_copy_it;
|
||||
BlockSplitIterator dist_it;
|
||||
size_t i;
|
||||
|
||||
InitBlockSplitIterator(&literal_it, literal_split);
|
||||
InitBlockSplitIterator(&insert_and_copy_it, insert_and_copy_split);
|
||||
InitBlockSplitIterator(&dist_it, dist_split);
|
||||
for (i = 0; i < num_commands; ++i) {
|
||||
const Command* cmd = &cmds[i];
|
||||
size_t j;
|
||||
BlockSplitIteratorNext(&insert_and_copy_it);
|
||||
HistogramAddCommand(&insert_and_copy_histograms[insert_and_copy_it.type_],
|
||||
cmd->cmd_prefix_);
|
||||
for (j = cmd->insert_len_; j != 0; --j) {
|
||||
size_t context;
|
||||
BlockSplitIteratorNext(&literal_it);
|
||||
context = context_modes ?
|
||||
((literal_it.type_ << BROTLI_LITERAL_CONTEXT_BITS) +
|
||||
Context(prev_byte, prev_byte2, context_modes[literal_it.type_])) :
|
||||
literal_it.type_;
|
||||
HistogramAddLiteral(&literal_histograms[context],
|
||||
ringbuffer[pos & mask]);
|
||||
prev_byte2 = prev_byte;
|
||||
prev_byte = ringbuffer[pos & mask];
|
||||
++pos;
|
||||
}
|
||||
pos += CommandCopyLen(cmd);
|
||||
if (CommandCopyLen(cmd)) {
|
||||
prev_byte2 = ringbuffer[(pos - 2) & mask];
|
||||
prev_byte = ringbuffer[(pos - 1) & mask];
|
||||
if (cmd->cmd_prefix_ >= 128) {
|
||||
size_t context;
|
||||
BlockSplitIteratorNext(&dist_it);
|
||||
context = (dist_it.type_ << BROTLI_DISTANCE_CONTEXT_BITS) +
|
||||
CommandDistanceContext(cmd);
|
||||
HistogramAddDistance(©_dist_histograms[context],
|
||||
cmd->dist_prefix_);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(__cplusplus) || defined(c_plusplus)
|
||||
} /* extern "C" */
|
||||
#endif
|
|
@ -0,0 +1,358 @@
|
|||
/* Copyright 2013 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
/* Utilities for building Huffman decoding tables. */
|
||||
|
||||
#include "./dec/huffman.h"
|
||||
|
||||
#include <string.h> /* memcpy, memset */
|
||||
|
||||
#include "./common/constants.h"
|
||||
#include <brotli/types.h>
|
||||
#include "./dec/port.h"
|
||||
|
||||
#if defined(__cplusplus) || defined(c_plusplus)
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#define BROTLI_REVERSE_BITS_MAX 8
|
||||
|
||||
#ifdef BROTLI_RBIT
|
||||
#define BROTLI_REVERSE_BITS_BASE \
|
||||
((sizeof(reg_t) << 3) - BROTLI_REVERSE_BITS_MAX)
|
||||
#else
|
||||
#define BROTLI_REVERSE_BITS_BASE 0
|
||||
static uint8_t kReverseBits[1 << BROTLI_REVERSE_BITS_MAX] = {
|
||||
0x00, 0x80, 0x40, 0xC0, 0x20, 0xA0, 0x60, 0xE0,
|
||||
0x10, 0x90, 0x50, 0xD0, 0x30, 0xB0, 0x70, 0xF0,
|
||||
0x08, 0x88, 0x48, 0xC8, 0x28, 0xA8, 0x68, 0xE8,
|
||||
0x18, 0x98, 0x58, 0xD8, 0x38, 0xB8, 0x78, 0xF8,
|
||||
0x04, 0x84, 0x44, 0xC4, 0x24, 0xA4, 0x64, 0xE4,
|
||||
0x14, 0x94, 0x54, 0xD4, 0x34, 0xB4, 0x74, 0xF4,
|
||||
0x0C, 0x8C, 0x4C, 0xCC, 0x2C, 0xAC, 0x6C, 0xEC,
|
||||
0x1C, 0x9C, 0x5C, 0xDC, 0x3C, 0xBC, 0x7C, 0xFC,
|
||||
0x02, 0x82, 0x42, 0xC2, 0x22, 0xA2, 0x62, 0xE2,
|
||||
0x12, 0x92, 0x52, 0xD2, 0x32, 0xB2, 0x72, 0xF2,
|
||||
0x0A, 0x8A, 0x4A, 0xCA, 0x2A, 0xAA, 0x6A, 0xEA,
|
||||
0x1A, 0x9A, 0x5A, 0xDA, 0x3A, 0xBA, 0x7A, 0xFA,
|
||||
0x06, 0x86, 0x46, 0xC6, 0x26, 0xA6, 0x66, 0xE6,
|
||||
0x16, 0x96, 0x56, 0xD6, 0x36, 0xB6, 0x76, 0xF6,
|
||||
0x0E, 0x8E, 0x4E, 0xCE, 0x2E, 0xAE, 0x6E, 0xEE,
|
||||
0x1E, 0x9E, 0x5E, 0xDE, 0x3E, 0xBE, 0x7E, 0xFE,
|
||||
0x01, 0x81, 0x41, 0xC1, 0x21, 0xA1, 0x61, 0xE1,
|
||||
0x11, 0x91, 0x51, 0xD1, 0x31, 0xB1, 0x71, 0xF1,
|
||||
0x09, 0x89, 0x49, 0xC9, 0x29, 0xA9, 0x69, 0xE9,
|
||||
0x19, 0x99, 0x59, 0xD9, 0x39, 0xB9, 0x79, 0xF9,
|
||||
0x05, 0x85, 0x45, 0xC5, 0x25, 0xA5, 0x65, 0xE5,
|
||||
0x15, 0x95, 0x55, 0xD5, 0x35, 0xB5, 0x75, 0xF5,
|
||||
0x0D, 0x8D, 0x4D, 0xCD, 0x2D, 0xAD, 0x6D, 0xED,
|
||||
0x1D, 0x9D, 0x5D, 0xDD, 0x3D, 0xBD, 0x7D, 0xFD,
|
||||
0x03, 0x83, 0x43, 0xC3, 0x23, 0xA3, 0x63, 0xE3,
|
||||
0x13, 0x93, 0x53, 0xD3, 0x33, 0xB3, 0x73, 0xF3,
|
||||
0x0B, 0x8B, 0x4B, 0xCB, 0x2B, 0xAB, 0x6B, 0xEB,
|
||||
0x1B, 0x9B, 0x5B, 0xDB, 0x3B, 0xBB, 0x7B, 0xFB,
|
||||
0x07, 0x87, 0x47, 0xC7, 0x27, 0xA7, 0x67, 0xE7,
|
||||
0x17, 0x97, 0x57, 0xD7, 0x37, 0xB7, 0x77, 0xF7,
|
||||
0x0F, 0x8F, 0x4F, 0xCF, 0x2F, 0xAF, 0x6F, 0xEF,
|
||||
0x1F, 0x9F, 0x5F, 0xDF, 0x3F, 0xBF, 0x7F, 0xFF
|
||||
};
|
||||
#endif /* BROTLI_RBIT */
|
||||
|
||||
#define BROTLI_REVERSE_BITS_LOWEST \
|
||||
((reg_t)1 << (BROTLI_REVERSE_BITS_MAX - 1 + BROTLI_REVERSE_BITS_BASE))
|
||||
|
||||
/* Returns reverse(num >> BROTLI_REVERSE_BITS_BASE, BROTLI_REVERSE_BITS_MAX),
|
||||
where reverse(value, len) is the bit-wise reversal of the len least
|
||||
significant bits of value. */
|
||||
static BROTLI_INLINE reg_t BrotliReverseBits(reg_t num) {
|
||||
#ifdef BROTLI_RBIT
|
||||
return BROTLI_RBIT(num);
|
||||
#else
|
||||
return kReverseBits[num];
|
||||
#endif
|
||||
}
|
||||
|
||||
/* Stores code in table[0], table[step], table[2*step], ..., table[end] */
|
||||
/* Assumes that end is an integer multiple of step */
|
||||
static BROTLI_INLINE void ReplicateValue(HuffmanCode* table,
|
||||
int step, int end,
|
||||
HuffmanCode code) {
|
||||
do {
|
||||
end -= step;
|
||||
table[end] = code;
|
||||
} while (end > 0);
|
||||
}
|
||||
|
||||
/* Returns the table width of the next 2nd level table. count is the histogram
|
||||
of bit lengths for the remaining symbols, len is the code length of the next
|
||||
processed symbol */
|
||||
static BROTLI_INLINE int NextTableBitSize(const uint16_t* const count,
|
||||
int len, int root_bits) {
|
||||
int left = 1 << (len - root_bits);
|
||||
while (len < BROTLI_HUFFMAN_MAX_CODE_LENGTH) {
|
||||
left -= count[len];
|
||||
if (left <= 0) break;
|
||||
++len;
|
||||
left <<= 1;
|
||||
}
|
||||
return len - root_bits;
|
||||
}
|
||||
|
||||
void BrotliBuildCodeLengthsHuffmanTable(HuffmanCode* table,
|
||||
const uint8_t* const code_lengths,
|
||||
uint16_t* count) {
|
||||
HuffmanCode code; /* current table entry */
|
||||
int symbol; /* symbol index in original or sorted table */
|
||||
reg_t key; /* prefix code */
|
||||
reg_t key_step; /* prefix code addend */
|
||||
int step; /* step size to replicate values in current table */
|
||||
int table_size; /* size of current table */
|
||||
int sorted[BROTLI_CODE_LENGTH_CODES]; /* symbols sorted by code length */
|
||||
/* offsets in sorted table for each length */
|
||||
int offset[BROTLI_HUFFMAN_MAX_CODE_LENGTH_CODE_LENGTH + 1];
|
||||
int bits;
|
||||
int bits_count;
|
||||
BROTLI_DCHECK(BROTLI_HUFFMAN_MAX_CODE_LENGTH_CODE_LENGTH <=
|
||||
BROTLI_REVERSE_BITS_MAX);
|
||||
|
||||
/* generate offsets into sorted symbol table by code length */
|
||||
symbol = -1;
|
||||
bits = 1;
|
||||
BROTLI_REPEAT(BROTLI_HUFFMAN_MAX_CODE_LENGTH_CODE_LENGTH, {
|
||||
symbol += count[bits];
|
||||
offset[bits] = symbol;
|
||||
bits++;
|
||||
});
|
||||
/* Symbols with code length 0 are placed after all other symbols. */
|
||||
offset[0] = BROTLI_CODE_LENGTH_CODES - 1;
|
||||
|
||||
/* sort symbols by length, by symbol order within each length */
|
||||
symbol = BROTLI_CODE_LENGTH_CODES;
|
||||
do {
|
||||
BROTLI_REPEAT(6, {
|
||||
symbol--;
|
||||
sorted[offset[code_lengths[symbol]]--] = symbol;
|
||||
});
|
||||
} while (symbol != 0);
|
||||
|
||||
table_size = 1 << BROTLI_HUFFMAN_MAX_CODE_LENGTH_CODE_LENGTH;
|
||||
|
||||
/* Special case: all symbols but one have 0 code length. */
|
||||
if (offset[0] == 0) {
|
||||
code.bits = 0;
|
||||
code.value = (uint16_t)sorted[0];
|
||||
for (key = 0; key < (reg_t)table_size; ++key) {
|
||||
table[key] = code;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
/* fill in table */
|
||||
key = 0;
|
||||
key_step = BROTLI_REVERSE_BITS_LOWEST;
|
||||
symbol = 0;
|
||||
bits = 1;
|
||||
step = 2;
|
||||
do {
|
||||
code.bits = (uint8_t)bits;
|
||||
for (bits_count = count[bits]; bits_count != 0; --bits_count) {
|
||||
code.value = (uint16_t)sorted[symbol++];
|
||||
ReplicateValue(&table[BrotliReverseBits(key)], step, table_size, code);
|
||||
key += key_step;
|
||||
}
|
||||
step <<= 1;
|
||||
key_step >>= 1;
|
||||
} while (++bits <= BROTLI_HUFFMAN_MAX_CODE_LENGTH_CODE_LENGTH);
|
||||
}
|
||||
|
||||
uint32_t BrotliBuildHuffmanTable(HuffmanCode* root_table,
|
||||
int root_bits,
|
||||
const uint16_t* const symbol_lists,
|
||||
uint16_t* count) {
|
||||
HuffmanCode code; /* current table entry */
|
||||
HuffmanCode* table; /* next available space in table */
|
||||
int len; /* current code length */
|
||||
int symbol; /* symbol index in original or sorted table */
|
||||
reg_t key; /* prefix code */
|
||||
reg_t key_step; /* prefix code addend */
|
||||
reg_t sub_key; /* 2nd level table prefix code */
|
||||
reg_t sub_key_step; /* 2nd level table prefix code addend */
|
||||
int step; /* step size to replicate values in current table */
|
||||
int table_bits; /* key length of current table */
|
||||
int table_size; /* size of current table */
|
||||
int total_size; /* sum of root table size and 2nd level table sizes */
|
||||
int max_length = -1;
|
||||
int bits;
|
||||
int bits_count;
|
||||
|
||||
BROTLI_DCHECK(root_bits <= BROTLI_REVERSE_BITS_MAX);
|
||||
BROTLI_DCHECK(BROTLI_HUFFMAN_MAX_CODE_LENGTH - root_bits <=
|
||||
BROTLI_REVERSE_BITS_MAX);
|
||||
|
||||
while (symbol_lists[max_length] == 0xFFFF) max_length--;
|
||||
max_length += BROTLI_HUFFMAN_MAX_CODE_LENGTH + 1;
|
||||
|
||||
table = root_table;
|
||||
table_bits = root_bits;
|
||||
table_size = 1 << table_bits;
|
||||
total_size = table_size;
|
||||
|
||||
/* fill in root table */
|
||||
/* let's reduce the table size to a smaller size if possible, and */
|
||||
/* create the repetitions by memcpy if possible in the coming loop */
|
||||
if (table_bits > max_length) {
|
||||
table_bits = max_length;
|
||||
table_size = 1 << table_bits;
|
||||
}
|
||||
key = 0;
|
||||
key_step = BROTLI_REVERSE_BITS_LOWEST;
|
||||
bits = 1;
|
||||
step = 2;
|
||||
do {
|
||||
code.bits = (uint8_t)bits;
|
||||
symbol = bits - (BROTLI_HUFFMAN_MAX_CODE_LENGTH + 1);
|
||||
for (bits_count = count[bits]; bits_count != 0; --bits_count) {
|
||||
symbol = symbol_lists[symbol];
|
||||
code.value = (uint16_t)symbol;
|
||||
ReplicateValue(&table[BrotliReverseBits(key)], step, table_size, code);
|
||||
key += key_step;
|
||||
}
|
||||
step <<= 1;
|
||||
key_step >>= 1;
|
||||
} while (++bits <= table_bits);
|
||||
|
||||
/* if root_bits != table_bits we only created one fraction of the */
|
||||
/* table, and we need to replicate it now. */
|
||||
while (total_size != table_size) {
|
||||
memcpy(&table[table_size], &table[0],
|
||||
(size_t)table_size * sizeof(table[0]));
|
||||
table_size <<= 1;
|
||||
}
|
||||
|
||||
/* fill in 2nd level tables and add pointers to root table */
|
||||
key_step = BROTLI_REVERSE_BITS_LOWEST >> (root_bits - 1);
|
||||
sub_key = (BROTLI_REVERSE_BITS_LOWEST << 1);
|
||||
sub_key_step = BROTLI_REVERSE_BITS_LOWEST;
|
||||
for (len = root_bits + 1, step = 2; len <= max_length; ++len) {
|
||||
symbol = len - (BROTLI_HUFFMAN_MAX_CODE_LENGTH + 1);
|
||||
for (; count[len] != 0; --count[len]) {
|
||||
if (sub_key == (BROTLI_REVERSE_BITS_LOWEST << 1U)) {
|
||||
table += table_size;
|
||||
table_bits = NextTableBitSize(count, len, root_bits);
|
||||
table_size = 1 << table_bits;
|
||||
total_size += table_size;
|
||||
sub_key = BrotliReverseBits(key);
|
||||
key += key_step;
|
||||
root_table[sub_key].bits = (uint8_t)(table_bits + root_bits);
|
||||
root_table[sub_key].value =
|
||||
(uint16_t)(((size_t)(table - root_table)) - sub_key);
|
||||
sub_key = 0;
|
||||
}
|
||||
code.bits = (uint8_t)(len - root_bits);
|
||||
symbol = symbol_lists[symbol];
|
||||
code.value = (uint16_t)symbol;
|
||||
ReplicateValue(
|
||||
&table[BrotliReverseBits(sub_key)], step, table_size, code);
|
||||
sub_key += sub_key_step;
|
||||
}
|
||||
step <<= 1;
|
||||
sub_key_step >>= 1;
|
||||
}
|
||||
return (uint32_t)total_size;
|
||||
}
|
||||
|
||||
uint32_t BrotliBuildSimpleHuffmanTable(HuffmanCode* table,
|
||||
int root_bits,
|
||||
uint16_t* val,
|
||||
uint32_t num_symbols) {
|
||||
uint32_t table_size = 1;
|
||||
const uint32_t goal_size = 1U << root_bits;
|
||||
switch (num_symbols) {
|
||||
case 0:
|
||||
table[0].bits = 0;
|
||||
table[0].value = val[0];
|
||||
break;
|
||||
case 1:
|
||||
table[0].bits = 1;
|
||||
table[1].bits = 1;
|
||||
if (val[1] > val[0]) {
|
||||
table[0].value = val[0];
|
||||
table[1].value = val[1];
|
||||
} else {
|
||||
table[0].value = val[1];
|
||||
table[1].value = val[0];
|
||||
}
|
||||
table_size = 2;
|
||||
break;
|
||||
case 2:
|
||||
table[0].bits = 1;
|
||||
table[0].value = val[0];
|
||||
table[2].bits = 1;
|
||||
table[2].value = val[0];
|
||||
if (val[2] > val[1]) {
|
||||
table[1].value = val[1];
|
||||
table[3].value = val[2];
|
||||
} else {
|
||||
table[1].value = val[2];
|
||||
table[3].value = val[1];
|
||||
}
|
||||
table[1].bits = 2;
|
||||
table[3].bits = 2;
|
||||
table_size = 4;
|
||||
break;
|
||||
case 3: {
|
||||
int i, k;
|
||||
for (i = 0; i < 3; ++i) {
|
||||
for (k = i + 1; k < 4; ++k) {
|
||||
if (val[k] < val[i]) {
|
||||
uint16_t t = val[k];
|
||||
val[k] = val[i];
|
||||
val[i] = t;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (i = 0; i < 4; ++i) {
|
||||
table[i].bits = 2;
|
||||
}
|
||||
table[0].value = val[0];
|
||||
table[2].value = val[1];
|
||||
table[1].value = val[2];
|
||||
table[3].value = val[3];
|
||||
table_size = 4;
|
||||
break;
|
||||
}
|
||||
case 4: {
|
||||
int i;
|
||||
if (val[3] < val[2]) {
|
||||
uint16_t t = val[3];
|
||||
val[3] = val[2];
|
||||
val[2] = t;
|
||||
}
|
||||
for (i = 0; i < 7; ++i) {
|
||||
table[i].value = val[0];
|
||||
table[i].bits = (uint8_t)(1 + (i & 1));
|
||||
}
|
||||
table[1].value = val[1];
|
||||
table[3].value = val[2];
|
||||
table[5].value = val[1];
|
||||
table[7].value = val[3];
|
||||
table[3].bits = 3;
|
||||
table[7].bits = 3;
|
||||
table_size = 8;
|
||||
break;
|
||||
}
|
||||
}
|
||||
while (table_size != goal_size) {
|
||||
memcpy(&table[table_size], &table[0],
|
||||
(size_t)table_size * sizeof(table[0]));
|
||||
table_size <<= 1;
|
||||
}
|
||||
return goal_size;
|
||||
}
|
||||
|
||||
#if defined(__cplusplus) || defined(c_plusplus)
|
||||
} /* extern "C" */
|
||||
#endif
|
|
@ -0,0 +1,175 @@
|
|||
/* Copyright 2013 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
/* Literal cost model to allow backward reference replacement to be efficient.
|
||||
*/
|
||||
|
||||
#include "./enc/literal_cost.h"
|
||||
|
||||
#include <brotli/types.h>
|
||||
#include "./enc/fast_log.h"
|
||||
#include "./enc/port.h"
|
||||
#include "./enc/utf8_util.h"
|
||||
|
||||
#if defined(__cplusplus) || defined(c_plusplus)
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
static size_t UTF8Position(size_t last, size_t c, size_t clamp) {
|
||||
if (c < 128) {
|
||||
return 0; /* Next one is the 'Byte 1' again. */
|
||||
} else if (c >= 192) { /* Next one is the 'Byte 2' of utf-8 encoding. */
|
||||
return BROTLI_MIN(size_t, 1, clamp);
|
||||
} else {
|
||||
/* Let's decide over the last byte if this ends the sequence. */
|
||||
if (last < 0xe0) {
|
||||
return 0; /* Completed two or three byte coding. */
|
||||
} else { /* Next one is the 'Byte 3' of utf-8 encoding. */
|
||||
return BROTLI_MIN(size_t, 2, clamp);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static size_t DecideMultiByteStatsLevel(size_t pos, size_t len, size_t mask,
|
||||
const uint8_t *data) {
|
||||
size_t counts[3] = { 0 };
|
||||
size_t max_utf8 = 1; /* should be 2, but 1 compresses better. */
|
||||
size_t last_c = 0;
|
||||
size_t i;
|
||||
for (i = 0; i < len; ++i) {
|
||||
size_t c = data[(pos + i) & mask];
|
||||
++counts[UTF8Position(last_c, c, 2)];
|
||||
last_c = c;
|
||||
}
|
||||
if (counts[2] < 500) {
|
||||
max_utf8 = 1;
|
||||
}
|
||||
if (counts[1] + counts[2] < 25) {
|
||||
max_utf8 = 0;
|
||||
}
|
||||
return max_utf8;
|
||||
}
|
||||
|
||||
static void EstimateBitCostsForLiteralsUTF8(size_t pos, size_t len, size_t mask,
|
||||
const uint8_t *data, float *cost) {
|
||||
/* max_utf8 is 0 (normal ASCII single byte modeling),
|
||||
1 (for 2-byte UTF-8 modeling), or 2 (for 3-byte UTF-8 modeling). */
|
||||
const size_t max_utf8 = DecideMultiByteStatsLevel(pos, len, mask, data);
|
||||
size_t histogram[3][256] = { { 0 } };
|
||||
size_t window_half = 495;
|
||||
size_t in_window = BROTLI_MIN(size_t, window_half, len);
|
||||
size_t in_window_utf8[3] = { 0 };
|
||||
|
||||
size_t i;
|
||||
{ /* Bootstrap histograms. */
|
||||
size_t last_c = 0;
|
||||
size_t utf8_pos = 0;
|
||||
for (i = 0; i < in_window; ++i) {
|
||||
size_t c = data[(pos + i) & mask];
|
||||
++histogram[utf8_pos][c];
|
||||
++in_window_utf8[utf8_pos];
|
||||
utf8_pos = UTF8Position(last_c, c, max_utf8);
|
||||
last_c = c;
|
||||
}
|
||||
}
|
||||
|
||||
/* Compute bit costs with sliding window. */
|
||||
for (i = 0; i < len; ++i) {
|
||||
if (i >= window_half) {
|
||||
/* Remove a byte in the past. */
|
||||
size_t c =
|
||||
i < window_half + 1 ? 0 : data[(pos + i - window_half - 1) & mask];
|
||||
size_t last_c =
|
||||
i < window_half + 2 ? 0 : data[(pos + i - window_half - 2) & mask];
|
||||
size_t utf8_pos2 = UTF8Position(last_c, c, max_utf8);
|
||||
--histogram[utf8_pos2][data[(pos + i - window_half) & mask]];
|
||||
--in_window_utf8[utf8_pos2];
|
||||
}
|
||||
if (i + window_half < len) {
|
||||
/* Add a byte in the future. */
|
||||
size_t c = data[(pos + i + window_half - 1) & mask];
|
||||
size_t last_c = data[(pos + i + window_half - 2) & mask];
|
||||
size_t utf8_pos2 = UTF8Position(last_c, c, max_utf8);
|
||||
++histogram[utf8_pos2][data[(pos + i + window_half) & mask]];
|
||||
++in_window_utf8[utf8_pos2];
|
||||
}
|
||||
{
|
||||
size_t c = i < 1 ? 0 : data[(pos + i - 1) & mask];
|
||||
size_t last_c = i < 2 ? 0 : data[(pos + i - 2) & mask];
|
||||
size_t utf8_pos = UTF8Position(last_c, c, max_utf8);
|
||||
size_t masked_pos = (pos + i) & mask;
|
||||
size_t histo = histogram[utf8_pos][data[masked_pos]];
|
||||
double lit_cost;
|
||||
if (histo == 0) {
|
||||
histo = 1;
|
||||
}
|
||||
lit_cost = FastLog2(in_window_utf8[utf8_pos]) - FastLog2(histo);
|
||||
lit_cost += 0.02905;
|
||||
if (lit_cost < 1.0) {
|
||||
lit_cost *= 0.5;
|
||||
lit_cost += 0.5;
|
||||
}
|
||||
/* Make the first bytes more expensive -- seems to help, not sure why.
|
||||
Perhaps because the entropy source is changing its properties
|
||||
rapidly in the beginning of the file, perhaps because the beginning
|
||||
of the data is a statistical "anomaly". */
|
||||
if (i < 2000) {
|
||||
lit_cost += 0.7 - ((double)(2000 - i) / 2000.0 * 0.35);
|
||||
}
|
||||
cost[i] = (float)lit_cost;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void BrotliEstimateBitCostsForLiterals(size_t pos, size_t len, size_t mask,
|
||||
const uint8_t *data, float *cost) {
|
||||
if (BrotliIsMostlyUTF8(data, pos, mask, len, kMinUTF8Ratio)) {
|
||||
EstimateBitCostsForLiteralsUTF8(pos, len, mask, data, cost);
|
||||
return;
|
||||
} else {
|
||||
size_t histogram[256] = { 0 };
|
||||
size_t window_half = 2000;
|
||||
size_t in_window = BROTLI_MIN(size_t, window_half, len);
|
||||
|
||||
/* Bootstrap histogram. */
|
||||
size_t i;
|
||||
for (i = 0; i < in_window; ++i) {
|
||||
++histogram[data[(pos + i) & mask]];
|
||||
}
|
||||
|
||||
/* Compute bit costs with sliding window. */
|
||||
for (i = 0; i < len; ++i) {
|
||||
size_t histo;
|
||||
if (i >= window_half) {
|
||||
/* Remove a byte in the past. */
|
||||
--histogram[data[(pos + i - window_half) & mask]];
|
||||
--in_window;
|
||||
}
|
||||
if (i + window_half < len) {
|
||||
/* Add a byte in the future. */
|
||||
++histogram[data[(pos + i + window_half) & mask]];
|
||||
++in_window;
|
||||
}
|
||||
histo = histogram[data[(pos + i) & mask]];
|
||||
if (histo == 0) {
|
||||
histo = 1;
|
||||
}
|
||||
{
|
||||
double lit_cost = FastLog2(in_window) - FastLog2(histo);
|
||||
lit_cost += 0.029;
|
||||
if (lit_cost < 1.0) {
|
||||
lit_cost *= 0.5;
|
||||
lit_cost += 0.5;
|
||||
}
|
||||
cost[i] = (float)lit_cost;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(__cplusplus) || defined(c_plusplus)
|
||||
} /* extern "C" */
|
||||
#endif
|
|
@ -0,0 +1,181 @@
|
|||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
/* Algorithms for distributing the literals and commands of a metablock between
|
||||
block types and contexts. */
|
||||
|
||||
#include "./enc/memory.h"
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdlib.h> /* exit, free, malloc */
|
||||
#include <string.h> /* memcpy */
|
||||
|
||||
#include <brotli/types.h>
|
||||
#include "./enc/port.h"
|
||||
|
||||
#if defined(__cplusplus) || defined(c_plusplus)
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#define MAX_PERM_ALLOCATED 128
|
||||
#define MAX_NEW_ALLOCATED 64
|
||||
#define MAX_NEW_FREED 64
|
||||
|
||||
#define PERM_ALLOCATED_OFFSET 0
|
||||
#define NEW_ALLOCATED_OFFSET MAX_PERM_ALLOCATED
|
||||
#define NEW_FREED_OFFSET (MAX_PERM_ALLOCATED + MAX_NEW_ALLOCATED)
|
||||
|
||||
static void* DefaultAllocFunc(void* opaque, size_t size) {
|
||||
BROTLI_UNUSED(opaque);
|
||||
return malloc(size);
|
||||
}
|
||||
|
||||
static void DefaultFreeFunc(void* opaque, void* address) {
|
||||
BROTLI_UNUSED(opaque);
|
||||
free(address);
|
||||
}
|
||||
|
||||
void BrotliInitMemoryManager(
|
||||
MemoryManager* m, brotli_alloc_func alloc_func, brotli_free_func free_func,
|
||||
void* opaque) {
|
||||
if (!alloc_func) {
|
||||
m->alloc_func = DefaultAllocFunc;
|
||||
m->free_func = DefaultFreeFunc;
|
||||
m->opaque = 0;
|
||||
} else {
|
||||
m->alloc_func = alloc_func;
|
||||
m->free_func = free_func;
|
||||
m->opaque = opaque;
|
||||
}
|
||||
#if !defined(BROTLI_ENCODER_EXIT_ON_OOM)
|
||||
m->is_oom = BROTLI_FALSE;
|
||||
m->perm_allocated = 0;
|
||||
m->new_allocated = 0;
|
||||
m->new_freed = 0;
|
||||
#endif /* BROTLI_ENCODER_EXIT_ON_OOM */
|
||||
}
|
||||
|
||||
#if defined(BROTLI_ENCODER_EXIT_ON_OOM)
|
||||
|
||||
void* BrotliAllocate(MemoryManager* m, size_t n) {
|
||||
void* result = m->alloc_func(m->opaque, n);
|
||||
if (!result) exit(EXIT_FAILURE);
|
||||
return result;
|
||||
}
|
||||
|
||||
void BrotliFree(MemoryManager* m, void* p) {
|
||||
m->free_func(m->opaque, p);
|
||||
}
|
||||
|
||||
void BrotliWipeOutMemoryManager(MemoryManager* m) {
|
||||
BROTLI_UNUSED(m);
|
||||
}
|
||||
|
||||
#else /* BROTLI_ENCODER_EXIT_ON_OOM */
|
||||
|
||||
static void SortPointers(void** items, const size_t n) {
|
||||
/* Shell sort. */
|
||||
static const size_t gaps[] = {23, 10, 4, 1};
|
||||
int g = 0;
|
||||
for (; g < 4; ++g) {
|
||||
size_t gap = gaps[g];
|
||||
size_t i;
|
||||
for (i = gap; i < n; ++i) {
|
||||
size_t j = i;
|
||||
void* tmp = items[i];
|
||||
for (; j >= gap && tmp < items[j - gap]; j -= gap) {
|
||||
items[j] = items[j - gap];
|
||||
}
|
||||
items[j] = tmp;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static size_t Annihilate(void** a, size_t a_len, void** b, size_t b_len) {
|
||||
size_t a_read_index = 0;
|
||||
size_t b_read_index = 0;
|
||||
size_t a_write_index = 0;
|
||||
size_t b_write_index = 0;
|
||||
size_t annihilated = 0;
|
||||
while (a_read_index < a_len && b_read_index < b_len) {
|
||||
if (a[a_read_index] == b[b_read_index]) {
|
||||
a_read_index++;
|
||||
b_read_index++;
|
||||
annihilated++;
|
||||
} else if (a[a_read_index] < b[b_read_index]) {
|
||||
a[a_write_index++] = a[a_read_index++];
|
||||
} else {
|
||||
b[b_write_index++] = b[b_read_index++];
|
||||
}
|
||||
}
|
||||
while (a_read_index < a_len) a[a_write_index++] = a[a_read_index++];
|
||||
while (b_read_index < b_len) b[b_write_index++] = b[b_read_index++];
|
||||
return annihilated;
|
||||
}
|
||||
|
||||
static void CollectGarbagePointers(MemoryManager* m) {
|
||||
size_t annihilated;
|
||||
SortPointers(m->pointers + NEW_ALLOCATED_OFFSET, m->new_allocated);
|
||||
SortPointers(m->pointers + NEW_FREED_OFFSET, m->new_freed);
|
||||
annihilated = Annihilate(
|
||||
m->pointers + NEW_ALLOCATED_OFFSET, m->new_allocated,
|
||||
m->pointers + NEW_FREED_OFFSET, m->new_freed);
|
||||
m->new_allocated -= annihilated;
|
||||
m->new_freed -= annihilated;
|
||||
|
||||
if (m->new_freed != 0) {
|
||||
annihilated = Annihilate(
|
||||
m->pointers + PERM_ALLOCATED_OFFSET, m->perm_allocated,
|
||||
m->pointers + NEW_FREED_OFFSET, m->new_freed);
|
||||
m->perm_allocated -= annihilated;
|
||||
m->new_freed -= annihilated;
|
||||
assert(m->new_freed == 0);
|
||||
}
|
||||
|
||||
if (m->new_allocated != 0) {
|
||||
assert(m->perm_allocated + m->new_allocated <= MAX_PERM_ALLOCATED);
|
||||
memcpy(m->pointers + PERM_ALLOCATED_OFFSET + m->perm_allocated,
|
||||
m->pointers + NEW_ALLOCATED_OFFSET,
|
||||
sizeof(void*) * m->new_allocated);
|
||||
m->perm_allocated += m->new_allocated;
|
||||
m->new_allocated = 0;
|
||||
SortPointers(m->pointers + PERM_ALLOCATED_OFFSET, m->perm_allocated);
|
||||
}
|
||||
}
|
||||
|
||||
void* BrotliAllocate(MemoryManager* m, size_t n) {
|
||||
void* result = m->alloc_func(m->opaque, n);
|
||||
if (!result) {
|
||||
m->is_oom = BROTLI_TRUE;
|
||||
return NULL;
|
||||
}
|
||||
if (m->new_allocated == MAX_NEW_ALLOCATED) CollectGarbagePointers(m);
|
||||
m->pointers[NEW_ALLOCATED_OFFSET + (m->new_allocated++)] = result;
|
||||
return result;
|
||||
}
|
||||
|
||||
void BrotliFree(MemoryManager* m, void* p) {
|
||||
if (!p) return;
|
||||
m->free_func(m->opaque, p);
|
||||
if (m->new_freed == MAX_NEW_FREED) CollectGarbagePointers(m);
|
||||
m->pointers[NEW_FREED_OFFSET + (m->new_freed++)] = p;
|
||||
}
|
||||
|
||||
void BrotliWipeOutMemoryManager(MemoryManager* m) {
|
||||
size_t i;
|
||||
CollectGarbagePointers(m);
|
||||
/* Now all unfreed pointers are in perm-allocated list. */
|
||||
for (i = 0; i < m->perm_allocated; ++i) {
|
||||
m->free_func(m->opaque, m->pointers[PERM_ALLOCATED_OFFSET + i]);
|
||||
}
|
||||
m->perm_allocated = 0;
|
||||
}
|
||||
|
||||
#endif /* BROTLI_ENCODER_EXIT_ON_OOM */
|
||||
|
||||
#if defined(__cplusplus) || defined(c_plusplus)
|
||||
} /* extern "C" */
|
||||
#endif
|
|
@ -0,0 +1,528 @@
|
|||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
/* Algorithms for distributing the literals and commands of a metablock between
|
||||
block types and contexts. */
|
||||
|
||||
#include "./enc/metablock.h"
|
||||
|
||||
#include "./common/constants.h"
|
||||
#include <brotli/types.h>
|
||||
#include "./enc/bit_cost.h"
|
||||
#include "./enc/block_splitter.h"
|
||||
#include "./enc/cluster.h"
|
||||
#include "./enc/context.h"
|
||||
#include "./enc/entropy_encode.h"
|
||||
#include "./enc/histogram.h"
|
||||
#include "./enc/memory.h"
|
||||
#include "./enc/port.h"
|
||||
#include "./enc/quality.h"
|
||||
|
||||
#if defined(__cplusplus) || defined(c_plusplus)
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
void BrotliBuildMetaBlock(MemoryManager* m,
|
||||
const uint8_t* ringbuffer,
|
||||
const size_t pos,
|
||||
const size_t mask,
|
||||
const BrotliEncoderParams* params,
|
||||
uint8_t prev_byte,
|
||||
uint8_t prev_byte2,
|
||||
const Command* cmds,
|
||||
size_t num_commands,
|
||||
ContextType literal_context_mode,
|
||||
MetaBlockSplit* mb) {
|
||||
/* Histogram ids need to fit in one byte. */
|
||||
static const size_t kMaxNumberOfHistograms = 256;
|
||||
HistogramDistance* distance_histograms;
|
||||
HistogramLiteral* literal_histograms;
|
||||
ContextType* literal_context_modes = NULL;
|
||||
size_t literal_histograms_size;
|
||||
size_t distance_histograms_size;
|
||||
size_t i;
|
||||
size_t literal_context_multiplier = 1;
|
||||
|
||||
BrotliSplitBlock(m, cmds, num_commands,
|
||||
ringbuffer, pos, mask, params,
|
||||
&mb->literal_split,
|
||||
&mb->command_split,
|
||||
&mb->distance_split);
|
||||
if (BROTLI_IS_OOM(m)) return;
|
||||
|
||||
if (!params->disable_literal_context_modeling) {
|
||||
literal_context_multiplier = 1 << BROTLI_LITERAL_CONTEXT_BITS;
|
||||
literal_context_modes =
|
||||
BROTLI_ALLOC(m, ContextType, mb->literal_split.num_types);
|
||||
if (BROTLI_IS_OOM(m)) return;
|
||||
for (i = 0; i < mb->literal_split.num_types; ++i) {
|
||||
literal_context_modes[i] = literal_context_mode;
|
||||
}
|
||||
}
|
||||
|
||||
literal_histograms_size =
|
||||
mb->literal_split.num_types * literal_context_multiplier;
|
||||
literal_histograms =
|
||||
BROTLI_ALLOC(m, HistogramLiteral, literal_histograms_size);
|
||||
if (BROTLI_IS_OOM(m)) return;
|
||||
ClearHistogramsLiteral(literal_histograms, literal_histograms_size);
|
||||
|
||||
distance_histograms_size =
|
||||
mb->distance_split.num_types << BROTLI_DISTANCE_CONTEXT_BITS;
|
||||
distance_histograms =
|
||||
BROTLI_ALLOC(m, HistogramDistance, distance_histograms_size);
|
||||
if (BROTLI_IS_OOM(m)) return;
|
||||
ClearHistogramsDistance(distance_histograms, distance_histograms_size);
|
||||
|
||||
assert(mb->command_histograms == 0);
|
||||
mb->command_histograms_size = mb->command_split.num_types;
|
||||
mb->command_histograms =
|
||||
BROTLI_ALLOC(m, HistogramCommand, mb->command_histograms_size);
|
||||
if (BROTLI_IS_OOM(m)) return;
|
||||
ClearHistogramsCommand(mb->command_histograms, mb->command_histograms_size);
|
||||
|
||||
BrotliBuildHistogramsWithContext(cmds, num_commands,
|
||||
&mb->literal_split, &mb->command_split, &mb->distance_split,
|
||||
ringbuffer, pos, mask, prev_byte, prev_byte2, literal_context_modes,
|
||||
literal_histograms, mb->command_histograms, distance_histograms);
|
||||
BROTLI_FREE(m, literal_context_modes);
|
||||
|
||||
assert(mb->literal_context_map == 0);
|
||||
mb->literal_context_map_size =
|
||||
mb->literal_split.num_types << BROTLI_LITERAL_CONTEXT_BITS;
|
||||
mb->literal_context_map =
|
||||
BROTLI_ALLOC(m, uint32_t, mb->literal_context_map_size);
|
||||
if (BROTLI_IS_OOM(m)) return;
|
||||
|
||||
assert(mb->literal_histograms == 0);
|
||||
mb->literal_histograms_size = mb->literal_context_map_size;
|
||||
mb->literal_histograms =
|
||||
BROTLI_ALLOC(m, HistogramLiteral, mb->literal_histograms_size);
|
||||
if (BROTLI_IS_OOM(m)) return;
|
||||
|
||||
BrotliClusterHistogramsLiteral(m, literal_histograms, literal_histograms_size,
|
||||
kMaxNumberOfHistograms, mb->literal_histograms,
|
||||
&mb->literal_histograms_size, mb->literal_context_map);
|
||||
if (BROTLI_IS_OOM(m)) return;
|
||||
BROTLI_FREE(m, literal_histograms);
|
||||
|
||||
if (params->disable_literal_context_modeling) {
|
||||
/* Distribute assignment to all contexts. */
|
||||
for (i = mb->literal_split.num_types; i != 0;) {
|
||||
size_t j = 0;
|
||||
i--;
|
||||
for (; j < (1 << BROTLI_LITERAL_CONTEXT_BITS); j++) {
|
||||
mb->literal_context_map[(i << BROTLI_LITERAL_CONTEXT_BITS) + j] =
|
||||
mb->literal_context_map[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert(mb->distance_context_map == 0);
|
||||
mb->distance_context_map_size =
|
||||
mb->distance_split.num_types << BROTLI_DISTANCE_CONTEXT_BITS;
|
||||
mb->distance_context_map =
|
||||
BROTLI_ALLOC(m, uint32_t, mb->distance_context_map_size);
|
||||
if (BROTLI_IS_OOM(m)) return;
|
||||
|
||||
assert(mb->distance_histograms == 0);
|
||||
mb->distance_histograms_size = mb->distance_context_map_size;
|
||||
mb->distance_histograms =
|
||||
BROTLI_ALLOC(m, HistogramDistance, mb->distance_histograms_size);
|
||||
if (BROTLI_IS_OOM(m)) return;
|
||||
|
||||
BrotliClusterHistogramsDistance(m, distance_histograms,
|
||||
mb->distance_context_map_size,
|
||||
kMaxNumberOfHistograms,
|
||||
mb->distance_histograms,
|
||||
&mb->distance_histograms_size,
|
||||
mb->distance_context_map);
|
||||
if (BROTLI_IS_OOM(m)) return;
|
||||
BROTLI_FREE(m, distance_histograms);
|
||||
}
|
||||
|
||||
#define FN(X) X ## Literal
|
||||
#include "./enc/metablock_inc.h" /* NOLINT(build/include) */
|
||||
#undef FN
|
||||
|
||||
#define FN(X) X ## Command
|
||||
#include "./enc/metablock_inc.h" /* NOLINT(build/include) */
|
||||
#undef FN
|
||||
|
||||
#define FN(X) X ## Distance
|
||||
#include "./enc/metablock_inc.h" /* NOLINT(build/include) */
|
||||
#undef FN
|
||||
|
||||
#define BROTLI_MAX_STATIC_CONTEXTS 13
|
||||
|
||||
/* Greedy block splitter for one block category (literal, command or distance).
|
||||
Gathers histograms for all context buckets. */
|
||||
typedef struct ContextBlockSplitter {
|
||||
/* Alphabet size of particular block category. */
|
||||
size_t alphabet_size_;
|
||||
size_t num_contexts_;
|
||||
size_t max_block_types_;
|
||||
/* We collect at least this many symbols for each block. */
|
||||
size_t min_block_size_;
|
||||
/* We merge histograms A and B if
|
||||
entropy(A+B) < entropy(A) + entropy(B) + split_threshold_,
|
||||
where A is the current histogram and B is the histogram of the last or the
|
||||
second last block type. */
|
||||
double split_threshold_;
|
||||
|
||||
size_t num_blocks_;
|
||||
BlockSplit* split_; /* not owned */
|
||||
HistogramLiteral* histograms_; /* not owned */
|
||||
size_t* histograms_size_; /* not owned */
|
||||
|
||||
/* The number of symbols that we want to collect before deciding on whether
|
||||
or not to merge the block with a previous one or emit a new block. */
|
||||
size_t target_block_size_;
|
||||
/* The number of symbols in the current histogram. */
|
||||
size_t block_size_;
|
||||
/* Offset of the current histogram. */
|
||||
size_t curr_histogram_ix_;
|
||||
/* Offset of the histograms of the previous two block types. */
|
||||
size_t last_histogram_ix_[2];
|
||||
/* Entropy of the previous two block types. */
|
||||
double last_entropy_[2 * BROTLI_MAX_STATIC_CONTEXTS];
|
||||
/* The number of times we merged the current block with the last one. */
|
||||
size_t merge_last_count_;
|
||||
} ContextBlockSplitter;
|
||||
|
||||
static void InitContextBlockSplitter(
|
||||
MemoryManager* m, ContextBlockSplitter* self, size_t alphabet_size,
|
||||
size_t num_contexts, size_t min_block_size, double split_threshold,
|
||||
size_t num_symbols, BlockSplit* split, HistogramLiteral** histograms,
|
||||
size_t* histograms_size) {
|
||||
size_t max_num_blocks = num_symbols / min_block_size + 1;
|
||||
size_t max_num_types;
|
||||
assert(num_contexts <= BROTLI_MAX_STATIC_CONTEXTS);
|
||||
|
||||
self->alphabet_size_ = alphabet_size;
|
||||
self->num_contexts_ = num_contexts;
|
||||
self->max_block_types_ = BROTLI_MAX_NUMBER_OF_BLOCK_TYPES / num_contexts;
|
||||
self->min_block_size_ = min_block_size;
|
||||
self->split_threshold_ = split_threshold;
|
||||
self->num_blocks_ = 0;
|
||||
self->split_ = split;
|
||||
self->histograms_size_ = histograms_size;
|
||||
self->target_block_size_ = min_block_size;
|
||||
self->block_size_ = 0;
|
||||
self->curr_histogram_ix_ = 0;
|
||||
self->merge_last_count_ = 0;
|
||||
|
||||
/* We have to allocate one more histogram than the maximum number of block
|
||||
types for the current histogram when the meta-block is too big. */
|
||||
max_num_types =
|
||||
BROTLI_MIN(size_t, max_num_blocks, self->max_block_types_ + 1);
|
||||
BROTLI_ENSURE_CAPACITY(m, uint8_t,
|
||||
split->types, split->types_alloc_size, max_num_blocks);
|
||||
BROTLI_ENSURE_CAPACITY(m, uint32_t,
|
||||
split->lengths, split->lengths_alloc_size, max_num_blocks);
|
||||
if (BROTLI_IS_OOM(m)) return;
|
||||
split->num_blocks = max_num_blocks;
|
||||
if (BROTLI_IS_OOM(m)) return;
|
||||
assert(*histograms == 0);
|
||||
*histograms_size = max_num_types * num_contexts;
|
||||
*histograms = BROTLI_ALLOC(m, HistogramLiteral, *histograms_size);
|
||||
self->histograms_ = *histograms;
|
||||
if (BROTLI_IS_OOM(m)) return;
|
||||
/* Clear only current histogram. */
|
||||
ClearHistogramsLiteral(&self->histograms_[0], num_contexts);
|
||||
self->last_histogram_ix_[0] = self->last_histogram_ix_[1] = 0;
|
||||
}
|
||||
|
||||
/* Does either of three things:
|
||||
(1) emits the current block with a new block type;
|
||||
(2) emits the current block with the type of the second last block;
|
||||
(3) merges the current block with the last block. */
|
||||
static void ContextBlockSplitterFinishBlock(
|
||||
ContextBlockSplitter* self, MemoryManager* m, BROTLI_BOOL is_final) {
|
||||
BlockSplit* split = self->split_;
|
||||
const size_t num_contexts = self->num_contexts_;
|
||||
double* last_entropy = self->last_entropy_;
|
||||
HistogramLiteral* histograms = self->histograms_;
|
||||
|
||||
if (self->block_size_ < self->min_block_size_) {
|
||||
self->block_size_ = self->min_block_size_;
|
||||
}
|
||||
if (self->num_blocks_ == 0) {
|
||||
size_t i;
|
||||
/* Create first block. */
|
||||
split->lengths[0] = (uint32_t)self->block_size_;
|
||||
split->types[0] = 0;
|
||||
|
||||
for (i = 0; i < num_contexts; ++i) {
|
||||
last_entropy[i] =
|
||||
BitsEntropy(histograms[i].data_, self->alphabet_size_);
|
||||
last_entropy[num_contexts + i] = last_entropy[i];
|
||||
}
|
||||
++self->num_blocks_;
|
||||
++split->num_types;
|
||||
self->curr_histogram_ix_ += num_contexts;
|
||||
if (self->curr_histogram_ix_ < *self->histograms_size_) {
|
||||
ClearHistogramsLiteral(
|
||||
&self->histograms_[self->curr_histogram_ix_], self->num_contexts_);
|
||||
}
|
||||
self->block_size_ = 0;
|
||||
} else if (self->block_size_ > 0) {
|
||||
/* Try merging the set of histograms for the current block type with the
|
||||
respective set of histograms for the last and second last block types.
|
||||
Decide over the split based on the total reduction of entropy across
|
||||
all contexts. */
|
||||
double entropy[BROTLI_MAX_STATIC_CONTEXTS];
|
||||
HistogramLiteral* combined_histo =
|
||||
BROTLI_ALLOC(m, HistogramLiteral, 2 * num_contexts);
|
||||
double combined_entropy[2 * BROTLI_MAX_STATIC_CONTEXTS];
|
||||
double diff[2] = { 0.0 };
|
||||
size_t i;
|
||||
if (BROTLI_IS_OOM(m)) return;
|
||||
for (i = 0; i < num_contexts; ++i) {
|
||||
size_t curr_histo_ix = self->curr_histogram_ix_ + i;
|
||||
size_t j;
|
||||
entropy[i] = BitsEntropy(histograms[curr_histo_ix].data_,
|
||||
self->alphabet_size_);
|
||||
for (j = 0; j < 2; ++j) {
|
||||
size_t jx = j * num_contexts + i;
|
||||
size_t last_histogram_ix = self->last_histogram_ix_[j] + i;
|
||||
combined_histo[jx] = histograms[curr_histo_ix];
|
||||
HistogramAddHistogramLiteral(&combined_histo[jx],
|
||||
&histograms[last_histogram_ix]);
|
||||
combined_entropy[jx] = BitsEntropy(
|
||||
&combined_histo[jx].data_[0], self->alphabet_size_);
|
||||
diff[j] += combined_entropy[jx] - entropy[i] - last_entropy[jx];
|
||||
}
|
||||
}
|
||||
|
||||
if (split->num_types < self->max_block_types_ &&
|
||||
diff[0] > self->split_threshold_ &&
|
||||
diff[1] > self->split_threshold_) {
|
||||
/* Create new block. */
|
||||
split->lengths[self->num_blocks_] = (uint32_t)self->block_size_;
|
||||
split->types[self->num_blocks_] = (uint8_t)split->num_types;
|
||||
self->last_histogram_ix_[1] = self->last_histogram_ix_[0];
|
||||
self->last_histogram_ix_[0] = split->num_types * num_contexts;
|
||||
for (i = 0; i < num_contexts; ++i) {
|
||||
last_entropy[num_contexts + i] = last_entropy[i];
|
||||
last_entropy[i] = entropy[i];
|
||||
}
|
||||
++self->num_blocks_;
|
||||
++split->num_types;
|
||||
self->curr_histogram_ix_ += num_contexts;
|
||||
if (self->curr_histogram_ix_ < *self->histograms_size_) {
|
||||
ClearHistogramsLiteral(
|
||||
&self->histograms_[self->curr_histogram_ix_], self->num_contexts_);
|
||||
}
|
||||
self->block_size_ = 0;
|
||||
self->merge_last_count_ = 0;
|
||||
self->target_block_size_ = self->min_block_size_;
|
||||
} else if (diff[1] < diff[0] - 20.0) {
|
||||
/* Combine this block with second last block. */
|
||||
split->lengths[self->num_blocks_] = (uint32_t)self->block_size_;
|
||||
split->types[self->num_blocks_] = split->types[self->num_blocks_ - 2];
|
||||
BROTLI_SWAP(size_t, self->last_histogram_ix_, 0, 1);
|
||||
for (i = 0; i < num_contexts; ++i) {
|
||||
histograms[self->last_histogram_ix_[0] + i] =
|
||||
combined_histo[num_contexts + i];
|
||||
last_entropy[num_contexts + i] = last_entropy[i];
|
||||
last_entropy[i] = combined_entropy[num_contexts + i];
|
||||
HistogramClearLiteral(&histograms[self->curr_histogram_ix_ + i]);
|
||||
}
|
||||
++self->num_blocks_;
|
||||
self->block_size_ = 0;
|
||||
self->merge_last_count_ = 0;
|
||||
self->target_block_size_ = self->min_block_size_;
|
||||
} else {
|
||||
/* Combine this block with last block. */
|
||||
split->lengths[self->num_blocks_ - 1] += (uint32_t)self->block_size_;
|
||||
for (i = 0; i < num_contexts; ++i) {
|
||||
histograms[self->last_histogram_ix_[0] + i] = combined_histo[i];
|
||||
last_entropy[i] = combined_entropy[i];
|
||||
if (split->num_types == 1) {
|
||||
last_entropy[num_contexts + i] = last_entropy[i];
|
||||
}
|
||||
HistogramClearLiteral(&histograms[self->curr_histogram_ix_ + i]);
|
||||
}
|
||||
self->block_size_ = 0;
|
||||
if (++self->merge_last_count_ > 1) {
|
||||
self->target_block_size_ += self->min_block_size_;
|
||||
}
|
||||
}
|
||||
BROTLI_FREE(m, combined_histo);
|
||||
}
|
||||
if (is_final) {
|
||||
*self->histograms_size_ = split->num_types * num_contexts;
|
||||
split->num_blocks = self->num_blocks_;
|
||||
}
|
||||
}
|
||||
|
||||
/* Adds the next symbol to the current block type and context. When the
|
||||
current block reaches the target size, decides on merging the block. */
|
||||
static void ContextBlockSplitterAddSymbol(
|
||||
ContextBlockSplitter* self, MemoryManager* m,
|
||||
size_t symbol, size_t context) {
|
||||
HistogramAddLiteral(&self->histograms_[self->curr_histogram_ix_ + context],
|
||||
symbol);
|
||||
++self->block_size_;
|
||||
if (self->block_size_ == self->target_block_size_) {
|
||||
ContextBlockSplitterFinishBlock(self, m, /* is_final = */ BROTLI_FALSE);
|
||||
if (BROTLI_IS_OOM(m)) return;
|
||||
}
|
||||
}
|
||||
|
||||
static void MapStaticContexts(MemoryManager* m,
|
||||
size_t num_contexts,
|
||||
const uint32_t* static_context_map,
|
||||
MetaBlockSplit* mb) {
|
||||
size_t i;
|
||||
assert(mb->literal_context_map == 0);
|
||||
mb->literal_context_map_size =
|
||||
mb->literal_split.num_types << BROTLI_LITERAL_CONTEXT_BITS;
|
||||
mb->literal_context_map =
|
||||
BROTLI_ALLOC(m, uint32_t, mb->literal_context_map_size);
|
||||
if (BROTLI_IS_OOM(m)) return;
|
||||
|
||||
for (i = 0; i < mb->literal_split.num_types; ++i) {
|
||||
uint32_t offset = (uint32_t)(i * num_contexts);
|
||||
size_t j;
|
||||
for (j = 0; j < (1u << BROTLI_LITERAL_CONTEXT_BITS); ++j) {
|
||||
mb->literal_context_map[(i << BROTLI_LITERAL_CONTEXT_BITS) + j] =
|
||||
offset + static_context_map[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static BROTLI_INLINE void BrotliBuildMetaBlockGreedyInternal(
|
||||
MemoryManager* m, const uint8_t* ringbuffer, size_t pos, size_t mask,
|
||||
uint8_t prev_byte, uint8_t prev_byte2, ContextType literal_context_mode,
|
||||
const size_t num_contexts, const uint32_t* static_context_map,
|
||||
const Command *commands, size_t n_commands, MetaBlockSplit* mb) {
|
||||
union {
|
||||
BlockSplitterLiteral plain;
|
||||
ContextBlockSplitter ctx;
|
||||
} lit_blocks;
|
||||
BlockSplitterCommand cmd_blocks;
|
||||
BlockSplitterDistance dist_blocks;
|
||||
size_t num_literals = 0;
|
||||
size_t i;
|
||||
for (i = 0; i < n_commands; ++i) {
|
||||
num_literals += commands[i].insert_len_;
|
||||
}
|
||||
|
||||
if (num_contexts == 1) {
|
||||
InitBlockSplitterLiteral(m, &lit_blocks.plain, 256, 512, 400.0,
|
||||
num_literals, &mb->literal_split, &mb->literal_histograms,
|
||||
&mb->literal_histograms_size);
|
||||
} else {
|
||||
InitContextBlockSplitter(m, &lit_blocks.ctx, 256, num_contexts, 512, 400.0,
|
||||
num_literals, &mb->literal_split, &mb->literal_histograms,
|
||||
&mb->literal_histograms_size);
|
||||
}
|
||||
if (BROTLI_IS_OOM(m)) return;
|
||||
InitBlockSplitterCommand(m, &cmd_blocks, BROTLI_NUM_COMMAND_SYMBOLS, 1024,
|
||||
500.0, n_commands, &mb->command_split, &mb->command_histograms,
|
||||
&mb->command_histograms_size);
|
||||
if (BROTLI_IS_OOM(m)) return;
|
||||
InitBlockSplitterDistance(m, &dist_blocks, 64, 512, 100.0, n_commands,
|
||||
&mb->distance_split, &mb->distance_histograms,
|
||||
&mb->distance_histograms_size);
|
||||
if (BROTLI_IS_OOM(m)) return;
|
||||
|
||||
for (i = 0; i < n_commands; ++i) {
|
||||
const Command cmd = commands[i];
|
||||
size_t j;
|
||||
BlockSplitterAddSymbolCommand(&cmd_blocks, cmd.cmd_prefix_);
|
||||
for (j = cmd.insert_len_; j != 0; --j) {
|
||||
uint8_t literal = ringbuffer[pos & mask];
|
||||
if (num_contexts == 1) {
|
||||
BlockSplitterAddSymbolLiteral(&lit_blocks.plain, literal);
|
||||
} else {
|
||||
size_t context = Context(prev_byte, prev_byte2, literal_context_mode);
|
||||
ContextBlockSplitterAddSymbol(&lit_blocks.ctx, m, literal,
|
||||
static_context_map[context]);
|
||||
if (BROTLI_IS_OOM(m)) return;
|
||||
}
|
||||
prev_byte2 = prev_byte;
|
||||
prev_byte = literal;
|
||||
++pos;
|
||||
}
|
||||
pos += CommandCopyLen(&cmd);
|
||||
if (CommandCopyLen(&cmd)) {
|
||||
prev_byte2 = ringbuffer[(pos - 2) & mask];
|
||||
prev_byte = ringbuffer[(pos - 1) & mask];
|
||||
if (cmd.cmd_prefix_ >= 128) {
|
||||
BlockSplitterAddSymbolDistance(&dist_blocks, cmd.dist_prefix_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (num_contexts == 1) {
|
||||
BlockSplitterFinishBlockLiteral(
|
||||
&lit_blocks.plain, /* is_final = */ BROTLI_TRUE);
|
||||
} else {
|
||||
ContextBlockSplitterFinishBlock(
|
||||
&lit_blocks.ctx, m, /* is_final = */ BROTLI_TRUE);
|
||||
if (BROTLI_IS_OOM(m)) return;
|
||||
}
|
||||
BlockSplitterFinishBlockCommand(&cmd_blocks, /* is_final = */ BROTLI_TRUE);
|
||||
BlockSplitterFinishBlockDistance(&dist_blocks, /* is_final = */ BROTLI_TRUE);
|
||||
|
||||
if (num_contexts > 1) {
|
||||
MapStaticContexts(m, num_contexts, static_context_map, mb);
|
||||
}
|
||||
}
|
||||
|
||||
void BrotliBuildMetaBlockGreedy(MemoryManager* m,
|
||||
const uint8_t* ringbuffer,
|
||||
size_t pos,
|
||||
size_t mask,
|
||||
uint8_t prev_byte,
|
||||
uint8_t prev_byte2,
|
||||
ContextType literal_context_mode,
|
||||
size_t num_contexts,
|
||||
const uint32_t* static_context_map,
|
||||
const Command* commands,
|
||||
size_t n_commands,
|
||||
MetaBlockSplit* mb) {
|
||||
if (num_contexts == 1) {
|
||||
BrotliBuildMetaBlockGreedyInternal(m, ringbuffer, pos, mask, prev_byte,
|
||||
prev_byte2, literal_context_mode, 1, NULL, commands, n_commands, mb);
|
||||
} else {
|
||||
BrotliBuildMetaBlockGreedyInternal(m, ringbuffer, pos, mask, prev_byte,
|
||||
prev_byte2, literal_context_mode, num_contexts, static_context_map,
|
||||
commands, n_commands, mb);
|
||||
}
|
||||
}
|
||||
|
||||
void BrotliOptimizeHistograms(size_t num_direct_distance_codes,
|
||||
size_t distance_postfix_bits,
|
||||
MetaBlockSplit* mb) {
|
||||
uint8_t good_for_rle[BROTLI_NUM_COMMAND_SYMBOLS];
|
||||
size_t num_distance_codes;
|
||||
size_t i;
|
||||
for (i = 0; i < mb->literal_histograms_size; ++i) {
|
||||
BrotliOptimizeHuffmanCountsForRle(256, mb->literal_histograms[i].data_,
|
||||
good_for_rle);
|
||||
}
|
||||
for (i = 0; i < mb->command_histograms_size; ++i) {
|
||||
BrotliOptimizeHuffmanCountsForRle(BROTLI_NUM_COMMAND_SYMBOLS,
|
||||
mb->command_histograms[i].data_,
|
||||
good_for_rle);
|
||||
}
|
||||
num_distance_codes = BROTLI_NUM_DISTANCE_SHORT_CODES +
|
||||
num_direct_distance_codes +
|
||||
((2 * BROTLI_MAX_DISTANCE_BITS) << distance_postfix_bits);
|
||||
for (i = 0; i < mb->distance_histograms_size; ++i) {
|
||||
BrotliOptimizeHuffmanCountsForRle(num_distance_codes,
|
||||
mb->distance_histograms[i].data_,
|
||||
good_for_rle);
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(__cplusplus) || defined(c_plusplus)
|
||||
} /* extern "C" */
|
||||
#endif
|
|
@ -0,0 +1,168 @@
|
|||
// Copyright 2016 Google Inc. All Rights Reserved.
|
||||
//
|
||||
// Distributed under MIT license.
|
||||
// See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
|
||||
// Package cbrotli compresses and decompresses data with C-Brotli library.
|
||||
package brotli
|
||||
|
||||
/*
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "brotli/decode.h"
|
||||
|
||||
static BrotliDecoderResult DecompressStream(BrotliDecoderState* s,
|
||||
uint8_t* out, size_t out_len,
|
||||
const uint8_t* in, size_t in_len,
|
||||
size_t* bytes_written,
|
||||
size_t* bytes_consumed) {
|
||||
size_t in_remaining = in_len;
|
||||
size_t out_remaining = out_len;
|
||||
BrotliDecoderResult result = BrotliDecoderDecompressStream(
|
||||
s, &in_remaining, &in, &out_remaining, &out, NULL);
|
||||
*bytes_written = out_len - out_remaining;
|
||||
*bytes_consumed = in_len - in_remaining;
|
||||
return result;
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
)
|
||||
|
||||
type decodeError C.BrotliDecoderErrorCode
|
||||
|
||||
func (err decodeError) Error() string {
|
||||
return "cbrotli: " +
|
||||
C.GoString(C.BrotliDecoderErrorString(C.BrotliDecoderErrorCode(err)))
|
||||
}
|
||||
|
||||
var errExcessiveInput = errors.New("cbrotli: excessive input")
|
||||
var errInvalidState = errors.New("cbrotli: invalid state")
|
||||
var errReaderClosed = errors.New("cbrotli: Reader is closed")
|
||||
|
||||
// Reader implements io.ReadCloser by reading Brotli-encoded data from an
|
||||
// underlying Reader.
|
||||
type Reader struct {
|
||||
src io.Reader
|
||||
state *C.BrotliDecoderState
|
||||
buf []byte // scratch space for reading from src
|
||||
in []byte // current chunk to decode; usually aliases buf
|
||||
}
|
||||
|
||||
// readBufSize is a "good" buffer size that avoids excessive round-trips
|
||||
// between C and Go but doesn't waste too much memory on buffering.
|
||||
// It is arbitrarily chosen to be equal to the constant used in io.Copy.
|
||||
const readBufSize = 32 * 1024
|
||||
|
||||
// NewReader initializes new Reader instance.
|
||||
// Close MUST be called to free resources.
|
||||
func NewReader(src io.Reader) *Reader {
|
||||
return &Reader{
|
||||
src: src,
|
||||
state: C.BrotliDecoderCreateInstance(nil, nil, nil),
|
||||
buf: make([]byte, readBufSize),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Reader) SetDictionary(p []byte) {
|
||||
var data *C.uint8_t
|
||||
if len(p) != 0 {
|
||||
data = (*C.uint8_t)(&p[0])
|
||||
}
|
||||
|
||||
C.BrotliDecoderSetCustomDictionary(r.state, C.size_t(len(p)), data)
|
||||
}
|
||||
|
||||
// Close implements io.Closer. Close MUST be invoked to free native resources.
|
||||
func (r *Reader) Close() error {
|
||||
if r.state == nil {
|
||||
return errReaderClosed
|
||||
}
|
||||
// Close despite the state; i.e. there might be some unread decoded data.
|
||||
C.BrotliDecoderDestroyInstance(r.state)
|
||||
r.state = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Reader) Read(p []byte) (n int, err error) {
|
||||
if int(C.BrotliDecoderHasMoreOutput(r.state)) == 0 && len(r.in) == 0 {
|
||||
m, readErr := r.src.Read(r.buf)
|
||||
if m == 0 {
|
||||
// If readErr is `nil`, we just proxy underlying stream behavior.
|
||||
return 0, readErr
|
||||
}
|
||||
r.in = r.buf[:m]
|
||||
}
|
||||
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
for {
|
||||
var written, consumed C.size_t
|
||||
var data *C.uint8_t
|
||||
if len(r.in) != 0 {
|
||||
data = (*C.uint8_t)(&r.in[0])
|
||||
}
|
||||
result := C.DecompressStream(r.state,
|
||||
(*C.uint8_t)(&p[0]), C.size_t(len(p)),
|
||||
data, C.size_t(len(r.in)),
|
||||
&written, &consumed)
|
||||
r.in = r.in[int(consumed):]
|
||||
n = int(written)
|
||||
|
||||
switch result {
|
||||
case C.BROTLI_DECODER_RESULT_SUCCESS:
|
||||
if len(r.in) > 0 {
|
||||
return n, errExcessiveInput
|
||||
}
|
||||
return n, nil
|
||||
case C.BROTLI_DECODER_RESULT_ERROR:
|
||||
return n, decodeError(C.BrotliDecoderGetErrorCode(r.state))
|
||||
case C.BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT:
|
||||
if n == 0 {
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
return n, nil
|
||||
case C.BROTLI_DECODER_NEEDS_MORE_INPUT:
|
||||
}
|
||||
|
||||
if len(r.in) != 0 {
|
||||
return 0, errInvalidState
|
||||
}
|
||||
|
||||
// Calling r.src.Read may block. Don't block if we have data to return.
|
||||
if n > 0 {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// Top off the buffer.
|
||||
encN, err := r.src.Read(r.buf)
|
||||
if encN == 0 {
|
||||
// Not enough data to complete decoding.
|
||||
if err == io.EOF {
|
||||
return 0, io.ErrUnexpectedEOF
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
r.in = r.buf[:encN]
|
||||
}
|
||||
}
|
||||
|
||||
// Decode decodes Brotli encoded data.
|
||||
func Decode(encodedData []byte) ([]byte, error) {
|
||||
r := &Reader{
|
||||
src: bytes.NewReader(nil),
|
||||
state: C.BrotliDecoderCreateInstance(nil, nil, nil),
|
||||
buf: make([]byte, 4), // arbitrarily small but nonzero so that r.src.Read returns io.EOF
|
||||
in: encodedData,
|
||||
}
|
||||
defer r.Close()
|
||||
return ioutil.ReadAll(r)
|
||||
}
|
|
@ -0,0 +1,175 @@
|
|||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
#include "./dec/state.h"
|
||||
|
||||
#include <stdlib.h> /* free, malloc */
|
||||
|
||||
#include <brotli/types.h>
|
||||
#include "./dec/huffman.h"
|
||||
|
||||
#if defined(__cplusplus) || defined(c_plusplus)
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
static void* DefaultAllocFunc(void* opaque, size_t size) {
|
||||
BROTLI_UNUSED(opaque);
|
||||
return malloc(size);
|
||||
}
|
||||
|
||||
static void DefaultFreeFunc(void* opaque, void* address) {
|
||||
BROTLI_UNUSED(opaque);
|
||||
free(address);
|
||||
}
|
||||
|
||||
void BrotliDecoderStateInit(BrotliDecoderState* s) {
|
||||
BrotliDecoderStateInitWithCustomAllocators(s, 0, 0, 0);
|
||||
}
|
||||
|
||||
void BrotliDecoderStateInitWithCustomAllocators(BrotliDecoderState* s,
|
||||
brotli_alloc_func alloc_func, brotli_free_func free_func, void* opaque) {
|
||||
if (!alloc_func) {
|
||||
s->alloc_func = DefaultAllocFunc;
|
||||
s->free_func = DefaultFreeFunc;
|
||||
s->memory_manager_opaque = 0;
|
||||
} else {
|
||||
s->alloc_func = alloc_func;
|
||||
s->free_func = free_func;
|
||||
s->memory_manager_opaque = opaque;
|
||||
}
|
||||
|
||||
s->error_code = 0; /* BROTLI_DECODER_NO_ERROR */
|
||||
|
||||
BrotliInitBitReader(&s->br);
|
||||
s->state = BROTLI_STATE_UNINITED;
|
||||
s->substate_metablock_header = BROTLI_STATE_METABLOCK_HEADER_NONE;
|
||||
s->substate_tree_group = BROTLI_STATE_TREE_GROUP_NONE;
|
||||
s->substate_context_map = BROTLI_STATE_CONTEXT_MAP_NONE;
|
||||
s->substate_uncompressed = BROTLI_STATE_UNCOMPRESSED_NONE;
|
||||
s->substate_huffman = BROTLI_STATE_HUFFMAN_NONE;
|
||||
s->substate_decode_uint8 = BROTLI_STATE_DECODE_UINT8_NONE;
|
||||
s->substate_read_block_length = BROTLI_STATE_READ_BLOCK_LENGTH_NONE;
|
||||
|
||||
s->dictionary = BrotliGetDictionary();
|
||||
|
||||
s->buffer_length = 0;
|
||||
s->loop_counter = 0;
|
||||
s->pos = 0;
|
||||
s->rb_roundtrips = 0;
|
||||
s->partial_pos_out = 0;
|
||||
|
||||
s->block_type_trees = NULL;
|
||||
s->block_len_trees = NULL;
|
||||
s->ringbuffer = NULL;
|
||||
s->ringbuffer_size = 0;
|
||||
s->new_ringbuffer_size = 0;
|
||||
s->ringbuffer_mask = 0;
|
||||
|
||||
s->context_map = NULL;
|
||||
s->context_modes = NULL;
|
||||
s->dist_context_map = NULL;
|
||||
s->context_map_slice = NULL;
|
||||
s->dist_context_map_slice = NULL;
|
||||
|
||||
s->sub_loop_counter = 0;
|
||||
|
||||
s->literal_hgroup.codes = NULL;
|
||||
s->literal_hgroup.htrees = NULL;
|
||||
s->insert_copy_hgroup.codes = NULL;
|
||||
s->insert_copy_hgroup.htrees = NULL;
|
||||
s->distance_hgroup.codes = NULL;
|
||||
s->distance_hgroup.htrees = NULL;
|
||||
|
||||
s->custom_dict = NULL;
|
||||
s->custom_dict_size = 0;
|
||||
|
||||
s->is_last_metablock = 0;
|
||||
s->is_uncompressed = 0;
|
||||
s->is_metadata = 0;
|
||||
s->should_wrap_ringbuffer = 0;
|
||||
s->canny_ringbuffer_allocation = 1;
|
||||
|
||||
s->window_bits = 0;
|
||||
s->max_distance = 0;
|
||||
s->dist_rb[0] = 16;
|
||||
s->dist_rb[1] = 15;
|
||||
s->dist_rb[2] = 11;
|
||||
s->dist_rb[3] = 4;
|
||||
s->dist_rb_idx = 0;
|
||||
s->block_type_trees = NULL;
|
||||
s->block_len_trees = NULL;
|
||||
|
||||
/* Make small negative indexes addressable. */
|
||||
s->symbol_lists = &s->symbols_lists_array[BROTLI_HUFFMAN_MAX_CODE_LENGTH + 1];
|
||||
|
||||
s->mtf_upper_bound = 63;
|
||||
}
|
||||
|
||||
void BrotliDecoderStateMetablockBegin(BrotliDecoderState* s) {
|
||||
s->meta_block_remaining_len = 0;
|
||||
s->block_length[0] = 1U << 28;
|
||||
s->block_length[1] = 1U << 28;
|
||||
s->block_length[2] = 1U << 28;
|
||||
s->num_block_types[0] = 1;
|
||||
s->num_block_types[1] = 1;
|
||||
s->num_block_types[2] = 1;
|
||||
s->block_type_rb[0] = 1;
|
||||
s->block_type_rb[1] = 0;
|
||||
s->block_type_rb[2] = 1;
|
||||
s->block_type_rb[3] = 0;
|
||||
s->block_type_rb[4] = 1;
|
||||
s->block_type_rb[5] = 0;
|
||||
s->context_map = NULL;
|
||||
s->context_modes = NULL;
|
||||
s->dist_context_map = NULL;
|
||||
s->context_map_slice = NULL;
|
||||
s->literal_htree = NULL;
|
||||
s->dist_context_map_slice = NULL;
|
||||
s->dist_htree_index = 0;
|
||||
s->context_lookup1 = NULL;
|
||||
s->context_lookup2 = NULL;
|
||||
s->literal_hgroup.codes = NULL;
|
||||
s->literal_hgroup.htrees = NULL;
|
||||
s->insert_copy_hgroup.codes = NULL;
|
||||
s->insert_copy_hgroup.htrees = NULL;
|
||||
s->distance_hgroup.codes = NULL;
|
||||
s->distance_hgroup.htrees = NULL;
|
||||
}
|
||||
|
||||
void BrotliDecoderStateCleanupAfterMetablock(BrotliDecoderState* s) {
|
||||
BROTLI_FREE(s, s->context_modes);
|
||||
BROTLI_FREE(s, s->context_map);
|
||||
BROTLI_FREE(s, s->dist_context_map);
|
||||
BROTLI_FREE(s, s->literal_hgroup.htrees);
|
||||
BROTLI_FREE(s, s->insert_copy_hgroup.htrees);
|
||||
BROTLI_FREE(s, s->distance_hgroup.htrees);
|
||||
}
|
||||
|
||||
void BrotliDecoderStateCleanup(BrotliDecoderState* s) {
|
||||
BrotliDecoderStateCleanupAfterMetablock(s);
|
||||
|
||||
BROTLI_FREE(s, s->ringbuffer);
|
||||
BROTLI_FREE(s, s->block_type_trees);
|
||||
}
|
||||
|
||||
BROTLI_BOOL BrotliDecoderHuffmanTreeGroupInit(BrotliDecoderState* s,
|
||||
HuffmanTreeGroup* group, uint32_t alphabet_size, uint32_t ntrees) {
|
||||
/* Pack two allocations into one */
|
||||
const size_t max_table_size = kMaxHuffmanTableSize[(alphabet_size + 31) >> 5];
|
||||
const size_t code_size = sizeof(HuffmanCode) * ntrees * max_table_size;
|
||||
const size_t htree_size = sizeof(HuffmanCode*) * ntrees;
|
||||
/* Pointer alignment is, hopefully, wider than sizeof(HuffmanCode). */
|
||||
HuffmanCode** p = (HuffmanCode**)BROTLI_ALLOC(s, code_size + htree_size);
|
||||
group->alphabet_size = (uint16_t)alphabet_size;
|
||||
group->num_htrees = (uint16_t)ntrees;
|
||||
group->htrees = p;
|
||||
group->codes = (HuffmanCode*)(&p[ntrees]);
|
||||
return !!p;
|
||||
}
|
||||
|
||||
#if defined(__cplusplus) || defined(c_plusplus)
|
||||
} /* extern "C" */
|
||||
#endif
|
|
@ -0,0 +1,482 @@
|
|||
/* Copyright 2013 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
#include "./enc/static_dict.h"
|
||||
|
||||
#include "./common/dictionary.h"
|
||||
#include "./enc/find_match_length.h"
|
||||
#include "./enc/port.h"
|
||||
#include "./enc/static_dict_lut.h"
|
||||
|
||||
#if defined(__cplusplus) || defined(c_plusplus)
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
static const uint8_t kUppercaseFirst = 10;
|
||||
static const uint8_t kOmitLastNTransforms[10] = {
|
||||
0, 12, 27, 23, 42, 63, 56, 48, 59, 64,
|
||||
};
|
||||
|
||||
static BROTLI_INLINE uint32_t Hash(const uint8_t *data) {
|
||||
uint32_t h = BROTLI_UNALIGNED_LOAD32(data) * kDictHashMul32;
|
||||
/* The higher bits contain more mixture from the multiplication,
|
||||
so we take our results from there. */
|
||||
return h >> (32 - kDictNumBits);
|
||||
}
|
||||
|
||||
static BROTLI_INLINE void AddMatch(size_t distance, size_t len, size_t len_code,
|
||||
uint32_t* matches) {
|
||||
uint32_t match = (uint32_t)((distance << 5) + len_code);
|
||||
matches[len] = BROTLI_MIN(uint32_t, matches[len], match);
|
||||
}
|
||||
|
||||
static BROTLI_INLINE size_t DictMatchLength(const BrotliDictionary* dictionary,
|
||||
const uint8_t* data,
|
||||
size_t id,
|
||||
size_t len,
|
||||
size_t maxlen) {
|
||||
const size_t offset = dictionary->offsets_by_length[len] + len * id;
|
||||
return FindMatchLengthWithLimit(&dictionary->data[offset], data,
|
||||
BROTLI_MIN(size_t, len, maxlen));
|
||||
}
|
||||
|
||||
static BROTLI_INLINE BROTLI_BOOL IsMatch(const BrotliDictionary* dictionary,
|
||||
DictWord w, const uint8_t* data, size_t max_length) {
|
||||
if (w.len > max_length) {
|
||||
return BROTLI_FALSE;
|
||||
} else {
|
||||
const size_t offset = dictionary->offsets_by_length[w.len] +
|
||||
(size_t)w.len * (size_t)w.idx;
|
||||
const uint8_t* dict = &dictionary->data[offset];
|
||||
if (w.transform == 0) {
|
||||
/* Match against base dictionary word. */
|
||||
return
|
||||
TO_BROTLI_BOOL(FindMatchLengthWithLimit(dict, data, w.len) == w.len);
|
||||
} else if (w.transform == 10) {
|
||||
/* Match against uppercase first transform.
|
||||
Note that there are only ASCII uppercase words in the lookup table. */
|
||||
return TO_BROTLI_BOOL(dict[0] >= 'a' && dict[0] <= 'z' &&
|
||||
(dict[0] ^ 32) == data[0] &&
|
||||
FindMatchLengthWithLimit(&dict[1], &data[1], w.len - 1u) ==
|
||||
w.len - 1u);
|
||||
} else {
|
||||
/* Match against uppercase all transform.
|
||||
Note that there are only ASCII uppercase words in the lookup table. */
|
||||
size_t i;
|
||||
for (i = 0; i < w.len; ++i) {
|
||||
if (dict[i] >= 'a' && dict[i] <= 'z') {
|
||||
if ((dict[i] ^ 32) != data[i]) return BROTLI_FALSE;
|
||||
} else {
|
||||
if (dict[i] != data[i]) return BROTLI_FALSE;
|
||||
}
|
||||
}
|
||||
return BROTLI_TRUE;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
BROTLI_BOOL BrotliFindAllStaticDictionaryMatches(
|
||||
const BrotliDictionary* dictionary, const uint8_t* data, size_t min_length,
|
||||
size_t max_length, uint32_t* matches) {
|
||||
BROTLI_BOOL has_found_match = BROTLI_FALSE;
|
||||
{
|
||||
size_t offset = kStaticDictionaryBuckets[Hash(data)];
|
||||
BROTLI_BOOL end = !offset;
|
||||
while (!end) {
|
||||
DictWord w = kStaticDictionaryWords[offset++];
|
||||
const size_t l = w.len & 0x1F;
|
||||
const size_t n = (size_t)1 << dictionary->size_bits_by_length[l];
|
||||
const size_t id = w.idx;
|
||||
end = !!(w.len & 0x80);
|
||||
w.len = (uint8_t)l;
|
||||
if (w.transform == 0) {
|
||||
const size_t matchlen =
|
||||
DictMatchLength(dictionary, data, id, l, max_length);
|
||||
const uint8_t* s;
|
||||
size_t minlen;
|
||||
size_t maxlen;
|
||||
size_t len;
|
||||
/* Transform "" + kIdentity + "" */
|
||||
if (matchlen == l) {
|
||||
AddMatch(id, l, l, matches);
|
||||
has_found_match = BROTLI_TRUE;
|
||||
}
|
||||
/* Transforms "" + kOmitLast1 + "" and "" + kOmitLast1 + "ing " */
|
||||
if (matchlen >= l - 1) {
|
||||
AddMatch(id + 12 * n, l - 1, l, matches);
|
||||
if (l + 2 < max_length &&
|
||||
data[l - 1] == 'i' && data[l] == 'n' && data[l + 1] == 'g' &&
|
||||
data[l + 2] == ' ') {
|
||||
AddMatch(id + 49 * n, l + 3, l, matches);
|
||||
}
|
||||
has_found_match = BROTLI_TRUE;
|
||||
}
|
||||
/* Transform "" + kOmitLastN + "" (N = 2 .. 9) */
|
||||
minlen = min_length;
|
||||
if (l > 9) minlen = BROTLI_MAX(size_t, minlen, l - 9);
|
||||
maxlen = BROTLI_MIN(size_t, matchlen, l - 2);
|
||||
for (len = minlen; len <= maxlen; ++len) {
|
||||
AddMatch(id + kOmitLastNTransforms[l - len] * n, len, l, matches);
|
||||
has_found_match = BROTLI_TRUE;
|
||||
}
|
||||
if (matchlen < l || l + 6 >= max_length) {
|
||||
continue;
|
||||
}
|
||||
s = &data[l];
|
||||
/* Transforms "" + kIdentity + <suffix> */
|
||||
if (s[0] == ' ') {
|
||||
AddMatch(id + n, l + 1, l, matches);
|
||||
if (s[1] == 'a') {
|
||||
if (s[2] == ' ') {
|
||||
AddMatch(id + 28 * n, l + 3, l, matches);
|
||||
} else if (s[2] == 's') {
|
||||
if (s[3] == ' ') AddMatch(id + 46 * n, l + 4, l, matches);
|
||||
} else if (s[2] == 't') {
|
||||
if (s[3] == ' ') AddMatch(id + 60 * n, l + 4, l, matches);
|
||||
} else if (s[2] == 'n') {
|
||||
if (s[3] == 'd' && s[4] == ' ') {
|
||||
AddMatch(id + 10 * n, l + 5, l, matches);
|
||||
}
|
||||
}
|
||||
} else if (s[1] == 'b') {
|
||||
if (s[2] == 'y' && s[3] == ' ') {
|
||||
AddMatch(id + 38 * n, l + 4, l, matches);
|
||||
}
|
||||
} else if (s[1] == 'i') {
|
||||
if (s[2] == 'n') {
|
||||
if (s[3] == ' ') AddMatch(id + 16 * n, l + 4, l, matches);
|
||||
} else if (s[2] == 's') {
|
||||
if (s[3] == ' ') AddMatch(id + 47 * n, l + 4, l, matches);
|
||||
}
|
||||
} else if (s[1] == 'f') {
|
||||
if (s[2] == 'o') {
|
||||
if (s[3] == 'r' && s[4] == ' ') {
|
||||
AddMatch(id + 25 * n, l + 5, l, matches);
|
||||
}
|
||||
} else if (s[2] == 'r') {
|
||||
if (s[3] == 'o' && s[4] == 'm' && s[5] == ' ') {
|
||||
AddMatch(id + 37 * n, l + 6, l, matches);
|
||||
}
|
||||
}
|
||||
} else if (s[1] == 'o') {
|
||||
if (s[2] == 'f') {
|
||||
if (s[3] == ' ') AddMatch(id + 8 * n, l + 4, l, matches);
|
||||
} else if (s[2] == 'n') {
|
||||
if (s[3] == ' ') AddMatch(id + 45 * n, l + 4, l, matches);
|
||||
}
|
||||
} else if (s[1] == 'n') {
|
||||
if (s[2] == 'o' && s[3] == 't' && s[4] == ' ') {
|
||||
AddMatch(id + 80 * n, l + 5, l, matches);
|
||||
}
|
||||
} else if (s[1] == 't') {
|
||||
if (s[2] == 'h') {
|
||||
if (s[3] == 'e') {
|
||||
if (s[4] == ' ') AddMatch(id + 5 * n, l + 5, l, matches);
|
||||
} else if (s[3] == 'a') {
|
||||
if (s[4] == 't' && s[5] == ' ') {
|
||||
AddMatch(id + 29 * n, l + 6, l, matches);
|
||||
}
|
||||
}
|
||||
} else if (s[2] == 'o') {
|
||||
if (s[3] == ' ') AddMatch(id + 17 * n, l + 4, l, matches);
|
||||
}
|
||||
} else if (s[1] == 'w') {
|
||||
if (s[2] == 'i' && s[3] == 't' && s[4] == 'h' && s[5] == ' ') {
|
||||
AddMatch(id + 35 * n, l + 6, l, matches);
|
||||
}
|
||||
}
|
||||
} else if (s[0] == '"') {
|
||||
AddMatch(id + 19 * n, l + 1, l, matches);
|
||||
if (s[1] == '>') {
|
||||
AddMatch(id + 21 * n, l + 2, l, matches);
|
||||
}
|
||||
} else if (s[0] == '.') {
|
||||
AddMatch(id + 20 * n, l + 1, l, matches);
|
||||
if (s[1] == ' ') {
|
||||
AddMatch(id + 31 * n, l + 2, l, matches);
|
||||
if (s[2] == 'T' && s[3] == 'h') {
|
||||
if (s[4] == 'e') {
|
||||
if (s[5] == ' ') AddMatch(id + 43 * n, l + 6, l, matches);
|
||||
} else if (s[4] == 'i') {
|
||||
if (s[5] == 's' && s[6] == ' ') {
|
||||
AddMatch(id + 75 * n, l + 7, l, matches);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (s[0] == ',') {
|
||||
AddMatch(id + 76 * n, l + 1, l, matches);
|
||||
if (s[1] == ' ') {
|
||||
AddMatch(id + 14 * n, l + 2, l, matches);
|
||||
}
|
||||
} else if (s[0] == '\n') {
|
||||
AddMatch(id + 22 * n, l + 1, l, matches);
|
||||
if (s[1] == '\t') {
|
||||
AddMatch(id + 50 * n, l + 2, l, matches);
|
||||
}
|
||||
} else if (s[0] == ']') {
|
||||
AddMatch(id + 24 * n, l + 1, l, matches);
|
||||
} else if (s[0] == '\'') {
|
||||
AddMatch(id + 36 * n, l + 1, l, matches);
|
||||
} else if (s[0] == ':') {
|
||||
AddMatch(id + 51 * n, l + 1, l, matches);
|
||||
} else if (s[0] == '(') {
|
||||
AddMatch(id + 57 * n, l + 1, l, matches);
|
||||
} else if (s[0] == '=') {
|
||||
if (s[1] == '"') {
|
||||
AddMatch(id + 70 * n, l + 2, l, matches);
|
||||
} else if (s[1] == '\'') {
|
||||
AddMatch(id + 86 * n, l + 2, l, matches);
|
||||
}
|
||||
} else if (s[0] == 'a') {
|
||||
if (s[1] == 'l' && s[2] == ' ') {
|
||||
AddMatch(id + 84 * n, l + 3, l, matches);
|
||||
}
|
||||
} else if (s[0] == 'e') {
|
||||
if (s[1] == 'd') {
|
||||
if (s[2] == ' ') AddMatch(id + 53 * n, l + 3, l, matches);
|
||||
} else if (s[1] == 'r') {
|
||||
if (s[2] == ' ') AddMatch(id + 82 * n, l + 3, l, matches);
|
||||
} else if (s[1] == 's') {
|
||||
if (s[2] == 't' && s[3] == ' ') {
|
||||
AddMatch(id + 95 * n, l + 4, l, matches);
|
||||
}
|
||||
}
|
||||
} else if (s[0] == 'f') {
|
||||
if (s[1] == 'u' && s[2] == 'l' && s[3] == ' ') {
|
||||
AddMatch(id + 90 * n, l + 4, l, matches);
|
||||
}
|
||||
} else if (s[0] == 'i') {
|
||||
if (s[1] == 'v') {
|
||||
if (s[2] == 'e' && s[3] == ' ') {
|
||||
AddMatch(id + 92 * n, l + 4, l, matches);
|
||||
}
|
||||
} else if (s[1] == 'z') {
|
||||
if (s[2] == 'e' && s[3] == ' ') {
|
||||
AddMatch(id + 100 * n, l + 4, l, matches);
|
||||
}
|
||||
}
|
||||
} else if (s[0] == 'l') {
|
||||
if (s[1] == 'e') {
|
||||
if (s[2] == 's' && s[3] == 's' && s[4] == ' ') {
|
||||
AddMatch(id + 93 * n, l + 5, l, matches);
|
||||
}
|
||||
} else if (s[1] == 'y') {
|
||||
if (s[2] == ' ') AddMatch(id + 61 * n, l + 3, l, matches);
|
||||
}
|
||||
} else if (s[0] == 'o') {
|
||||
if (s[1] == 'u' && s[2] == 's' && s[3] == ' ') {
|
||||
AddMatch(id + 106 * n, l + 4, l, matches);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
/* Set is_all_caps=0 for kUppercaseFirst and
|
||||
is_all_caps=1 otherwise (kUppercaseAll) transform. */
|
||||
const BROTLI_BOOL is_all_caps =
|
||||
TO_BROTLI_BOOL(w.transform != kUppercaseFirst);
|
||||
const uint8_t* s;
|
||||
if (!IsMatch(dictionary, w, data, max_length)) {
|
||||
continue;
|
||||
}
|
||||
/* Transform "" + kUppercase{First,All} + "" */
|
||||
AddMatch(id + (is_all_caps ? 44 : 9) * n, l, l, matches);
|
||||
has_found_match = BROTLI_TRUE;
|
||||
if (l + 1 >= max_length) {
|
||||
continue;
|
||||
}
|
||||
/* Transforms "" + kUppercase{First,All} + <suffix> */
|
||||
s = &data[l];
|
||||
if (s[0] == ' ') {
|
||||
AddMatch(id + (is_all_caps ? 68 : 4) * n, l + 1, l, matches);
|
||||
} else if (s[0] == '"') {
|
||||
AddMatch(id + (is_all_caps ? 87 : 66) * n, l + 1, l, matches);
|
||||
if (s[1] == '>') {
|
||||
AddMatch(id + (is_all_caps ? 97 : 69) * n, l + 2, l, matches);
|
||||
}
|
||||
} else if (s[0] == '.') {
|
||||
AddMatch(id + (is_all_caps ? 101 : 79) * n, l + 1, l, matches);
|
||||
if (s[1] == ' ') {
|
||||
AddMatch(id + (is_all_caps ? 114 : 88) * n, l + 2, l, matches);
|
||||
}
|
||||
} else if (s[0] == ',') {
|
||||
AddMatch(id + (is_all_caps ? 112 : 99) * n, l + 1, l, matches);
|
||||
if (s[1] == ' ') {
|
||||
AddMatch(id + (is_all_caps ? 107 : 58) * n, l + 2, l, matches);
|
||||
}
|
||||
} else if (s[0] == '\'') {
|
||||
AddMatch(id + (is_all_caps ? 94 : 74) * n, l + 1, l, matches);
|
||||
} else if (s[0] == '(') {
|
||||
AddMatch(id + (is_all_caps ? 113 : 78) * n, l + 1, l, matches);
|
||||
} else if (s[0] == '=') {
|
||||
if (s[1] == '"') {
|
||||
AddMatch(id + (is_all_caps ? 105 : 104) * n, l + 2, l, matches);
|
||||
} else if (s[1] == '\'') {
|
||||
AddMatch(id + (is_all_caps ? 116 : 108) * n, l + 2, l, matches);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
/* Transforms with prefixes " " and "." */
|
||||
if (max_length >= 5 && (data[0] == ' ' || data[0] == '.')) {
|
||||
BROTLI_BOOL is_space = TO_BROTLI_BOOL(data[0] == ' ');
|
||||
size_t offset = kStaticDictionaryBuckets[Hash(&data[1])];
|
||||
BROTLI_BOOL end = !offset;
|
||||
while (!end) {
|
||||
DictWord w = kStaticDictionaryWords[offset++];
|
||||
const size_t l = w.len & 0x1F;
|
||||
const size_t n = (size_t)1 << dictionary->size_bits_by_length[l];
|
||||
const size_t id = w.idx;
|
||||
end = !!(w.len & 0x80);
|
||||
w.len = (uint8_t)l;
|
||||
if (w.transform == 0) {
|
||||
const uint8_t* s;
|
||||
if (!IsMatch(dictionary, w, &data[1], max_length - 1)) {
|
||||
continue;
|
||||
}
|
||||
/* Transforms " " + kIdentity + "" and "." + kIdentity + "" */
|
||||
AddMatch(id + (is_space ? 6 : 32) * n, l + 1, l, matches);
|
||||
has_found_match = BROTLI_TRUE;
|
||||
if (l + 2 >= max_length) {
|
||||
continue;
|
||||
}
|
||||
/* Transforms " " + kIdentity + <suffix> and "." + kIdentity + <suffix>
|
||||
*/
|
||||
s = &data[l + 1];
|
||||
if (s[0] == ' ') {
|
||||
AddMatch(id + (is_space ? 2 : 77) * n, l + 2, l, matches);
|
||||
} else if (s[0] == '(') {
|
||||
AddMatch(id + (is_space ? 89 : 67) * n, l + 2, l, matches);
|
||||
} else if (is_space) {
|
||||
if (s[0] == ',') {
|
||||
AddMatch(id + 103 * n, l + 2, l, matches);
|
||||
if (s[1] == ' ') {
|
||||
AddMatch(id + 33 * n, l + 3, l, matches);
|
||||
}
|
||||
} else if (s[0] == '.') {
|
||||
AddMatch(id + 71 * n, l + 2, l, matches);
|
||||
if (s[1] == ' ') {
|
||||
AddMatch(id + 52 * n, l + 3, l, matches);
|
||||
}
|
||||
} else if (s[0] == '=') {
|
||||
if (s[1] == '"') {
|
||||
AddMatch(id + 81 * n, l + 3, l, matches);
|
||||
} else if (s[1] == '\'') {
|
||||
AddMatch(id + 98 * n, l + 3, l, matches);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (is_space) {
|
||||
/* Set is_all_caps=0 for kUppercaseFirst and
|
||||
is_all_caps=1 otherwise (kUppercaseAll) transform. */
|
||||
const BROTLI_BOOL is_all_caps =
|
||||
TO_BROTLI_BOOL(w.transform != kUppercaseFirst);
|
||||
const uint8_t* s;
|
||||
if (!IsMatch(dictionary, w, &data[1], max_length - 1)) {
|
||||
continue;
|
||||
}
|
||||
/* Transforms " " + kUppercase{First,All} + "" */
|
||||
AddMatch(id + (is_all_caps ? 85 : 30) * n, l + 1, l, matches);
|
||||
has_found_match = BROTLI_TRUE;
|
||||
if (l + 2 >= max_length) {
|
||||
continue;
|
||||
}
|
||||
/* Transforms " " + kUppercase{First,All} + <suffix> */
|
||||
s = &data[l + 1];
|
||||
if (s[0] == ' ') {
|
||||
AddMatch(id + (is_all_caps ? 83 : 15) * n, l + 2, l, matches);
|
||||
} else if (s[0] == ',') {
|
||||
if (!is_all_caps) {
|
||||
AddMatch(id + 109 * n, l + 2, l, matches);
|
||||
}
|
||||
if (s[1] == ' ') {
|
||||
AddMatch(id + (is_all_caps ? 111 : 65) * n, l + 3, l, matches);
|
||||
}
|
||||
} else if (s[0] == '.') {
|
||||
AddMatch(id + (is_all_caps ? 115 : 96) * n, l + 2, l, matches);
|
||||
if (s[1] == ' ') {
|
||||
AddMatch(id + (is_all_caps ? 117 : 91) * n, l + 3, l, matches);
|
||||
}
|
||||
} else if (s[0] == '=') {
|
||||
if (s[1] == '"') {
|
||||
AddMatch(id + (is_all_caps ? 110 : 118) * n, l + 3, l, matches);
|
||||
} else if (s[1] == '\'') {
|
||||
AddMatch(id + (is_all_caps ? 119 : 120) * n, l + 3, l, matches);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (max_length >= 6) {
|
||||
/* Transforms with prefixes "e ", "s ", ", " and "\xc2\xa0" */
|
||||
if ((data[1] == ' ' &&
|
||||
(data[0] == 'e' || data[0] == 's' || data[0] == ',')) ||
|
||||
(data[0] == 0xc2 && data[1] == 0xa0)) {
|
||||
size_t offset = kStaticDictionaryBuckets[Hash(&data[2])];
|
||||
BROTLI_BOOL end = !offset;
|
||||
while (!end) {
|
||||
DictWord w = kStaticDictionaryWords[offset++];
|
||||
const size_t l = w.len & 0x1F;
|
||||
const size_t n = (size_t)1 << dictionary->size_bits_by_length[l];
|
||||
const size_t id = w.idx;
|
||||
end = !!(w.len & 0x80);
|
||||
w.len = (uint8_t)l;
|
||||
if (w.transform == 0 &&
|
||||
IsMatch(dictionary, w, &data[2], max_length - 2)) {
|
||||
if (data[0] == 0xc2) {
|
||||
AddMatch(id + 102 * n, l + 2, l, matches);
|
||||
has_found_match = BROTLI_TRUE;
|
||||
} else if (l + 2 < max_length && data[l + 2] == ' ') {
|
||||
size_t t = data[0] == 'e' ? 18 : (data[0] == 's' ? 7 : 13);
|
||||
AddMatch(id + t * n, l + 3, l, matches);
|
||||
has_found_match = BROTLI_TRUE;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (max_length >= 9) {
|
||||
/* Transforms with prefixes " the " and ".com/" */
|
||||
if ((data[0] == ' ' && data[1] == 't' && data[2] == 'h' &&
|
||||
data[3] == 'e' && data[4] == ' ') ||
|
||||
(data[0] == '.' && data[1] == 'c' && data[2] == 'o' &&
|
||||
data[3] == 'm' && data[4] == '/')) {
|
||||
size_t offset = kStaticDictionaryBuckets[Hash(&data[5])];
|
||||
BROTLI_BOOL end = !offset;
|
||||
while (!end) {
|
||||
DictWord w = kStaticDictionaryWords[offset++];
|
||||
const size_t l = w.len & 0x1F;
|
||||
const size_t n = (size_t)1 << dictionary->size_bits_by_length[l];
|
||||
const size_t id = w.idx;
|
||||
end = !!(w.len & 0x80);
|
||||
w.len = (uint8_t)l;
|
||||
if (w.transform == 0 &&
|
||||
IsMatch(dictionary, w, &data[5], max_length - 5)) {
|
||||
AddMatch(id + (data[0] == ' ' ? 41 : 72) * n, l + 5, l, matches);
|
||||
has_found_match = BROTLI_TRUE;
|
||||
if (l + 5 < max_length) {
|
||||
const uint8_t* s = &data[l + 5];
|
||||
if (data[0] == ' ') {
|
||||
if (l + 8 < max_length &&
|
||||
s[0] == ' ' && s[1] == 'o' && s[2] == 'f' && s[3] == ' ') {
|
||||
AddMatch(id + 62 * n, l + 9, l, matches);
|
||||
if (l + 12 < max_length &&
|
||||
s[4] == 't' && s[5] == 'h' && s[6] == 'e' && s[7] == ' ') {
|
||||
AddMatch(id + 73 * n, l + 13, l, matches);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return has_found_match;
|
||||
}
|
||||
|
||||
#if defined(__cplusplus) || defined(c_plusplus)
|
||||
} /* extern "C" */
|
||||
#endif
|
|
@ -0,0 +1,85 @@
|
|||
/* Copyright 2013 Google Inc. All Rights Reserved.
|
||||
|
||||
Distributed under MIT license.
|
||||
See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
*/
|
||||
|
||||
/* Heuristics for deciding about the UTF8-ness of strings. */
|
||||
|
||||
#include "./enc/utf8_util.h"
|
||||
|
||||
#include <brotli/types.h>
|
||||
|
||||
#if defined(__cplusplus) || defined(c_plusplus)
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
static size_t BrotliParseAsUTF8(
|
||||
int* symbol, const uint8_t* input, size_t size) {
|
||||
/* ASCII */
|
||||
if ((input[0] & 0x80) == 0) {
|
||||
*symbol = input[0];
|
||||
if (*symbol > 0) {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
/* 2-byte UTF8 */
|
||||
if (size > 1u &&
|
||||
(input[0] & 0xe0) == 0xc0 &&
|
||||
(input[1] & 0xc0) == 0x80) {
|
||||
*symbol = (((input[0] & 0x1f) << 6) |
|
||||
(input[1] & 0x3f));
|
||||
if (*symbol > 0x7f) {
|
||||
return 2;
|
||||
}
|
||||
}
|
||||
/* 3-byte UFT8 */
|
||||
if (size > 2u &&
|
||||
(input[0] & 0xf0) == 0xe0 &&
|
||||
(input[1] & 0xc0) == 0x80 &&
|
||||
(input[2] & 0xc0) == 0x80) {
|
||||
*symbol = (((input[0] & 0x0f) << 12) |
|
||||
((input[1] & 0x3f) << 6) |
|
||||
(input[2] & 0x3f));
|
||||
if (*symbol > 0x7ff) {
|
||||
return 3;
|
||||
}
|
||||
}
|
||||
/* 4-byte UFT8 */
|
||||
if (size > 3u &&
|
||||
(input[0] & 0xf8) == 0xf0 &&
|
||||
(input[1] & 0xc0) == 0x80 &&
|
||||
(input[2] & 0xc0) == 0x80 &&
|
||||
(input[3] & 0xc0) == 0x80) {
|
||||
*symbol = (((input[0] & 0x07) << 18) |
|
||||
((input[1] & 0x3f) << 12) |
|
||||
((input[2] & 0x3f) << 6) |
|
||||
(input[3] & 0x3f));
|
||||
if (*symbol > 0xffff && *symbol <= 0x10ffff) {
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
/* Not UTF8, emit a special symbol above the UTF8-code space */
|
||||
*symbol = 0x110000 | input[0];
|
||||
return 1;
|
||||
}
|
||||
|
||||
/* Returns 1 if at least min_fraction of the data is UTF8-encoded.*/
|
||||
BROTLI_BOOL BrotliIsMostlyUTF8(
|
||||
const uint8_t* data, const size_t pos, const size_t mask,
|
||||
const size_t length, const double min_fraction) {
|
||||
size_t size_utf8 = 0;
|
||||
size_t i = 0;
|
||||
while (i < length) {
|
||||
int symbol;
|
||||
size_t bytes_read =
|
||||
BrotliParseAsUTF8(&symbol, &data[(pos + i) & mask], length - i);
|
||||
i += bytes_read;
|
||||
if (symbol < 0x110000) size_utf8 += bytes_read;
|
||||
}
|
||||
return TO_BROTLI_BOOL(size_utf8 > min_fraction * (double)length);
|
||||
}
|
||||
|
||||
#if defined(__cplusplus) || defined(c_plusplus)
|
||||
} /* extern "C" */
|
||||
#endif
|
|
@ -0,0 +1,169 @@
|
|||
// Copyright 2016 Google Inc. All Rights Reserved.
|
||||
//
|
||||
// Distributed under MIT license.
|
||||
// See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
|
||||
|
||||
package brotli
|
||||
|
||||
/*
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <brotli/encode.h>
|
||||
|
||||
struct CompressStreamResult {
|
||||
size_t bytes_consumed;
|
||||
const uint8_t* output_data;
|
||||
size_t output_data_size;
|
||||
int success;
|
||||
int has_more;
|
||||
};
|
||||
|
||||
static struct CompressStreamResult CompressStream(
|
||||
BrotliEncoderState* s, BrotliEncoderOperation op,
|
||||
const uint8_t* data, size_t data_size) {
|
||||
struct CompressStreamResult result;
|
||||
size_t available_in = data_size;
|
||||
const uint8_t* next_in = data;
|
||||
size_t available_out = 0;
|
||||
result.success = BrotliEncoderCompressStream(s, op,
|
||||
&available_in, &next_in, &available_out, 0, 0) ? 1 : 0;
|
||||
result.bytes_consumed = data_size - available_in;
|
||||
result.output_data = 0;
|
||||
result.output_data_size = 0;
|
||||
if (result.success) {
|
||||
result.output_data = BrotliEncoderTakeOutput(s, &result.output_data_size);
|
||||
}
|
||||
result.has_more = BrotliEncoderHasMoreOutput(s) ? 1 : 0;
|
||||
return result;
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// WriterOptions configures Writer.
|
||||
type WriterOptions struct {
|
||||
// Quality controls the compression-speed vs compression-density trade-offs.
|
||||
// The higher the quality, the slower the compression. Range is 0 to 11.
|
||||
Quality int
|
||||
// LGWin is the base 2 logarithm of the sliding window size.
|
||||
// Range is 10 to 24. 0 indicates automatic configuration based on Quality.
|
||||
LGWin int
|
||||
}
|
||||
|
||||
// Writer implements io.WriteCloser by writing Brotli-encoded data to an
|
||||
// underlying Writer.
|
||||
type Writer struct {
|
||||
dst io.Writer
|
||||
state *C.BrotliEncoderState
|
||||
buf, encoded []byte
|
||||
}
|
||||
|
||||
var (
|
||||
errEncode = errors.New("cbrotli: encode error")
|
||||
errWriterClosed = errors.New("cbrotli: Writer is closed")
|
||||
)
|
||||
|
||||
// NewWriter initializes new Writer instance.
|
||||
// Close MUST be called to free resources.
|
||||
func NewWriter(dst io.Writer, options WriterOptions) *Writer {
|
||||
state := C.BrotliEncoderCreateInstance(nil, nil, nil)
|
||||
C.BrotliEncoderSetParameter(
|
||||
state, C.BROTLI_PARAM_QUALITY, (C.uint32_t)(options.Quality))
|
||||
if options.LGWin > 0 {
|
||||
C.BrotliEncoderSetParameter(
|
||||
state, C.BROTLI_PARAM_LGWIN, (C.uint32_t)(options.LGWin))
|
||||
}
|
||||
return &Writer{
|
||||
dst: dst,
|
||||
state: state,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Writer) SetDictionary(p []byte) {
|
||||
var data *C.uint8_t
|
||||
if len(p) != 0 {
|
||||
data = (*C.uint8_t)(&p[0])
|
||||
}
|
||||
|
||||
C.BrotliEncoderSetCustomDictionary(w.state, C.size_t(len(p)), data)
|
||||
}
|
||||
|
||||
func (w *Writer) writeChunk(p []byte, op C.BrotliEncoderOperation) (n int, err error) {
|
||||
if w.state == nil {
|
||||
return 0, errWriterClosed
|
||||
}
|
||||
|
||||
for {
|
||||
var data *C.uint8_t
|
||||
if len(p) != 0 {
|
||||
data = (*C.uint8_t)(&p[0])
|
||||
}
|
||||
result := C.CompressStream(w.state, op, data, C.size_t(len(p)))
|
||||
if result.success == 0 {
|
||||
return n, errEncode
|
||||
}
|
||||
p = p[int(result.bytes_consumed):]
|
||||
n += int(result.bytes_consumed)
|
||||
|
||||
length := int(result.output_data_size)
|
||||
if length != 0 {
|
||||
// It is a workaround for non-copying-wrapping of native memory.
|
||||
// C-encoder never pushes output block longer than ((2 << 25) + 502).
|
||||
// TODO: use natural wrapper, when it becomes available, see
|
||||
// https://golang.org/issue/13656.
|
||||
output := (*[1 << 30]byte)(unsafe.Pointer(result.output_data))[:length:length]
|
||||
_, err = w.dst.Write(output)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
if len(p) == 0 && result.has_more == 0 {
|
||||
return n, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Flush outputs encoded data for all input provided to Write. The resulting
|
||||
// output can be decoded to match all input before Flush, but the stream is
|
||||
// not yet complete until after Close.
|
||||
// Flush has a negative impact on compression.
|
||||
func (w *Writer) Flush() error {
|
||||
_, err := w.writeChunk(nil, C.BROTLI_OPERATION_FLUSH)
|
||||
return err
|
||||
}
|
||||
|
||||
// Close flushes remaining data to the decorated writer and frees C resources.
|
||||
func (w *Writer) Close() error {
|
||||
// If stream is already closed, it is reported by `writeChunk`.
|
||||
_, err := w.writeChunk(nil, C.BROTLI_OPERATION_FINISH)
|
||||
// C-Brotli tolerates `nil` pointer here.
|
||||
C.BrotliEncoderDestroyInstance(w.state)
|
||||
w.state = nil
|
||||
return err
|
||||
}
|
||||
|
||||
// Write implements io.Writer. Flush or Close must be called to ensure that the
|
||||
// encoded bytes are actually flushed to the underlying Writer.
|
||||
func (w *Writer) Write(p []byte) (n int, err error) {
|
||||
return w.writeChunk(p, C.BROTLI_OPERATION_PROCESS)
|
||||
}
|
||||
|
||||
// Encode returns content encoded with Brotli.
|
||||
func Encode(content []byte, options WriterOptions) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
writer := NewWriter(&buf, options)
|
||||
_, err := writer.Write(content)
|
||||
if closeErr := writer.Close(); err == nil {
|
||||
err = closeErr
|
||||
}
|
||||
return buf.Bytes(), err
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
Copyright (c) 2013 CloudFlare, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the following disclaimer
|
||||
in the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
* Neither the name of the CloudFlare, Inc. nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
@ -0,0 +1,3 @@
|
|||
cover.out~
|
||||
benchmark/benchmark
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright (c) 2013 CloudFlare, Inc.
|
||||
|
||||
RACE+=--race
|
||||
|
||||
PKGNAME=github.com/cloudflare/golibs/lrucache
|
||||
SKIPCOVER=list.go|list_extension.go|priorityqueue.go
|
||||
|
||||
.PHONY: all test bench cover clean
|
||||
|
||||
all:
|
||||
@echo "Targets:"
|
||||
@echo " test: run tests with race detector"
|
||||
@echo " cover: print test coverage"
|
||||
@echo " bench: run basic benchmarks"
|
||||
|
||||
test:
|
||||
@go test $(RACE) -bench=. -v $(PKGNAME)
|
||||
|
||||
COVEROUT=cover.out
|
||||
cover:
|
||||
@go test -coverprofile=$(COVEROUT) -v $(PKGNAME)
|
||||
@cat $(COVEROUT) | egrep -v '$(SKIPCOVER)' > $(COVEROUT)~
|
||||
@go tool cover -func=$(COVEROUT)~|sed 's|^.*/\([^/]*/[^/]*/[^/]*\)$$|\1|g'
|
||||
|
||||
bench:
|
||||
@echo "[*] Scalability of cache/lrucache"
|
||||
@echo "[ ] Operations in shared cache using one core"
|
||||
@GOMAXPROCS=1 go test -run=- -bench='.*LRUCache.*' $(PKGNAME) \
|
||||
| egrep -v "^PASS|^ok"
|
||||
|
||||
@echo "[*] Scalability of cache/multilru"
|
||||
@echo "[ ] Operations in four caches using four cores "
|
||||
@GOMAXPROCS=4 go test -run=- -bench='.*MultiLRU.*' $(PKGNAME) \
|
||||
| egrep -v "^PASS|^ok"
|
||||
|
||||
|
||||
@(cd benchmark; go build $(PKGNAME)/benchmark)
|
||||
@./benchmark/benchmark
|
||||
|
||||
clean:
|
||||
rm -rf $(COVEROUT) $(COVEROUT)~ benchmark/benchmark
|
|
@ -0,0 +1,40 @@
|
|||
LRU Cache
|
||||
---------
|
||||
|
||||
A `golang` implementation of last recently used cache data structure.
|
||||
|
||||
To install:
|
||||
|
||||
go get github.com/cloudflare/golibs/lrucache
|
||||
|
||||
To test:
|
||||
|
||||
cd $GOPATH/src/github.com/cloudflare/golibs/lrucache
|
||||
make test
|
||||
|
||||
For coverage:
|
||||
|
||||
make cover
|
||||
|
||||
Basic benchmarks:
|
||||
|
||||
$ make bench # As tested on my two core i5
|
||||
[*] Scalability of cache/lrucache
|
||||
[ ] Operations in shared cache using one core
|
||||
BenchmarkConcurrentGetLRUCache 5000000 450 ns/op
|
||||
BenchmarkConcurrentSetLRUCache 2000000 821 ns/op
|
||||
BenchmarkConcurrentSetNXLRUCache 5000000 664 ns/op
|
||||
|
||||
[*] Scalability of cache/multilru
|
||||
[ ] Operations in four caches using four cores
|
||||
BenchmarkConcurrentGetMultiLRU-4 5000000 475 ns/op
|
||||
BenchmarkConcurrentSetMultiLRU-4 2000000 809 ns/op
|
||||
BenchmarkConcurrentSetNXMultiLRU-4 5000000 643 ns/op
|
||||
|
||||
[*] Capacity=4096 Keys=30000 KeySpace=15625
|
||||
vitess LRUCache MultiLRUCache-4
|
||||
create 1.709us 1.626374ms 343.54us
|
||||
Get (miss) 144.266083ms 132.470397ms 177.277193ms
|
||||
SetNX #1 338.637977ms 380.733302ms 411.709204ms
|
||||
Get (hit) 195.896066ms 173.252112ms 234.109494ms
|
||||
SetNX #2 349.785951ms 367.255624ms 419.129127ms
|
|
@ -0,0 +1,69 @@
|
|||
// Copyright (c) 2013 CloudFlare, Inc.
|
||||
|
||||
// Package lrucache implements a last recently used cache data structure.
|
||||
//
|
||||
// This code tries to avoid dynamic memory allocations - all required
|
||||
// memory is allocated on creation. Access to the data structure is
|
||||
// O(1). Modification O(log(n)) if expiry is used, O(1)
|
||||
// otherwise.
|
||||
//
|
||||
// This package exports three things:
|
||||
// LRUCache: is the main implementation. It supports multithreading by
|
||||
// using guarding mutex lock.
|
||||
//
|
||||
// MultiLRUCache: is a sharded implementation. It supports the same
|
||||
// API as LRUCache and uses it internally, but is not limited to
|
||||
// a single CPU as every shard is separately locked. Use this
|
||||
// data structure instead of LRUCache if you have have lock
|
||||
// contention issues.
|
||||
//
|
||||
// Cache interface: Both implementations fulfill it.
|
||||
package lrucache
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Cache interface is fulfilled by the LRUCache and MultiLRUCache
|
||||
// implementations.
|
||||
type Cache interface {
|
||||
// Methods not needing to know current time.
|
||||
//
|
||||
// Get a key from the cache, possibly stale. Update its LRU
|
||||
// score.
|
||||
Get(key string) (value interface{}, ok bool)
|
||||
// Get a key from the cache, possibly stale. Don't modify its LRU score. O(1)
|
||||
GetQuiet(key string) (value interface{}, ok bool)
|
||||
// Get and remove a key from the cache.
|
||||
Del(key string) (value interface{}, ok bool)
|
||||
// Evict all items from the cache.
|
||||
Clear() int
|
||||
// Number of entries used in the LRU
|
||||
Len() int
|
||||
// Get the total capacity of the LRU
|
||||
Capacity() int
|
||||
|
||||
// Methods use time.Now() when neccessary to determine expiry.
|
||||
//
|
||||
// Add an item to the cache overwriting existing one if it
|
||||
// exists.
|
||||
Set(key string, value interface{}, expire time.Time)
|
||||
// Get a key from the cache, make sure it's not stale. Update
|
||||
// its LRU score.
|
||||
GetNotStale(key string) (value interface{}, ok bool)
|
||||
// Evict all the expired items.
|
||||
Expire() int
|
||||
|
||||
// Methods allowing to explicitly specify time used to
|
||||
// determine if items are expired.
|
||||
//
|
||||
// Add an item to the cache overwriting existing one if it
|
||||
// exists. Allows specifing current time required to expire an
|
||||
// item when no more slots are used.
|
||||
SetNow(key string, value interface{}, expire time.Time, now time.Time)
|
||||
// Get a key from the cache, make sure it's not stale. Update
|
||||
// its LRU score.
|
||||
GetNotStaleNow(key string, now time.Time) (value interface{}, ok bool)
|
||||
// Evict items that expire before Now.
|
||||
ExpireNow(now time.Time) int
|
||||
}
|
|
@ -0,0 +1,238 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
/* This file is a slightly modified file from the go package sources
|
||||
and is released on the following license:
|
||||
|
||||
Copyright (c) 2012 The Go Authors. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the following disclaimer
|
||||
in the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
* Neither the name of Google Inc. nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
|
||||
// Package list implements a doubly linked list.
|
||||
//
|
||||
// To iterate over a list (where l is a *List):
|
||||
// for e := l.Front(); e != nil; e = e.Next() {
|
||||
// // do something with e.Value
|
||||
// }
|
||||
//
|
||||
|
||||
package lrucache
|
||||
|
||||
// Element is an element of a linked list.
|
||||
type element struct {
|
||||
// Next and previous pointers in the doubly-linked list of elements.
|
||||
// To simplify the implementation, internally a list l is implemented
|
||||
// as a ring, such that &l.root is both the next element of the last
|
||||
// list element (l.Back()) and the previous element of the first list
|
||||
// element (l.Front()).
|
||||
next, prev *element
|
||||
|
||||
// The list to which this element belongs.
|
||||
list *list
|
||||
|
||||
// The value stored with this element.
|
||||
Value interface{}
|
||||
}
|
||||
|
||||
// Next returns the next list element or nil.
|
||||
func (e *element) Next() *element {
|
||||
if p := e.next; e.list != nil && p != &e.list.root {
|
||||
return p
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Prev returns the previous list element or nil.
|
||||
func (e *element) Prev() *element {
|
||||
if p := e.prev; e.list != nil && p != &e.list.root {
|
||||
return p
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// List represents a doubly linked list.
|
||||
// The zero value for List is an empty list ready to use.
|
||||
type list struct {
|
||||
root element // sentinel list element, only &root, root.prev, and root.next are used
|
||||
len int // current list length excluding (this) sentinel element
|
||||
}
|
||||
|
||||
// Init initializes or clears list l.
|
||||
func (l *list) Init() *list {
|
||||
l.root.next = &l.root
|
||||
l.root.prev = &l.root
|
||||
l.len = 0
|
||||
return l
|
||||
}
|
||||
|
||||
// New returns an initialized list.
|
||||
// func New() *list { return new(list).Init() }
|
||||
|
||||
// Len returns the number of elements of list l.
|
||||
// The complexity is O(1).
|
||||
func (l *list) Len() int { return l.len }
|
||||
|
||||
// Front returns the first element of list l or nil
|
||||
func (l *list) Front() *element {
|
||||
if l.len == 0 {
|
||||
return nil
|
||||
}
|
||||
return l.root.next
|
||||
}
|
||||
|
||||
// Back returns the last element of list l or nil.
|
||||
func (l *list) Back() *element {
|
||||
if l.len == 0 {
|
||||
return nil
|
||||
}
|
||||
return l.root.prev
|
||||
}
|
||||
|
||||
// insert inserts e after at, increments l.len, and returns e.
|
||||
func (l *list) insert(e, at *element) *element {
|
||||
n := at.next
|
||||
at.next = e
|
||||
e.prev = at
|
||||
e.next = n
|
||||
n.prev = e
|
||||
e.list = l
|
||||
l.len++
|
||||
return e
|
||||
}
|
||||
|
||||
// insertValue is a convenience wrapper for insert(&Element{Value: v}, at).
|
||||
func (l *list) insertValue(v interface{}, at *element) *element {
|
||||
return l.insert(&element{Value: v}, at)
|
||||
}
|
||||
|
||||
// remove removes e from its list, decrements l.len, and returns e.
|
||||
func (l *list) remove(e *element) *element {
|
||||
e.prev.next = e.next
|
||||
e.next.prev = e.prev
|
||||
e.next = nil // avoid memory leaks
|
||||
e.prev = nil // avoid memory leaks
|
||||
e.list = nil
|
||||
l.len--
|
||||
return e
|
||||
}
|
||||
|
||||
// Remove removes e from l if e is an element of list l.
|
||||
// It returns the element value e.Value.
|
||||
func (l *list) Remove(e *element) interface{} {
|
||||
if e.list == l {
|
||||
// if e.list == l, l must have been initialized when e was inserted
|
||||
// in l or l == nil (e is a zero Element) and l.remove will crash
|
||||
l.remove(e)
|
||||
}
|
||||
return e.Value
|
||||
}
|
||||
|
||||
// PushFront inserts a new element e with value v at the front of list l and returns e.
|
||||
func (l *list) PushFront(v interface{}) *element {
|
||||
return l.insertValue(v, &l.root)
|
||||
}
|
||||
|
||||
// PushBack inserts a new element e with value v at the back of list l and returns e.
|
||||
func (l *list) PushBack(v interface{}) *element {
|
||||
return l.insertValue(v, l.root.prev)
|
||||
}
|
||||
|
||||
// InsertBefore inserts a new element e with value v immediately before mark and returns e.
|
||||
// If mark is not an element of l, the list is not modified.
|
||||
func (l *list) InsertBefore(v interface{}, mark *element) *element {
|
||||
if mark.list != l {
|
||||
return nil
|
||||
}
|
||||
// see comment in List.Remove about initialization of l
|
||||
return l.insertValue(v, mark.prev)
|
||||
}
|
||||
|
||||
// InsertAfter inserts a new element e with value v immediately after mark and returns e.
|
||||
// If mark is not an element of l, the list is not modified.
|
||||
func (l *list) InsertAfter(v interface{}, mark *element) *element {
|
||||
if mark.list != l {
|
||||
return nil
|
||||
}
|
||||
// see comment in List.Remove about initialization of l
|
||||
return l.insertValue(v, mark)
|
||||
}
|
||||
|
||||
// MoveToFront moves element e to the front of list l.
|
||||
// If e is not an element of l, the list is not modified.
|
||||
func (l *list) MoveToFront(e *element) {
|
||||
if e.list != l || l.root.next == e {
|
||||
return
|
||||
}
|
||||
// see comment in List.Remove about initialization of l
|
||||
l.insert(l.remove(e), &l.root)
|
||||
}
|
||||
|
||||
// MoveToBack moves element e to the back of list l.
|
||||
// If e is not an element of l, the list is not modified.
|
||||
func (l *list) MoveToBack(e *element) {
|
||||
if e.list != l || l.root.prev == e {
|
||||
return
|
||||
}
|
||||
// see comment in List.Remove about initialization of l
|
||||
l.insert(l.remove(e), l.root.prev)
|
||||
}
|
||||
|
||||
// MoveBefore moves element e to its new position before mark.
|
||||
// If e is not an element of l, or e == mark, the list is not modified.
|
||||
func (l *list) MoveBefore(e, mark *element) {
|
||||
if e.list != l || e == mark {
|
||||
return
|
||||
}
|
||||
l.insert(l.remove(e), mark.prev)
|
||||
}
|
||||
|
||||
// MoveAfter moves element e to its new position after mark.
|
||||
// If e is not an element of l, or e == mark, the list is not modified.
|
||||
func (l *list) MoveAfter(e, mark *element) {
|
||||
if e.list != l || e == mark {
|
||||
return
|
||||
}
|
||||
l.insert(l.remove(e), mark)
|
||||
}
|
||||
|
||||
// PushBackList inserts a copy of an other list at the back of list l.
|
||||
// The lists l and other may be the same.
|
||||
func (l *list) PushBackList(other *list) {
|
||||
for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() {
|
||||
l.insertValue(e.Value, l.root.prev)
|
||||
}
|
||||
}
|
||||
|
||||
// PushFrontList inserts a copy of an other list at the front of list l.
|
||||
// The lists l and other may be the same.
|
||||
func (l *list) PushFrontList(other *list) {
|
||||
for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() {
|
||||
l.insertValue(e.Value, &l.root)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
// Copyright (c) 2013 CloudFlare, Inc.
|
||||
|
||||
// Extensions to "container/list" that allowing reuse of Elements.
|
||||
|
||||
package lrucache
|
||||
|
||||
func (l *list) PushElementFront(e *element) *element {
|
||||
return l.insert(e, &l.root)
|
||||
}
|
||||
|
||||
func (l *list) PushElementBack(e *element) *element {
|
||||
return l.insert(e, l.root.prev)
|
||||
}
|
||||
|
||||
func (l *list) PopElementFront() *element {
|
||||
el := l.Front()
|
||||
l.Remove(el)
|
||||
return el
|
||||
}
|
||||
|
||||
func (l *list) PopFront() interface{} {
|
||||
el := l.Front()
|
||||
l.Remove(el)
|
||||
return el.Value
|
||||
}
|
|
@ -0,0 +1,316 @@
|
|||
// Copyright (c) 2013 CloudFlare, Inc.
|
||||
|
||||
package lrucache
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Every element in the cache is linked to three data structures:
|
||||
// Table map, PriorityQueue heap ordered by expiry and a LruList list
|
||||
// ordered by decreasing popularity.
|
||||
type entry struct {
|
||||
element element // list element. value is a pointer to this entry
|
||||
key string // key is a key!
|
||||
value interface{} //
|
||||
expire time.Time // time when the item is expired. it's okay to be stale.
|
||||
index int // index for priority queue needs. -1 if entry is free
|
||||
}
|
||||
|
||||
// LRUCache data structure. Never dereference it or copy it by
|
||||
// value. Always use it through a pointer.
|
||||
type LRUCache struct {
|
||||
lock sync.Mutex
|
||||
table map[string]*entry // all entries in table must be in lruList
|
||||
priorityQueue priorityQueue // some elements from table may be in priorityQueue
|
||||
lruList list // every entry is either used and resides in lruList
|
||||
freeList list // or free and is linked to freeList
|
||||
|
||||
ExpireGracePeriod time.Duration // time after an expired entry is purged from cache (unless pushed out of LRU)
|
||||
}
|
||||
|
||||
// Initialize the LRU cache instance. O(capacity)
|
||||
func (b *LRUCache) Init(capacity uint) {
|
||||
b.table = make(map[string]*entry, capacity)
|
||||
b.priorityQueue = make([]*entry, 0, capacity)
|
||||
b.lruList.Init()
|
||||
b.freeList.Init()
|
||||
heap.Init(&b.priorityQueue)
|
||||
|
||||
// Reserve all the entries in one giant continous block of memory
|
||||
arrayOfEntries := make([]entry, capacity)
|
||||
for i := uint(0); i < capacity; i++ {
|
||||
e := &arrayOfEntries[i]
|
||||
e.element.Value = e
|
||||
e.index = -1
|
||||
b.freeList.PushElementBack(&e.element)
|
||||
}
|
||||
}
|
||||
|
||||
// Create new LRU cache instance. Allocate all the needed memory. O(capacity)
|
||||
func NewLRUCache(capacity uint) *LRUCache {
|
||||
b := &LRUCache{}
|
||||
b.Init(capacity)
|
||||
return b
|
||||
}
|
||||
|
||||
// Give me the entry with lowest expiry field if it's before now.
|
||||
func (b *LRUCache) expiredEntry(now time.Time) *entry {
|
||||
if len(b.priorityQueue) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if now.IsZero() {
|
||||
// Fill it only when actually used.
|
||||
now = time.Now()
|
||||
}
|
||||
|
||||
if e := b.priorityQueue[0]; e.expire.Before(now) {
|
||||
return e
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Give me the least used entry.
|
||||
func (b *LRUCache) leastUsedEntry() *entry {
|
||||
return b.lruList.Back().Value.(*entry)
|
||||
}
|
||||
|
||||
func (b *LRUCache) freeSomeEntry(now time.Time) (e *entry, used bool) {
|
||||
if b.freeList.Len() > 0 {
|
||||
return b.freeList.Front().Value.(*entry), false
|
||||
}
|
||||
|
||||
e = b.expiredEntry(now)
|
||||
if e != nil {
|
||||
return e, true
|
||||
}
|
||||
|
||||
if b.lruList.Len() == 0 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return b.leastUsedEntry(), true
|
||||
}
|
||||
|
||||
// Move entry from used/lru list to a free list. Clear the entry as well.
|
||||
func (b *LRUCache) removeEntry(e *entry) {
|
||||
if e.element.list != &b.lruList {
|
||||
panic("list lruList")
|
||||
}
|
||||
|
||||
if e.index != -1 {
|
||||
heap.Remove(&b.priorityQueue, e.index)
|
||||
}
|
||||
b.lruList.Remove(&e.element)
|
||||
b.freeList.PushElementFront(&e.element)
|
||||
delete(b.table, e.key)
|
||||
e.key = ""
|
||||
e.value = nil
|
||||
}
|
||||
|
||||
func (b *LRUCache) insertEntry(e *entry) {
|
||||
if e.element.list != &b.freeList {
|
||||
panic("list freeList")
|
||||
}
|
||||
|
||||
if !e.expire.IsZero() {
|
||||
heap.Push(&b.priorityQueue, e)
|
||||
}
|
||||
b.freeList.Remove(&e.element)
|
||||
b.lruList.PushElementFront(&e.element)
|
||||
b.table[e.key] = e
|
||||
}
|
||||
|
||||
func (b *LRUCache) touchEntry(e *entry) {
|
||||
b.lruList.MoveToFront(&e.element)
|
||||
}
|
||||
|
||||
// SetNow adds an item to the cache overwriting existing one if it
|
||||
// exists. Allows specifing current time required to expire an item
|
||||
// when no more slots are used. O(log(n)) if expiry is set, O(1) when
|
||||
// clear.
|
||||
func (b *LRUCache) SetNow(key string, value interface{}, expire time.Time, now time.Time) {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
var used bool
|
||||
|
||||
e := b.table[key]
|
||||
if e != nil {
|
||||
used = true
|
||||
} else {
|
||||
e, used = b.freeSomeEntry(now)
|
||||
if e == nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
if used {
|
||||
b.removeEntry(e)
|
||||
}
|
||||
|
||||
e.key = key
|
||||
e.value = value
|
||||
e.expire = expire
|
||||
b.insertEntry(e)
|
||||
}
|
||||
|
||||
// Set adds an item to the cache overwriting existing one if it
|
||||
// exists. O(log(n)) if expiry is set, O(1) when clear.
|
||||
func (b *LRUCache) Set(key string, value interface{}, expire time.Time) {
|
||||
b.SetNow(key, value, expire, time.Time{})
|
||||
}
|
||||
|
||||
// Get a key from the cache, possibly stale. Update its LRU score. O(1)
|
||||
func (b *LRUCache) Get(key string) (v interface{}, ok bool) {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
e := b.table[key]
|
||||
if e == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
b.touchEntry(e)
|
||||
return e.value, true
|
||||
}
|
||||
|
||||
// GetQuiet gets a key from the cache, possibly stale. Don't modify its LRU score. O(1)
|
||||
func (b *LRUCache) GetQuiet(key string) (v interface{}, ok bool) {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
e := b.table[key]
|
||||
if e == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return e.value, true
|
||||
}
|
||||
|
||||
// GetNotStale gets a key from the cache, make sure it's not stale. Update its
|
||||
// LRU score. O(log(n)) if the item is expired.
|
||||
func (b *LRUCache) GetNotStale(key string) (value interface{}, ok bool) {
|
||||
return b.GetNotStaleNow(key, time.Now())
|
||||
}
|
||||
|
||||
// GetNotStaleNow gets a key from the cache, make sure it's not stale. Update its
|
||||
// LRU score. O(log(n)) if the item is expired.
|
||||
func (b *LRUCache) GetNotStaleNow(key string, now time.Time) (value interface{}, ok bool) {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
e := b.table[key]
|
||||
if e == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if e.expire.Before(now) {
|
||||
// Remove entries expired for more than a graceful period
|
||||
if b.ExpireGracePeriod == 0 || e.expire.Sub(now) > b.ExpireGracePeriod {
|
||||
b.removeEntry(e)
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
b.touchEntry(e)
|
||||
return e.value, true
|
||||
}
|
||||
|
||||
// GetStale gets a key from the cache, possibly stale. Update its LRU
|
||||
// score. O(1) always.
|
||||
func (b *LRUCache) GetStale(key string) (value interface{}, ok, expired bool) {
|
||||
return b.GetStaleNow(key, time.Now())
|
||||
}
|
||||
|
||||
// GetStaleNow gets a key from the cache, possibly stale. Update its LRU
|
||||
// score. O(1) always.
|
||||
func (b *LRUCache) GetStaleNow(key string, now time.Time) (value interface{}, ok, expired bool) {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
e := b.table[key]
|
||||
if e == nil {
|
||||
return nil, false, false
|
||||
}
|
||||
|
||||
b.touchEntry(e)
|
||||
return e.value, true, e.expire.Before(now)
|
||||
}
|
||||
|
||||
// Del gets and remove a key from the cache. O(log(n)) if the item is using expiry, O(1) otherwise.
|
||||
func (b *LRUCache) Del(key string) (v interface{}, ok bool) {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
e := b.table[key]
|
||||
if e == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
value := e.value
|
||||
b.removeEntry(e)
|
||||
return value, true
|
||||
}
|
||||
|
||||
// Evict all items from the cache. O(n*log(n))
|
||||
func (b *LRUCache) Clear() int {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
// First, remove entries that have expiry set
|
||||
l := len(b.priorityQueue)
|
||||
for i := 0; i < l; i++ {
|
||||
// This could be reduced to O(n).
|
||||
b.removeEntry(b.priorityQueue[0])
|
||||
}
|
||||
|
||||
// Second, remove all remaining entries
|
||||
r := b.lruList.Len()
|
||||
for i := 0; i < r; i++ {
|
||||
b.removeEntry(b.leastUsedEntry())
|
||||
}
|
||||
return l + r
|
||||
}
|
||||
|
||||
// Evict all the expired items. O(n*log(n))
|
||||
func (b *LRUCache) Expire() int {
|
||||
return b.ExpireNow(time.Now())
|
||||
}
|
||||
|
||||
// Evict items that expire before `now`. O(n*log(n))
|
||||
func (b *LRUCache) ExpireNow(now time.Time) int {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
i := 0
|
||||
for {
|
||||
e := b.expiredEntry(now)
|
||||
if e == nil {
|
||||
break
|
||||
}
|
||||
b.removeEntry(e)
|
||||
i += 1
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
// Number of entries used in the LRU
|
||||
func (b *LRUCache) Len() int {
|
||||
// yes. this stupid thing requires locking
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
return b.lruList.Len()
|
||||
}
|
||||
|
||||
// Capacity gets the total capacity of the LRU
|
||||
func (b *LRUCache) Capacity() int {
|
||||
// yes. this stupid thing requires locking
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
return b.lruList.Len() + b.freeList.Len()
|
||||
}
|
|
@ -0,0 +1,118 @@
|
|||
// Copyright (c) 2013 CloudFlare, Inc.
|
||||
|
||||
package lrucache
|
||||
|
||||
import (
|
||||
"hash/crc32"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MultiLRUCache data structure. Never dereference it or copy it by
|
||||
// value. Always use it through a pointer.
|
||||
type MultiLRUCache struct {
|
||||
buckets uint
|
||||
cache []*LRUCache
|
||||
}
|
||||
|
||||
// Using this constructor is almost always wrong. Use NewMultiLRUCache instead.
|
||||
func (m *MultiLRUCache) Init(buckets, bucket_capacity uint) {
|
||||
m.buckets = buckets
|
||||
m.cache = make([]*LRUCache, buckets)
|
||||
for i := uint(0); i < buckets; i++ {
|
||||
m.cache[i] = NewLRUCache(bucket_capacity)
|
||||
}
|
||||
}
|
||||
|
||||
// Set the stale expiry grace period for each cache in the multicache instance.
|
||||
func (m *MultiLRUCache) SetExpireGracePeriod(p time.Duration) {
|
||||
for _, c := range m.cache {
|
||||
c.ExpireGracePeriod = p
|
||||
}
|
||||
}
|
||||
|
||||
func NewMultiLRUCache(buckets, bucket_capacity uint) *MultiLRUCache {
|
||||
m := &MultiLRUCache{}
|
||||
m.Init(buckets, bucket_capacity)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *MultiLRUCache) bucketNo(key string) uint {
|
||||
// Arbitrary choice. Any fast hash will do.
|
||||
return uint(crc32.ChecksumIEEE([]byte(key))) % m.buckets
|
||||
}
|
||||
|
||||
func (m *MultiLRUCache) Set(key string, value interface{}, expire time.Time) {
|
||||
m.cache[m.bucketNo(key)].Set(key, value, expire)
|
||||
}
|
||||
|
||||
func (m *MultiLRUCache) SetNow(key string, value interface{}, expire time.Time, now time.Time) {
|
||||
m.cache[m.bucketNo(key)].SetNow(key, value, expire, now)
|
||||
}
|
||||
|
||||
func (m *MultiLRUCache) Get(key string) (value interface{}, ok bool) {
|
||||
return m.cache[m.bucketNo(key)].Get(key)
|
||||
}
|
||||
|
||||
func (m *MultiLRUCache) GetQuiet(key string) (value interface{}, ok bool) {
|
||||
return m.cache[m.bucketNo(key)].Get(key)
|
||||
}
|
||||
|
||||
func (m *MultiLRUCache) GetNotStale(key string) (value interface{}, ok bool) {
|
||||
return m.cache[m.bucketNo(key)].GetNotStale(key)
|
||||
}
|
||||
|
||||
func (m *MultiLRUCache) GetNotStaleNow(key string, now time.Time) (value interface{}, ok bool) {
|
||||
return m.cache[m.bucketNo(key)].GetNotStaleNow(key, now)
|
||||
}
|
||||
|
||||
func (m *MultiLRUCache) GetStale(key string) (value interface{}, ok, expired bool) {
|
||||
return m.cache[m.bucketNo(key)].GetStale(key)
|
||||
}
|
||||
|
||||
func (m *MultiLRUCache) GetStaleNow(key string, now time.Time) (value interface{}, ok, expired bool) {
|
||||
return m.cache[m.bucketNo(key)].GetStaleNow(key, now)
|
||||
}
|
||||
|
||||
func (m *MultiLRUCache) Del(key string) (value interface{}, ok bool) {
|
||||
return m.cache[m.bucketNo(key)].Del(key)
|
||||
}
|
||||
|
||||
func (m *MultiLRUCache) Clear() int {
|
||||
var s int
|
||||
for _, c := range m.cache {
|
||||
s += c.Clear()
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (m *MultiLRUCache) Len() int {
|
||||
var s int
|
||||
for _, c := range m.cache {
|
||||
s += c.Len()
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (m *MultiLRUCache) Capacity() int {
|
||||
var s int
|
||||
for _, c := range m.cache {
|
||||
s += c.Capacity()
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (m *MultiLRUCache) Expire() int {
|
||||
var s int
|
||||
for _, c := range m.cache {
|
||||
s += c.Expire()
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (m *MultiLRUCache) ExpireNow(now time.Time) int {
|
||||
var s int
|
||||
for _, c := range m.cache {
|
||||
s += c.ExpireNow(now)
|
||||
}
|
||||
return s
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
// Copyright (c) 2013 CloudFlare, Inc.
|
||||
|
||||
// This code is based on golang example from "container/heap" package.
|
||||
|
||||
package lrucache
|
||||
|
||||
type priorityQueue []*entry
|
||||
|
||||
func (pq priorityQueue) Len() int {
|
||||
return len(pq)
|
||||
}
|
||||
|
||||
func (pq priorityQueue) Less(i, j int) bool {
|
||||
return pq[i].expire.Before(pq[j].expire)
|
||||
}
|
||||
|
||||
func (pq priorityQueue) Swap(i, j int) {
|
||||
pq[i], pq[j] = pq[j], pq[i]
|
||||
pq[i].index = i
|
||||
pq[j].index = j
|
||||
}
|
||||
|
||||
func (pq *priorityQueue) Push(e interface{}) {
|
||||
n := len(*pq)
|
||||
item := e.(*entry)
|
||||
item.index = n
|
||||
*pq = append(*pq, item)
|
||||
}
|
||||
|
||||
func (pq *priorityQueue) Pop() interface{} {
|
||||
old := *pq
|
||||
n := len(old)
|
||||
item := old[n-1]
|
||||
item.index = -1
|
||||
*pq = old[0 : n-1]
|
||||
return item
|
||||
}
|
|
@ -0,0 +1,201 @@
|
|||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "{}"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright {yyyy} {name of copyright owner}
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
|
@ -0,0 +1,105 @@
|
|||
package dnsserver
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/coredns/coredns/plugin"
|
||||
"github.com/coredns/coredns/plugin/pkg/parse"
|
||||
"github.com/coredns/coredns/plugin/pkg/transport"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type zoneAddr struct {
|
||||
Zone string
|
||||
Port string
|
||||
Transport string // dns, tls or grpc
|
||||
IPNet *net.IPNet // if reverse zone this hold the IPNet
|
||||
Address string // used for bound zoneAddr - validation of overlapping
|
||||
}
|
||||
|
||||
// String returns the string representation of z.
|
||||
func (z zoneAddr) String() string {
|
||||
s := z.Transport + "://" + z.Zone + ":" + z.Port
|
||||
if z.Address != "" {
|
||||
s += " on " + z.Address
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// normalizeZone parses a zone string into a structured format with separate
|
||||
// host, and port portions, as well as the original input string.
|
||||
func normalizeZone(str string) (zoneAddr, error) {
|
||||
trans, str := parse.Transport(str)
|
||||
|
||||
host, port, ipnet, err := plugin.SplitHostPort(str)
|
||||
if err != nil {
|
||||
return zoneAddr{}, err
|
||||
}
|
||||
|
||||
if port == "" {
|
||||
switch trans {
|
||||
case transport.DNS:
|
||||
port = Port
|
||||
case transport.TLS:
|
||||
port = transport.TLSPort
|
||||
case transport.GRPC:
|
||||
port = transport.GRPCPort
|
||||
case transport.HTTPS:
|
||||
port = transport.HTTPSPort
|
||||
}
|
||||
}
|
||||
|
||||
return zoneAddr{Zone: dns.Fqdn(host), Port: port, Transport: trans, IPNet: ipnet}, nil
|
||||
}
|
||||
|
||||
// SplitProtocolHostPort splits a full formed address like "dns://[::1]:53" into parts.
|
||||
func SplitProtocolHostPort(address string) (protocol string, ip string, port string, err error) {
|
||||
parts := strings.Split(address, "://")
|
||||
switch len(parts) {
|
||||
case 1:
|
||||
ip, port, err := net.SplitHostPort(parts[0])
|
||||
return "", ip, port, err
|
||||
case 2:
|
||||
ip, port, err := net.SplitHostPort(parts[1])
|
||||
return parts[0], ip, port, err
|
||||
default:
|
||||
return "", "", "", fmt.Errorf("provided value is not in an address format : %s", address)
|
||||
}
|
||||
}
|
||||
|
||||
type zoneOverlap struct {
|
||||
registeredAddr map[zoneAddr]zoneAddr // each zoneAddr is registered once by its key
|
||||
unboundOverlap map[zoneAddr]zoneAddr // the "no bind" equiv ZoneAddr is registered by its original key
|
||||
}
|
||||
|
||||
func newOverlapZone() *zoneOverlap {
|
||||
return &zoneOverlap{registeredAddr: make(map[zoneAddr]zoneAddr), unboundOverlap: make(map[zoneAddr]zoneAddr)}
|
||||
}
|
||||
|
||||
// registerAndCheck adds a new zoneAddr for validation, it returns information about existing or overlapping with already registered
|
||||
// we consider that an unbound address is overlapping all bound addresses for same zone, same port
|
||||
func (zo *zoneOverlap) registerAndCheck(z zoneAddr) (existingZone *zoneAddr, overlappingZone *zoneAddr) {
|
||||
|
||||
if exist, ok := zo.registeredAddr[z]; ok {
|
||||
// exact same zone already registered
|
||||
return &exist, nil
|
||||
}
|
||||
uz := zoneAddr{Zone: z.Zone, Address: "", Port: z.Port, Transport: z.Transport}
|
||||
if already, ok := zo.unboundOverlap[uz]; ok {
|
||||
if z.Address == "" {
|
||||
// current is not bound to an address, but there is already another zone with a bind address registered
|
||||
return nil, &already
|
||||
}
|
||||
if _, ok := zo.registeredAddr[uz]; ok {
|
||||
// current zone is bound to an address, but there is already an overlapping zone+port with no bind address
|
||||
return nil, &uz
|
||||
}
|
||||
}
|
||||
// there is no overlap, keep the current zoneAddr for future checks
|
||||
zo.registeredAddr[z] = z
|
||||
zo.unboundOverlap[uz] = z
|
||||
return nil, nil
|
||||
}
|
|
@ -0,0 +1,73 @@
|
|||
package dnsserver
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
|
||||
"github.com/coredns/coredns/plugin"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
)
|
||||
|
||||
// Config configuration for a single server.
|
||||
type Config struct {
|
||||
// The zone of the site.
|
||||
Zone string
|
||||
|
||||
// one or several hostnames to bind the server to.
|
||||
// defaults to a single empty string that denote the wildcard address
|
||||
ListenHosts []string
|
||||
|
||||
// The port to listen on.
|
||||
Port string
|
||||
|
||||
// Root points to a base directory we find user defined "things".
|
||||
// First consumer is the file plugin to looks for zone files in this place.
|
||||
Root string
|
||||
|
||||
// Debug controls the panic/recover mechanism that is enabled by default.
|
||||
Debug bool
|
||||
|
||||
// The transport we implement, normally just "dns" over TCP/UDP, but could be
|
||||
// DNS-over-TLS or DNS-over-gRPC.
|
||||
Transport string
|
||||
|
||||
// If this function is not nil it will be used to further filter access
|
||||
// to this handler. The primary use is to limit access to a reverse zone
|
||||
// on a non-octet boundary, i.e. /17
|
||||
FilterFunc func(string) bool
|
||||
|
||||
// TLSConfig when listening for encrypted connections (gRPC, DNS-over-TLS).
|
||||
TLSConfig *tls.Config
|
||||
|
||||
// Plugin stack.
|
||||
Plugin []plugin.Plugin
|
||||
|
||||
// Compiled plugin stack.
|
||||
pluginChain plugin.Handler
|
||||
|
||||
// Plugin interested in announcing that they exist, so other plugin can call methods
|
||||
// on them should register themselves here. The name should be the name as return by the
|
||||
// Handler's Name method.
|
||||
registry map[string]plugin.Handler
|
||||
}
|
||||
|
||||
// keyForConfig build a key for identifying the configs during setup time
|
||||
func keyForConfig(blocIndex int, blocKeyIndex int) string {
|
||||
return fmt.Sprintf("%d:%d", blocIndex, blocKeyIndex)
|
||||
}
|
||||
|
||||
// GetConfig gets the Config that corresponds to c.
|
||||
// If none exist nil is returned.
|
||||
func GetConfig(c *caddy.Controller) *Config {
|
||||
ctx := c.Context().(*dnsContext)
|
||||
key := keyForConfig(c.ServerBlockIndex, c.ServerBlockKeyIndex)
|
||||
if cfg, ok := ctx.keysToConfigs[key]; ok {
|
||||
return cfg
|
||||
}
|
||||
// we should only get here during tests because directive
|
||||
// actions typically skip the server blocks where we make
|
||||
// the configs.
|
||||
ctx.saveConfig(key, &Config{ListenHosts: []string{""}})
|
||||
return GetConfig(c)
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
package dnsserver
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/coredns/coredns/plugin/pkg/nonwriter"
|
||||
)
|
||||
|
||||
// DoHWriter is a nonwriter.Writer that adds more specific LocalAddr and RemoteAddr methods.
|
||||
type DoHWriter struct {
|
||||
nonwriter.Writer
|
||||
|
||||
// raddr is the remote's address. This can be optionally set.
|
||||
raddr net.Addr
|
||||
// laddr is our address. This can be optionally set.
|
||||
laddr net.Addr
|
||||
}
|
||||
|
||||
// RemoteAddr returns the remote address.
|
||||
func (d *DoHWriter) RemoteAddr() net.Addr { return d.raddr }
|
||||
|
||||
// LocalAddr returns the local address.
|
||||
func (d *DoHWriter) LocalAddr() net.Addr { return d.laddr }
|
|
@ -0,0 +1,33 @@
|
|||
// +build go1.11
|
||||
// +build aix darwin dragonfly freebsd linux netbsd openbsd
|
||||
|
||||
package dnsserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"syscall"
|
||||
|
||||
"github.com/coredns/coredns/plugin/pkg/log"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func reuseportControl(network, address string, c syscall.RawConn) error {
|
||||
c.Control(func(fd uintptr) {
|
||||
if err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
|
||||
log.Warningf("Failed to set SO_REUSEPORT on socket: %s", err)
|
||||
}
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func listen(network, addr string) (net.Listener, error) {
|
||||
lc := net.ListenConfig{Control: reuseportControl}
|
||||
return lc.Listen(context.Background(), network, addr)
|
||||
}
|
||||
|
||||
func listenPacket(network, addr string) (net.PacketConn, error) {
|
||||
lc := net.ListenConfig{Control: reuseportControl}
|
||||
return lc.ListenPacket(context.Background(), network, addr)
|
||||
}
|
11
vendor/github.com/coredns/coredns/core/dnsserver/listen_go_not111.go
generated
vendored
Normal file
11
vendor/github.com/coredns/coredns/core/dnsserver/listen_go_not111.go
generated
vendored
Normal file
|
@ -0,0 +1,11 @@
|
|||
// +build !go1.11 !aix,!darwin,!dragonfly,!freebsd,!linux,!netbsd,!openbsd
|
||||
|
||||
package dnsserver
|
||||
|
||||
import "net"
|
||||
|
||||
func listen(network, addr string) (net.Listener, error) { return net.Listen(network, addr) }
|
||||
|
||||
func listenPacket(network, addr string) (net.PacketConn, error) {
|
||||
return net.ListenPacket(network, addr)
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
package dnsserver
|
||||
|
||||
import "fmt"
|
||||
|
||||
// startUpZones create the text that we show when starting up:
|
||||
// grpc://example.com.:1055
|
||||
// example.com.:1053 on 127.0.0.1
|
||||
func startUpZones(protocol, addr string, zones map[string]*Config) string {
|
||||
s := ""
|
||||
|
||||
for zone := range zones {
|
||||
// split addr into protocol, IP and Port
|
||||
_, ip, port, err := SplitProtocolHostPort(addr)
|
||||
|
||||
if err != nil {
|
||||
// this should not happen, but we need to take care of it anyway
|
||||
s += fmt.Sprintln(protocol + zone + ":" + addr)
|
||||
continue
|
||||
}
|
||||
if ip == "" {
|
||||
s += fmt.Sprintln(protocol + zone + ":" + port)
|
||||
continue
|
||||
}
|
||||
// if the server is listening on a specific address let's make it visible in the log,
|
||||
// so one can differentiate between all active listeners
|
||||
s += fmt.Sprintln(protocol + zone + ":" + port + " on " + ip)
|
||||
}
|
||||
return s
|
||||
}
|
|
@ -0,0 +1,252 @@
|
|||
package dnsserver
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/coredns/coredns/plugin"
|
||||
"github.com/coredns/coredns/plugin/pkg/dnsutil"
|
||||
"github.com/coredns/coredns/plugin/pkg/parse"
|
||||
"github.com/coredns/coredns/plugin/pkg/transport"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/mholt/caddy/caddyfile"
|
||||
)
|
||||
|
||||
const serverType = "dns"
|
||||
|
||||
// Any flags defined here, need to be namespaced to the serverType other
|
||||
// wise they potentially clash with other server types.
|
||||
func init() {
|
||||
flag.StringVar(&Port, serverType+".port", DefaultPort, "Default port")
|
||||
|
||||
caddy.RegisterServerType(serverType, caddy.ServerType{
|
||||
Directives: func() []string { return Directives },
|
||||
DefaultInput: func() caddy.Input {
|
||||
return caddy.CaddyfileInput{
|
||||
Filepath: "Corefile",
|
||||
Contents: []byte(".:" + Port + " {\nwhoami\n}\n"),
|
||||
ServerTypeName: serverType,
|
||||
}
|
||||
},
|
||||
NewContext: newContext,
|
||||
})
|
||||
}
|
||||
|
||||
func newContext(i *caddy.Instance) caddy.Context {
|
||||
return &dnsContext{keysToConfigs: make(map[string]*Config)}
|
||||
}
|
||||
|
||||
type dnsContext struct {
|
||||
keysToConfigs map[string]*Config
|
||||
|
||||
// configs is the master list of all site configs.
|
||||
configs []*Config
|
||||
}
|
||||
|
||||
func (h *dnsContext) saveConfig(key string, cfg *Config) {
|
||||
h.configs = append(h.configs, cfg)
|
||||
h.keysToConfigs[key] = cfg
|
||||
}
|
||||
|
||||
// InspectServerBlocks make sure that everything checks out before
|
||||
// executing directives and otherwise prepares the directives to
|
||||
// be parsed and executed.
|
||||
func (h *dnsContext) InspectServerBlocks(sourceFile string, serverBlocks []caddyfile.ServerBlock) ([]caddyfile.ServerBlock, error) {
|
||||
// Normalize and check all the zone names and check for duplicates
|
||||
for ib, s := range serverBlocks {
|
||||
for ik, k := range s.Keys {
|
||||
za, err := normalizeZone(k)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.Keys[ik] = za.String()
|
||||
// Save the config to our master list, and key it for lookups.
|
||||
cfg := &Config{
|
||||
Zone: za.Zone,
|
||||
ListenHosts: []string{""},
|
||||
Port: za.Port,
|
||||
Transport: za.Transport,
|
||||
}
|
||||
keyConfig := keyForConfig(ib, ik)
|
||||
if za.IPNet == nil {
|
||||
h.saveConfig(keyConfig, cfg)
|
||||
continue
|
||||
}
|
||||
|
||||
ones, bits := za.IPNet.Mask.Size()
|
||||
if (bits-ones)%8 != 0 { // only do this for non-octet boundaries
|
||||
cfg.FilterFunc = func(s string) bool {
|
||||
// TODO(miek): strings.ToLower! Slow and allocates new string.
|
||||
addr := dnsutil.ExtractAddressFromReverse(strings.ToLower(s))
|
||||
if addr == "" {
|
||||
return true
|
||||
}
|
||||
return za.IPNet.Contains(net.ParseIP(addr))
|
||||
}
|
||||
}
|
||||
h.saveConfig(keyConfig, cfg)
|
||||
}
|
||||
}
|
||||
return serverBlocks, nil
|
||||
}
|
||||
|
||||
// MakeServers uses the newly-created siteConfigs to create and return a list of server instances.
|
||||
func (h *dnsContext) MakeServers() ([]caddy.Server, error) {
|
||||
|
||||
// Now that all Keys and Directives are parsed and initialized
|
||||
// lets verify that there is no overlap on the zones and addresses to listen for
|
||||
errValid := h.validateZonesAndListeningAddresses()
|
||||
if errValid != nil {
|
||||
return nil, errValid
|
||||
}
|
||||
|
||||
// we must map (group) each config to a bind address
|
||||
groups, err := groupConfigsByListenAddr(h.configs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// then we create a server for each group
|
||||
var servers []caddy.Server
|
||||
for addr, group := range groups {
|
||||
// switch on addr
|
||||
switch tr, _ := parse.Transport(addr); tr {
|
||||
case transport.DNS:
|
||||
s, err := NewServer(addr, group)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
servers = append(servers, s)
|
||||
|
||||
case transport.TLS:
|
||||
s, err := NewServerTLS(addr, group)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
servers = append(servers, s)
|
||||
|
||||
case transport.GRPC:
|
||||
s, err := NewServergRPC(addr, group)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
servers = append(servers, s)
|
||||
|
||||
case transport.HTTPS:
|
||||
s, err := NewServerHTTPS(addr, group)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
servers = append(servers, s)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return servers, nil
|
||||
}
|
||||
|
||||
// AddPlugin adds a plugin to a site's plugin stack.
|
||||
func (c *Config) AddPlugin(m plugin.Plugin) {
|
||||
c.Plugin = append(c.Plugin, m)
|
||||
}
|
||||
|
||||
// registerHandler adds a handler to a site's handler registration. Handlers
|
||||
// use this to announce that they exist to other plugin.
|
||||
func (c *Config) registerHandler(h plugin.Handler) {
|
||||
if c.registry == nil {
|
||||
c.registry = make(map[string]plugin.Handler)
|
||||
}
|
||||
|
||||
// Just overwrite...
|
||||
c.registry[h.Name()] = h
|
||||
}
|
||||
|
||||
// Handler returns the plugin handler that has been added to the config under its name.
|
||||
// This is useful to inspect if a certain plugin is active in this server.
|
||||
// Note that this is order dependent and the order is defined in directives.go, i.e. if your plugin
|
||||
// comes before the plugin you are checking; it will not be there (yet).
|
||||
func (c *Config) Handler(name string) plugin.Handler {
|
||||
if c.registry == nil {
|
||||
return nil
|
||||
}
|
||||
if h, ok := c.registry[name]; ok {
|
||||
return h
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handlers returns a slice of plugins that have been registered. This can be used to
|
||||
// inspect and interact with registered plugins but cannot be used to remove or add plugins.
|
||||
// Note that this is order dependent and the order is defined in directives.go, i.e. if your plugin
|
||||
// comes before the plugin you are checking; it will not be there (yet).
|
||||
func (c *Config) Handlers() []plugin.Handler {
|
||||
if c.registry == nil {
|
||||
return nil
|
||||
}
|
||||
hs := make([]plugin.Handler, 0, len(c.registry))
|
||||
for k := range c.registry {
|
||||
hs = append(hs, c.registry[k])
|
||||
}
|
||||
return hs
|
||||
}
|
||||
|
||||
func (h *dnsContext) validateZonesAndListeningAddresses() error {
|
||||
//Validate Zone and addresses
|
||||
checker := newOverlapZone()
|
||||
for _, conf := range h.configs {
|
||||
for _, h := range conf.ListenHosts {
|
||||
// Validate the overlapping of ZoneAddr
|
||||
akey := zoneAddr{Transport: conf.Transport, Zone: conf.Zone, Address: h, Port: conf.Port}
|
||||
existZone, overlapZone := checker.registerAndCheck(akey)
|
||||
if existZone != nil {
|
||||
return fmt.Errorf("cannot serve %s - it is already defined", akey.String())
|
||||
}
|
||||
if overlapZone != nil {
|
||||
return fmt.Errorf("cannot serve %s - zone overlap listener capacity with %v", akey.String(), overlapZone.String())
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
// groupSiteConfigsByListenAddr groups site configs by their listen
|
||||
// (bind) address, so sites that use the same listener can be served
|
||||
// on the same server instance. The return value maps the listen
|
||||
// address (what you pass into net.Listen) to the list of site configs.
|
||||
// This function does NOT vet the configs to ensure they are compatible.
|
||||
func groupConfigsByListenAddr(configs []*Config) (map[string][]*Config, error) {
|
||||
|
||||
groups := make(map[string][]*Config)
|
||||
for _, conf := range configs {
|
||||
for _, h := range conf.ListenHosts {
|
||||
addr, err := net.ResolveTCPAddr("tcp", net.JoinHostPort(h, conf.Port))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
addrstr := conf.Transport + "://" + addr.String()
|
||||
groups[addrstr] = append(groups[addrstr], conf)
|
||||
}
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
// DefaultPort is the default port.
|
||||
const DefaultPort = transport.Port
|
||||
|
||||
// These "soft defaults" are configurable by
|
||||
// command line flags, etc.
|
||||
var (
|
||||
// Port is the port we listen on by default.
|
||||
Port = DefaultPort
|
||||
|
||||
// GracefulTimeout is the maximum duration of a graceful shutdown.
|
||||
GracefulTimeout time.Duration
|
||||
)
|
||||
|
||||
var _ caddy.GracefulServer = new(Server)
|
|
@ -0,0 +1,355 @@
|
|||
// Package dnsserver implements all the interfaces from Caddy, so that CoreDNS can be a servertype plugin.
|
||||
package dnsserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/coredns/coredns/plugin"
|
||||
"github.com/coredns/coredns/plugin/metrics/vars"
|
||||
"github.com/coredns/coredns/plugin/pkg/edns"
|
||||
"github.com/coredns/coredns/plugin/pkg/log"
|
||||
"github.com/coredns/coredns/plugin/pkg/rcode"
|
||||
"github.com/coredns/coredns/plugin/pkg/trace"
|
||||
"github.com/coredns/coredns/plugin/pkg/transport"
|
||||
"github.com/coredns/coredns/request"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
ot "github.com/opentracing/opentracing-go"
|
||||
)
|
||||
|
||||
// Server represents an instance of a server, which serves
|
||||
// DNS requests at a particular address (host and port). A
|
||||
// server is capable of serving numerous zones on
|
||||
// the same address and the listener may be stopped for
|
||||
// graceful termination (POSIX only).
|
||||
type Server struct {
|
||||
Addr string // Address we listen on
|
||||
|
||||
server [2]*dns.Server // 0 is a net.Listener, 1 is a net.PacketConn (a *UDPConn) in our case.
|
||||
m sync.Mutex // protects the servers
|
||||
|
||||
zones map[string]*Config // zones keyed by their address
|
||||
dnsWg sync.WaitGroup // used to wait on outstanding connections
|
||||
graceTimeout time.Duration // the maximum duration of a graceful shutdown
|
||||
trace trace.Trace // the trace plugin for the server
|
||||
debug bool // disable recover()
|
||||
classChaos bool // allow non-INET class queries
|
||||
}
|
||||
|
||||
// NewServer returns a new CoreDNS server and compiles all plugins in to it. By default CH class
|
||||
// queries are blocked unless queries from enableChaos are loaded.
|
||||
func NewServer(addr string, group []*Config) (*Server, error) {
|
||||
|
||||
s := &Server{
|
||||
Addr: addr,
|
||||
zones: make(map[string]*Config),
|
||||
graceTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
// We have to bound our wg with one increment
|
||||
// to prevent a "race condition" that is hard-coded
|
||||
// into sync.WaitGroup.Wait() - basically, an add
|
||||
// with a positive delta must be guaranteed to
|
||||
// occur before Wait() is called on the wg.
|
||||
// In a way, this kind of acts as a safety barrier.
|
||||
s.dnsWg.Add(1)
|
||||
|
||||
for _, site := range group {
|
||||
if site.Debug {
|
||||
s.debug = true
|
||||
log.D = true
|
||||
}
|
||||
// set the config per zone
|
||||
s.zones[site.Zone] = site
|
||||
|
||||
// compile custom plugin for everything
|
||||
var stack plugin.Handler
|
||||
for i := len(site.Plugin) - 1; i >= 0; i-- {
|
||||
stack = site.Plugin[i](stack)
|
||||
|
||||
// register the *handler* also
|
||||
site.registerHandler(stack)
|
||||
|
||||
if s.trace == nil && stack.Name() == "trace" {
|
||||
// we have to stash away the plugin, not the
|
||||
// Tracer object, because the Tracer won't be initialized yet
|
||||
if t, ok := stack.(trace.Trace); ok {
|
||||
s.trace = t
|
||||
}
|
||||
}
|
||||
// Unblock CH class queries when any of these plugins are loaded.
|
||||
if _, ok := EnableChaos[stack.Name()]; ok {
|
||||
s.classChaos = true
|
||||
}
|
||||
}
|
||||
site.pluginChain = stack
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Serve starts the server with an existing listener. It blocks until the server stops.
|
||||
// This implements caddy.TCPServer interface.
|
||||
func (s *Server) Serve(l net.Listener) error {
|
||||
s.m.Lock()
|
||||
s.server[tcp] = &dns.Server{Listener: l, Net: "tcp", Handler: dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
ctx := context.WithValue(context.Background(), Key{}, s)
|
||||
s.ServeDNS(ctx, w, r)
|
||||
})}
|
||||
s.m.Unlock()
|
||||
|
||||
return s.server[tcp].ActivateAndServe()
|
||||
}
|
||||
|
||||
// ServePacket starts the server with an existing packetconn. It blocks until the server stops.
|
||||
// This implements caddy.UDPServer interface.
|
||||
func (s *Server) ServePacket(p net.PacketConn) error {
|
||||
s.m.Lock()
|
||||
s.server[udp] = &dns.Server{PacketConn: p, Net: "udp", Handler: dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
ctx := context.WithValue(context.Background(), Key{}, s)
|
||||
s.ServeDNS(ctx, w, r)
|
||||
})}
|
||||
s.m.Unlock()
|
||||
|
||||
return s.server[udp].ActivateAndServe()
|
||||
}
|
||||
|
||||
// Listen implements caddy.TCPServer interface.
|
||||
func (s *Server) Listen() (net.Listener, error) {
|
||||
l, err := listen("tcp", s.Addr[len(transport.DNS+"://"):])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return l, nil
|
||||
}
|
||||
|
||||
// WrapListener Listen implements caddy.GracefulServer interface.
|
||||
func (s *Server) WrapListener(ln net.Listener) net.Listener {
|
||||
return ln
|
||||
}
|
||||
|
||||
// ListenPacket implements caddy.UDPServer interface.
|
||||
func (s *Server) ListenPacket() (net.PacketConn, error) {
|
||||
p, err := listenPacket("udp", s.Addr[len(transport.DNS+"://"):])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// Stop stops the server. It blocks until the server is
|
||||
// totally stopped. On POSIX systems, it will wait for
|
||||
// connections to close (up to a max timeout of a few
|
||||
// seconds); on Windows it will close the listener
|
||||
// immediately.
|
||||
// This implements Caddy.Stopper interface.
|
||||
func (s *Server) Stop() (err error) {
|
||||
|
||||
if runtime.GOOS != "windows" {
|
||||
// force connections to close after timeout
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
s.dnsWg.Done() // decrement our initial increment used as a barrier
|
||||
s.dnsWg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Wait for remaining connections to finish or
|
||||
// force them all to close after timeout
|
||||
select {
|
||||
case <-time.After(s.graceTimeout):
|
||||
case <-done:
|
||||
}
|
||||
}
|
||||
|
||||
// Close the listener now; this stops the server without delay
|
||||
s.m.Lock()
|
||||
for _, s1 := range s.server {
|
||||
// We might not have started and initialized the full set of servers
|
||||
if s1 != nil {
|
||||
err = s1.Shutdown()
|
||||
}
|
||||
}
|
||||
s.m.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// Address together with Stop() implement caddy.GracefulServer.
|
||||
func (s *Server) Address() string { return s.Addr }
|
||||
|
||||
// ServeDNS is the entry point for every request to the address that s
|
||||
// is bound to. It acts as a multiplexer for the requests zonename as
|
||||
// defined in the request so that the correct zone
|
||||
// (configuration and plugin stack) will handle the request.
|
||||
func (s *Server) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) {
|
||||
// The default dns.Mux checks the question section size, but we have our
|
||||
// own mux here. Check if we have a question section. If not drop them here.
|
||||
if r == nil || len(r.Question) == 0 {
|
||||
errorAndMetricsFunc(s.Addr, w, r, dns.RcodeServerFailure)
|
||||
return
|
||||
}
|
||||
|
||||
if !s.debug {
|
||||
defer func() {
|
||||
// In case the user doesn't enable error plugin, we still
|
||||
// need to make sure that we stay alive up here
|
||||
if rec := recover(); rec != nil {
|
||||
vars.Panic.Inc()
|
||||
errorAndMetricsFunc(s.Addr, w, r, dns.RcodeServerFailure)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if !s.classChaos && r.Question[0].Qclass != dns.ClassINET {
|
||||
errorAndMetricsFunc(s.Addr, w, r, dns.RcodeRefused)
|
||||
return
|
||||
}
|
||||
|
||||
if m, err := edns.Version(r); err != nil { // Wrong EDNS version, return at once.
|
||||
w.WriteMsg(m)
|
||||
return
|
||||
}
|
||||
|
||||
q := r.Question[0].Name
|
||||
b := make([]byte, len(q))
|
||||
var off int
|
||||
var end bool
|
||||
|
||||
var dshandler *Config
|
||||
|
||||
// Wrap the response writer in a ScrubWriter so we automatically make the reply fit in the client's buffer.
|
||||
w = request.NewScrubWriter(r, w)
|
||||
|
||||
for {
|
||||
l := len(q[off:])
|
||||
for i := 0; i < l; i++ {
|
||||
b[i] = q[off+i]
|
||||
// normalize the name for the lookup
|
||||
if b[i] >= 'A' && b[i] <= 'Z' {
|
||||
b[i] |= ('a' - 'A')
|
||||
}
|
||||
}
|
||||
|
||||
if h, ok := s.zones[string(b[:l])]; ok {
|
||||
if r.Question[0].Qtype != dns.TypeDS {
|
||||
if h.FilterFunc == nil {
|
||||
rcode, _ := h.pluginChain.ServeDNS(ctx, w, r)
|
||||
if !plugin.ClientWrite(rcode) {
|
||||
errorFunc(s.Addr, w, r, rcode)
|
||||
}
|
||||
return
|
||||
}
|
||||
// FilterFunc is set, call it to see if we should use this handler.
|
||||
// This is given to full query name.
|
||||
if h.FilterFunc(q) {
|
||||
rcode, _ := h.pluginChain.ServeDNS(ctx, w, r)
|
||||
if !plugin.ClientWrite(rcode) {
|
||||
errorFunc(s.Addr, w, r, rcode)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
// The type is DS, keep the handler, but keep on searching as maybe we are serving
|
||||
// the parent as well and the DS should be routed to it - this will probably *misroute* DS
|
||||
// queries to a possibly grand parent, but there is no way for us to know at this point
|
||||
// if there is an actually delegation from grandparent -> parent -> zone.
|
||||
// In all fairness: direct DS queries should not be needed.
|
||||
dshandler = h
|
||||
}
|
||||
off, end = dns.NextLabel(q, off)
|
||||
if end {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if r.Question[0].Qtype == dns.TypeDS && dshandler != nil && dshandler.pluginChain != nil {
|
||||
// DS request, and we found a zone, use the handler for the query.
|
||||
rcode, _ := dshandler.pluginChain.ServeDNS(ctx, w, r)
|
||||
if !plugin.ClientWrite(rcode) {
|
||||
errorFunc(s.Addr, w, r, rcode)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Wildcard match, if we have found nothing try the root zone as a last resort.
|
||||
if h, ok := s.zones["."]; ok && h.pluginChain != nil {
|
||||
rcode, _ := h.pluginChain.ServeDNS(ctx, w, r)
|
||||
if !plugin.ClientWrite(rcode) {
|
||||
errorFunc(s.Addr, w, r, rcode)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Still here? Error out with REFUSED.
|
||||
errorAndMetricsFunc(s.Addr, w, r, dns.RcodeRefused)
|
||||
}
|
||||
|
||||
// OnStartupComplete lists the sites served by this server
|
||||
// and any relevant information, assuming Quiet is false.
|
||||
func (s *Server) OnStartupComplete() {
|
||||
if Quiet {
|
||||
return
|
||||
}
|
||||
|
||||
out := startUpZones("", s.Addr, s.zones)
|
||||
if out != "" {
|
||||
fmt.Print(out)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Tracer returns the tracer in the server if defined.
|
||||
func (s *Server) Tracer() ot.Tracer {
|
||||
if s.trace == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return s.trace.Tracer()
|
||||
}
|
||||
|
||||
// errorFunc responds to an DNS request with an error.
|
||||
func errorFunc(server string, w dns.ResponseWriter, r *dns.Msg, rc int) {
|
||||
state := request.Request{W: w, Req: r}
|
||||
|
||||
answer := new(dns.Msg)
|
||||
answer.SetRcode(r, rc)
|
||||
state.SizeAndDo(answer)
|
||||
|
||||
w.WriteMsg(answer)
|
||||
}
|
||||
|
||||
func errorAndMetricsFunc(server string, w dns.ResponseWriter, r *dns.Msg, rc int) {
|
||||
state := request.Request{W: w, Req: r}
|
||||
|
||||
answer := new(dns.Msg)
|
||||
answer.SetRcode(r, rc)
|
||||
state.SizeAndDo(answer)
|
||||
|
||||
vars.Report(server, state, vars.Dropped, rcode.ToString(rc), answer.Len(), time.Now())
|
||||
|
||||
w.WriteMsg(answer)
|
||||
}
|
||||
|
||||
const (
|
||||
tcp = 0
|
||||
udp = 1
|
||||
)
|
||||
|
||||
// Key is the context key for the current server added to the context.
|
||||
type Key struct{}
|
||||
|
||||
// EnableChaos is a map with plugin names for which we should open CH class queries as we block these by default.
|
||||
var EnableChaos = map[string]struct{}{
|
||||
"chaos": struct{}{},
|
||||
"forward": struct{}{},
|
||||
"proxy": struct{}{},
|
||||
}
|
||||
|
||||
// Quiet mode will not show any informative output on initialization.
|
||||
var Quiet bool
|
|
@ -0,0 +1,171 @@
|
|||
package dnsserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/coredns/coredns/pb"
|
||||
"github.com/coredns/coredns/plugin/pkg/transport"
|
||||
|
||||
"github.com/grpc-ecosystem/grpc-opentracing/go/otgrpc"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/opentracing/opentracing-go"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/peer"
|
||||
)
|
||||
|
||||
// ServergRPC represents an instance of a DNS-over-gRPC server.
|
||||
type ServergRPC struct {
|
||||
*Server
|
||||
grpcServer *grpc.Server
|
||||
listenAddr net.Addr
|
||||
tlsConfig *tls.Config
|
||||
}
|
||||
|
||||
// NewServergRPC returns a new CoreDNS GRPC server and compiles all plugin in to it.
|
||||
func NewServergRPC(addr string, group []*Config) (*ServergRPC, error) {
|
||||
s, err := NewServer(addr, group)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// The *tls* plugin must make sure that multiple conflicting
|
||||
// TLS configuration return an error: it can only be specified once.
|
||||
var tlsConfig *tls.Config
|
||||
for _, conf := range s.zones {
|
||||
// Should we error if some configs *don't* have TLS?
|
||||
tlsConfig = conf.TLSConfig
|
||||
}
|
||||
|
||||
return &ServergRPC{Server: s, tlsConfig: tlsConfig}, nil
|
||||
}
|
||||
|
||||
// Serve implements caddy.TCPServer interface.
|
||||
func (s *ServergRPC) Serve(l net.Listener) error {
|
||||
s.m.Lock()
|
||||
s.listenAddr = l.Addr()
|
||||
s.m.Unlock()
|
||||
|
||||
if s.Tracer() != nil {
|
||||
onlyIfParent := func(parentSpanCtx opentracing.SpanContext, method string, req, resp interface{}) bool {
|
||||
return parentSpanCtx != nil
|
||||
}
|
||||
intercept := otgrpc.OpenTracingServerInterceptor(s.Tracer(), otgrpc.IncludingSpans(onlyIfParent))
|
||||
s.grpcServer = grpc.NewServer(grpc.UnaryInterceptor(intercept))
|
||||
} else {
|
||||
s.grpcServer = grpc.NewServer()
|
||||
}
|
||||
|
||||
pb.RegisterDnsServiceServer(s.grpcServer, s)
|
||||
|
||||
if s.tlsConfig != nil {
|
||||
l = tls.NewListener(l, s.tlsConfig)
|
||||
}
|
||||
return s.grpcServer.Serve(l)
|
||||
}
|
||||
|
||||
// ServePacket implements caddy.UDPServer interface.
|
||||
func (s *ServergRPC) ServePacket(p net.PacketConn) error { return nil }
|
||||
|
||||
// Listen implements caddy.TCPServer interface.
|
||||
func (s *ServergRPC) Listen() (net.Listener, error) {
|
||||
|
||||
l, err := net.Listen("tcp", s.Addr[len(transport.GRPC+"://"):])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return l, nil
|
||||
}
|
||||
|
||||
// ListenPacket implements caddy.UDPServer interface.
|
||||
func (s *ServergRPC) ListenPacket() (net.PacketConn, error) { return nil, nil }
|
||||
|
||||
// OnStartupComplete lists the sites served by this server
|
||||
// and any relevant information, assuming Quiet is false.
|
||||
func (s *ServergRPC) OnStartupComplete() {
|
||||
if Quiet {
|
||||
return
|
||||
}
|
||||
|
||||
out := startUpZones(transport.GRPC+"://", s.Addr, s.zones)
|
||||
if out != "" {
|
||||
fmt.Print(out)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Stop stops the server. It blocks until the server is
|
||||
// totally stopped.
|
||||
func (s *ServergRPC) Stop() (err error) {
|
||||
s.m.Lock()
|
||||
defer s.m.Unlock()
|
||||
if s.grpcServer != nil {
|
||||
s.grpcServer.GracefulStop()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Query is the main entry-point into the gRPC server. From here we call ServeDNS like
|
||||
// any normal server. We use a custom responseWriter to pick up the bytes we need to write
|
||||
// back to the client as a protobuf.
|
||||
func (s *ServergRPC) Query(ctx context.Context, in *pb.DnsPacket) (*pb.DnsPacket, error) {
|
||||
msg := new(dns.Msg)
|
||||
err := msg.Unpack(in.Msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p, ok := peer.FromContext(ctx)
|
||||
if !ok {
|
||||
return nil, errors.New("no peer in gRPC context")
|
||||
}
|
||||
|
||||
a, ok := p.Addr.(*net.TCPAddr)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no TCP peer in gRPC context: %v", p.Addr)
|
||||
}
|
||||
|
||||
w := &gRPCresponse{localAddr: s.listenAddr, remoteAddr: a, Msg: msg}
|
||||
|
||||
s.ServeDNS(ctx, w, msg)
|
||||
|
||||
packed, err := w.Msg.Pack()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &pb.DnsPacket{Msg: packed}, nil
|
||||
}
|
||||
|
||||
// Shutdown stops the server (non gracefully).
|
||||
func (s *ServergRPC) Shutdown() error {
|
||||
if s.grpcServer != nil {
|
||||
s.grpcServer.Stop()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type gRPCresponse struct {
|
||||
localAddr net.Addr
|
||||
remoteAddr net.Addr
|
||||
Msg *dns.Msg
|
||||
}
|
||||
|
||||
// Write is the hack that makes this work. It does not actually write the message
|
||||
// but returns the bytes we need to write in r. We can then pick this up in Query
|
||||
// and write a proper protobuf back to the client.
|
||||
func (r *gRPCresponse) Write(b []byte) (int, error) {
|
||||
r.Msg = new(dns.Msg)
|
||||
return len(b), r.Msg.Unpack(b)
|
||||
}
|
||||
|
||||
// These methods implement the dns.ResponseWriter interface from Go DNS.
|
||||
func (r *gRPCresponse) Close() error { return nil }
|
||||
func (r *gRPCresponse) TsigStatus() error { return nil }
|
||||
func (r *gRPCresponse) TsigTimersOnly(b bool) { return }
|
||||
func (r *gRPCresponse) Hijack() { return }
|
||||
func (r *gRPCresponse) LocalAddr() net.Addr { return r.localAddr }
|
||||
func (r *gRPCresponse) RemoteAddr() net.Addr { return r.remoteAddr }
|
||||
func (r *gRPCresponse) WriteMsg(m *dns.Msg) error { r.Msg = m; return nil }
|
|
@ -0,0 +1,149 @@
|
|||
package dnsserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/coredns/coredns/plugin/pkg/dnsutil"
|
||||
"github.com/coredns/coredns/plugin/pkg/doh"
|
||||
"github.com/coredns/coredns/plugin/pkg/response"
|
||||
"github.com/coredns/coredns/plugin/pkg/transport"
|
||||
)
|
||||
|
||||
// ServerHTTPS represents an instance of a DNS-over-HTTPS server.
|
||||
type ServerHTTPS struct {
|
||||
*Server
|
||||
httpsServer *http.Server
|
||||
listenAddr net.Addr
|
||||
tlsConfig *tls.Config
|
||||
}
|
||||
|
||||
// NewServerHTTPS returns a new CoreDNS GRPC server and compiles all plugins in to it.
|
||||
func NewServerHTTPS(addr string, group []*Config) (*ServerHTTPS, error) {
|
||||
s, err := NewServer(addr, group)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// The *tls* plugin must make sure that multiple conflicting
|
||||
// TLS configuration return an error: it can only be specified once.
|
||||
var tlsConfig *tls.Config
|
||||
for _, conf := range s.zones {
|
||||
// Should we error if some configs *don't* have TLS?
|
||||
tlsConfig = conf.TLSConfig
|
||||
}
|
||||
|
||||
sh := &ServerHTTPS{Server: s, tlsConfig: tlsConfig, httpsServer: new(http.Server)}
|
||||
sh.httpsServer.Handler = sh
|
||||
|
||||
return sh, nil
|
||||
}
|
||||
|
||||
// Serve implements caddy.TCPServer interface.
|
||||
func (s *ServerHTTPS) Serve(l net.Listener) error {
|
||||
s.m.Lock()
|
||||
s.listenAddr = l.Addr()
|
||||
s.m.Unlock()
|
||||
|
||||
if s.tlsConfig != nil {
|
||||
l = tls.NewListener(l, s.tlsConfig)
|
||||
}
|
||||
return s.httpsServer.Serve(l)
|
||||
}
|
||||
|
||||
// ServePacket implements caddy.UDPServer interface.
|
||||
func (s *ServerHTTPS) ServePacket(p net.PacketConn) error { return nil }
|
||||
|
||||
// Listen implements caddy.TCPServer interface.
|
||||
func (s *ServerHTTPS) Listen() (net.Listener, error) {
|
||||
|
||||
l, err := net.Listen("tcp", s.Addr[len(transport.HTTPS+"://"):])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return l, nil
|
||||
}
|
||||
|
||||
// ListenPacket implements caddy.UDPServer interface.
|
||||
func (s *ServerHTTPS) ListenPacket() (net.PacketConn, error) { return nil, nil }
|
||||
|
||||
// OnStartupComplete lists the sites served by this server
|
||||
// and any relevant information, assuming Quiet is false.
|
||||
func (s *ServerHTTPS) OnStartupComplete() {
|
||||
if Quiet {
|
||||
return
|
||||
}
|
||||
|
||||
out := startUpZones(transport.HTTPS+"://", s.Addr, s.zones)
|
||||
if out != "" {
|
||||
fmt.Print(out)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Stop stops the server. It blocks until the server is totally stopped.
|
||||
func (s *ServerHTTPS) Stop() error {
|
||||
s.m.Lock()
|
||||
defer s.m.Unlock()
|
||||
if s.httpsServer != nil {
|
||||
s.httpsServer.Shutdown(context.Background())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ServeHTTP is the handler that gets the HTTP request and converts to the dns format, calls the plugin
|
||||
// chain, converts it back and write it to the client.
|
||||
func (s *ServerHTTPS) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if r.URL.Path != doh.Path {
|
||||
http.Error(w, "", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
msg, err := doh.RequestToMsg(r)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Create a DoHWriter with the correct addresses in it.
|
||||
h, p, _ := net.SplitHostPort(r.RemoteAddr)
|
||||
port, _ := strconv.Atoi(p)
|
||||
dw := &DoHWriter{laddr: s.listenAddr, raddr: &net.TCPAddr{IP: net.ParseIP(h), Port: port}}
|
||||
|
||||
// We just call the normal chain handler - all error handling is done there.
|
||||
// We should expect a packet to be returned that we can send to the client.
|
||||
s.ServeDNS(context.Background(), dw, msg)
|
||||
|
||||
// See section 4.2.1 of RFC 8484.
|
||||
// We are using code 500 to indicate an unexpected situation when the chain
|
||||
// handler has not provided any response message.
|
||||
if dw.Msg == nil {
|
||||
http.Error(w, "No response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
buf, _ := dw.Msg.Pack()
|
||||
|
||||
mt, _ := response.Typify(dw.Msg, time.Now().UTC())
|
||||
age := dnsutil.MinimalTTL(dw.Msg, mt)
|
||||
|
||||
w.Header().Set("Content-Type", doh.MimeType)
|
||||
w.Header().Set("Cache-Control", fmt.Sprintf("max-age=%f", age.Seconds()))
|
||||
w.Header().Set("Content-Length", strconv.Itoa(len(buf)))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
w.Write(buf)
|
||||
}
|
||||
|
||||
// Shutdown stops the server (non gracefully).
|
||||
func (s *ServerHTTPS) Shutdown() error {
|
||||
if s.httpsServer != nil {
|
||||
s.httpsServer.Shutdown(context.Background())
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,82 @@
|
|||
package dnsserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/coredns/coredns/plugin/pkg/transport"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// ServerTLS represents an instance of a TLS-over-DNS-server.
|
||||
type ServerTLS struct {
|
||||
*Server
|
||||
tlsConfig *tls.Config
|
||||
}
|
||||
|
||||
// NewServerTLS returns a new CoreDNS TLS server and compiles all plugin in to it.
|
||||
func NewServerTLS(addr string, group []*Config) (*ServerTLS, error) {
|
||||
s, err := NewServer(addr, group)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// The *tls* plugin must make sure that multiple conflicting
|
||||
// TLS configuration return an error: it can only be specified once.
|
||||
var tlsConfig *tls.Config
|
||||
for _, conf := range s.zones {
|
||||
// Should we error if some configs *don't* have TLS?
|
||||
tlsConfig = conf.TLSConfig
|
||||
}
|
||||
|
||||
return &ServerTLS{Server: s, tlsConfig: tlsConfig}, nil
|
||||
}
|
||||
|
||||
// Serve implements caddy.TCPServer interface.
|
||||
func (s *ServerTLS) Serve(l net.Listener) error {
|
||||
s.m.Lock()
|
||||
|
||||
if s.tlsConfig != nil {
|
||||
l = tls.NewListener(l, s.tlsConfig)
|
||||
}
|
||||
|
||||
// Only fill out the TCP server for this one.
|
||||
s.server[tcp] = &dns.Server{Listener: l, Net: "tcp-tls", Handler: dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
ctx := context.Background()
|
||||
s.ServeDNS(ctx, w, r)
|
||||
})}
|
||||
s.m.Unlock()
|
||||
|
||||
return s.server[tcp].ActivateAndServe()
|
||||
}
|
||||
|
||||
// ServePacket implements caddy.UDPServer interface.
|
||||
func (s *ServerTLS) ServePacket(p net.PacketConn) error { return nil }
|
||||
|
||||
// Listen implements caddy.TCPServer interface.
|
||||
func (s *ServerTLS) Listen() (net.Listener, error) {
|
||||
l, err := net.Listen("tcp", s.Addr[len(transport.TLS+"://"):])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return l, nil
|
||||
}
|
||||
|
||||
// ListenPacket implements caddy.UDPServer interface.
|
||||
func (s *ServerTLS) ListenPacket() (net.PacketConn, error) { return nil, nil }
|
||||
|
||||
// OnStartupComplete lists the sites served by this server
|
||||
// and any relevant information, assuming Quiet is false.
|
||||
func (s *ServerTLS) OnStartupComplete() {
|
||||
if Quiet {
|
||||
return
|
||||
}
|
||||
|
||||
out := startUpZones(transport.TLS+"://", s.Addr, s.zones)
|
||||
if out != "" {
|
||||
fmt.Print(out)
|
||||
}
|
||||
return
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
// generated by directives_generate.go; DO NOT EDIT
|
||||
|
||||
package dnsserver
|
||||
|
||||
// Directives are registered in the order they should be
|
||||
// executed.
|
||||
//
|
||||
// Ordering is VERY important. Every plugin will
|
||||
// feel the effects of all other plugin below
|
||||
// (after) them during a request, but they must not
|
||||
// care what plugin above them are doing.
|
||||
var Directives = []string{
|
||||
"metadata",
|
||||
"cancel",
|
||||
"tls",
|
||||
"reload",
|
||||
"nsid",
|
||||
"root",
|
||||
"bind",
|
||||
"debug",
|
||||
"trace",
|
||||
"ready",
|
||||
"health",
|
||||
"pprof",
|
||||
"prometheus",
|
||||
"errors",
|
||||
"log",
|
||||
"dnstap",
|
||||
"chaos",
|
||||
"loadbalance",
|
||||
"cache",
|
||||
"rewrite",
|
||||
"dnssec",
|
||||
"autopath",
|
||||
"template",
|
||||
"hosts",
|
||||
"route53",
|
||||
"federation",
|
||||
"k8s_external",
|
||||
"kubernetes",
|
||||
"file",
|
||||
"auto",
|
||||
"secondary",
|
||||
"etcd",
|
||||
"loop",
|
||||
"forward",
|
||||
"grpc",
|
||||
"erratic",
|
||||
"whoami",
|
||||
"on",
|
||||
}
|
|
@ -0,0 +1,267 @@
|
|||
// Package coremain contains the functions for starting CoreDNS.
|
||||
package coremain
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/coredns/coredns/core/dnsserver"
|
||||
clog "github.com/coredns/coredns/plugin/pkg/log"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
)
|
||||
|
||||
func init() {
|
||||
caddy.DefaultConfigFile = "Corefile"
|
||||
caddy.Quiet = true // don't show init stuff from caddy
|
||||
setVersion()
|
||||
|
||||
flag.StringVar(&conf, "conf", "", "Corefile to load (default \""+caddy.DefaultConfigFile+"\")")
|
||||
flag.StringVar(&cpu, "cpu", "100%", "CPU cap")
|
||||
flag.BoolVar(&plugins, "plugins", false, "List installed plugins")
|
||||
flag.StringVar(&caddy.PidFile, "pidfile", "", "Path to write pid file")
|
||||
flag.BoolVar(&version, "version", false, "Show version")
|
||||
flag.BoolVar(&dnsserver.Quiet, "quiet", false, "Quiet mode (no initialization output)")
|
||||
|
||||
caddy.RegisterCaddyfileLoader("flag", caddy.LoaderFunc(confLoader))
|
||||
caddy.SetDefaultCaddyfileLoader("default", caddy.LoaderFunc(defaultLoader))
|
||||
|
||||
caddy.AppName = coreName
|
||||
caddy.AppVersion = CoreVersion
|
||||
}
|
||||
|
||||
// Run is CoreDNS's main() function.
|
||||
func Run() {
|
||||
caddy.TrapSignals()
|
||||
|
||||
// Reset flag.CommandLine to get rid of unwanted flags for instance from glog (used in kubernetes).
|
||||
// And read the ones we want to keep.
|
||||
flag.VisitAll(func(f *flag.Flag) {
|
||||
if _, ok := flagsBlacklist[f.Name]; ok {
|
||||
return
|
||||
}
|
||||
flagsToKeep = append(flagsToKeep, f)
|
||||
})
|
||||
|
||||
flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError)
|
||||
for _, f := range flagsToKeep {
|
||||
flag.Var(f.Value, f.Name, f.Usage)
|
||||
}
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if len(flag.Args()) > 0 {
|
||||
mustLogFatal(fmt.Errorf("extra command line arguments: %s", flag.Args()))
|
||||
}
|
||||
|
||||
log.SetOutput(os.Stdout)
|
||||
log.SetFlags(0) // Set to 0 because we're doing our own time, with timezone
|
||||
|
||||
if version {
|
||||
showVersion()
|
||||
os.Exit(0)
|
||||
}
|
||||
if plugins {
|
||||
fmt.Println(caddy.DescribePlugins())
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
// Set CPU cap
|
||||
if err := setCPU(cpu); err != nil {
|
||||
mustLogFatal(err)
|
||||
}
|
||||
|
||||
// Get Corefile input
|
||||
corefile, err := caddy.LoadCaddyfile(serverType)
|
||||
if err != nil {
|
||||
mustLogFatal(err)
|
||||
}
|
||||
|
||||
// Start your engines
|
||||
instance, err := caddy.Start(corefile)
|
||||
if err != nil {
|
||||
mustLogFatal(err)
|
||||
}
|
||||
|
||||
logVersion()
|
||||
if !dnsserver.Quiet {
|
||||
showVersion()
|
||||
}
|
||||
|
||||
// Twiddle your thumbs
|
||||
instance.Wait()
|
||||
}
|
||||
|
||||
// mustLogFatal wraps log.Fatal() in a way that ensures the
|
||||
// output is always printed to stderr so the user can see it
|
||||
// if the user is still there, even if the process log was not
|
||||
// enabled. If this process is an upgrade, however, and the user
|
||||
// might not be there anymore, this just logs to the process
|
||||
// log and exits.
|
||||
func mustLogFatal(args ...interface{}) {
|
||||
if !caddy.IsUpgrade() {
|
||||
log.SetOutput(os.Stderr)
|
||||
}
|
||||
log.Fatal(args...)
|
||||
}
|
||||
|
||||
// confLoader loads the Caddyfile using the -conf flag.
|
||||
func confLoader(serverType string) (caddy.Input, error) {
|
||||
if conf == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if conf == "stdin" {
|
||||
return caddy.CaddyfileFromPipe(os.Stdin, serverType)
|
||||
}
|
||||
|
||||
contents, err := ioutil.ReadFile(conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return caddy.CaddyfileInput{
|
||||
Contents: contents,
|
||||
Filepath: conf,
|
||||
ServerTypeName: serverType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// defaultLoader loads the Corefile from the current working directory.
|
||||
func defaultLoader(serverType string) (caddy.Input, error) {
|
||||
contents, err := ioutil.ReadFile(caddy.DefaultConfigFile)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return caddy.CaddyfileInput{
|
||||
Contents: contents,
|
||||
Filepath: caddy.DefaultConfigFile,
|
||||
ServerTypeName: serverType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// logVersion logs the version that is starting.
|
||||
func logVersion() {
|
||||
clog.Info(versionString())
|
||||
clog.Info(releaseString())
|
||||
}
|
||||
|
||||
// showVersion prints the version that is starting.
|
||||
func showVersion() {
|
||||
fmt.Print(versionString())
|
||||
fmt.Print(releaseString())
|
||||
if devBuild && gitShortStat != "" {
|
||||
fmt.Printf("%s\n%s\n", gitShortStat, gitFilesModified)
|
||||
}
|
||||
}
|
||||
|
||||
// versionString returns the CoreDNS version as a string.
|
||||
func versionString() string {
|
||||
return fmt.Sprintf("%s-%s\n", caddy.AppName, caddy.AppVersion)
|
||||
}
|
||||
|
||||
// releaseString returns the release information related to CoreDNS version:
|
||||
// <OS>/<ARCH>, <go version>, <commit>
|
||||
// e.g.,
|
||||
// linux/amd64, go1.8.3, a6d2d7b5
|
||||
func releaseString() string {
|
||||
return fmt.Sprintf("%s/%s, %s, %s\n", runtime.GOOS, runtime.GOARCH, runtime.Version(), GitCommit)
|
||||
}
|
||||
|
||||
// setVersion figures out the version information
|
||||
// based on variables set by -ldflags.
|
||||
func setVersion() {
|
||||
// A development build is one that's not at a tag or has uncommitted changes
|
||||
devBuild = gitTag == "" || gitShortStat != ""
|
||||
|
||||
// Only set the appVersion if -ldflags was used
|
||||
if gitNearestTag != "" || gitTag != "" {
|
||||
if devBuild && gitNearestTag != "" {
|
||||
appVersion = fmt.Sprintf("%s (+%s %s)",
|
||||
strings.TrimPrefix(gitNearestTag, "v"), GitCommit, buildDate)
|
||||
} else if gitTag != "" {
|
||||
appVersion = strings.TrimPrefix(gitTag, "v")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// setCPU parses string cpu and sets GOMAXPROCS
|
||||
// according to its value. It accepts either
|
||||
// a number (e.g. 3) or a percent (e.g. 50%).
|
||||
func setCPU(cpu string) error {
|
||||
var numCPU int
|
||||
|
||||
availCPU := runtime.NumCPU()
|
||||
|
||||
if strings.HasSuffix(cpu, "%") {
|
||||
// Percent
|
||||
var percent float32
|
||||
pctStr := cpu[:len(cpu)-1]
|
||||
pctInt, err := strconv.Atoi(pctStr)
|
||||
if err != nil || pctInt < 1 || pctInt > 100 {
|
||||
return errors.New("invalid CPU value: percentage must be between 1-100")
|
||||
}
|
||||
percent = float32(pctInt) / 100
|
||||
numCPU = int(float32(availCPU) * percent)
|
||||
} else {
|
||||
// Number
|
||||
num, err := strconv.Atoi(cpu)
|
||||
if err != nil || num < 1 {
|
||||
return errors.New("invalid CPU value: provide a number or percent greater than 0")
|
||||
}
|
||||
numCPU = num
|
||||
}
|
||||
|
||||
if numCPU > availCPU {
|
||||
numCPU = availCPU
|
||||
}
|
||||
|
||||
runtime.GOMAXPROCS(numCPU)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flags that control program flow or startup
|
||||
var (
|
||||
conf string
|
||||
cpu string
|
||||
logfile bool
|
||||
version bool
|
||||
plugins bool
|
||||
)
|
||||
|
||||
// Build information obtained with the help of -ldflags
|
||||
var (
|
||||
appVersion = "(untracked dev build)" // inferred at startup
|
||||
devBuild = true // inferred at startup
|
||||
|
||||
buildDate string // date -u
|
||||
gitTag string // git describe --exact-match HEAD 2> /dev/null
|
||||
gitNearestTag string // git describe --abbrev=0 --tags HEAD
|
||||
gitShortStat string // git diff-index --shortstat
|
||||
gitFilesModified string // git diff-index --name-only HEAD
|
||||
|
||||
// Gitcommit contains the commit where we built CoreDNS from.
|
||||
GitCommit string
|
||||
)
|
||||
|
||||
// flagsBlacklist removes flags with these names from our flagset.
|
||||
var flagsBlacklist = map[string]struct{}{
|
||||
"logtostderr": struct{}{},
|
||||
"alsologtostderr": struct{}{},
|
||||
"v": struct{}{},
|
||||
"stderrthreshold": struct{}{},
|
||||
"vmodule": struct{}{},
|
||||
"log_backtrace_at": struct{}{},
|
||||
"log_dir": struct{}{},
|
||||
}
|
||||
|
||||
var flagsToKeep []*flag.Flag
|
|
@ -0,0 +1,8 @@
|
|||
package coremain
|
||||
|
||||
// Various CoreDNS constants.
|
||||
const (
|
||||
CoreVersion = "1.5.0"
|
||||
coreName = "CoreDNS"
|
||||
serverType = "dns"
|
||||
)
|
|
@ -0,0 +1,13 @@
|
|||
# Generate the Go files from the dns.proto protobuf, you need the utilities
|
||||
# from: https://github.com/golang/protobuf to make this work.
|
||||
# The generate dns.pb.go is checked into git, so for normal builds we don't need
|
||||
# to run this generation step.
|
||||
|
||||
all: dns.pb.go
|
||||
|
||||
dns.pb.go: dns.proto
|
||||
protoc --go_out=plugins=grpc:. dns.proto
|
||||
|
||||
.PHONY: clean
|
||||
clean:
|
||||
rm dns.pb.go
|
|
@ -0,0 +1,156 @@
|
|||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// source: dns.proto
|
||||
|
||||
package pb
|
||||
|
||||
import (
|
||||
context "context"
|
||||
fmt "fmt"
|
||||
math "math"
|
||||
|
||||
proto "github.com/golang/protobuf/proto"
|
||||
grpc "google.golang.org/grpc"
|
||||
)
|
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
var _ = proto.Marshal
|
||||
var _ = fmt.Errorf
|
||||
var _ = math.Inf
|
||||
|
||||
/* Miek: disabled this manually, because I don't know what the heck */
|
||||
/*
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the proto package it is being compiled against.
|
||||
// A compilation error at this line likely means your copy of the
|
||||
// proto package needs to be updated.
|
||||
const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
|
||||
*/
|
||||
|
||||
type DnsPacket struct {
|
||||
Msg []byte `protobuf:"bytes,1,opt,name=msg,proto3" json:"msg,omitempty"`
|
||||
XXX_NoUnkeyedLiteral struct{} `json:"-"`
|
||||
XXX_unrecognized []byte `json:"-"`
|
||||
XXX_sizecache int32 `json:"-"`
|
||||
}
|
||||
|
||||
func (m *DnsPacket) Reset() { *m = DnsPacket{} }
|
||||
func (m *DnsPacket) String() string { return proto.CompactTextString(m) }
|
||||
func (*DnsPacket) ProtoMessage() {}
|
||||
func (*DnsPacket) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_638ff8d8aaf3d8ae, []int{0}
|
||||
}
|
||||
|
||||
func (m *DnsPacket) XXX_Unmarshal(b []byte) error {
|
||||
return xxx_messageInfo_DnsPacket.Unmarshal(m, b)
|
||||
}
|
||||
func (m *DnsPacket) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
|
||||
return xxx_messageInfo_DnsPacket.Marshal(b, m, deterministic)
|
||||
}
|
||||
func (m *DnsPacket) XXX_Merge(src proto.Message) {
|
||||
xxx_messageInfo_DnsPacket.Merge(m, src)
|
||||
}
|
||||
func (m *DnsPacket) XXX_Size() int {
|
||||
return xxx_messageInfo_DnsPacket.Size(m)
|
||||
}
|
||||
func (m *DnsPacket) XXX_DiscardUnknown() {
|
||||
xxx_messageInfo_DnsPacket.DiscardUnknown(m)
|
||||
}
|
||||
|
||||
var xxx_messageInfo_DnsPacket proto.InternalMessageInfo
|
||||
|
||||
func (m *DnsPacket) GetMsg() []byte {
|
||||
if m != nil {
|
||||
return m.Msg
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
proto.RegisterType((*DnsPacket)(nil), "coredns.dns.DnsPacket")
|
||||
}
|
||||
|
||||
func init() { proto.RegisterFile("dns.proto", fileDescriptor_638ff8d8aaf3d8ae) }
|
||||
|
||||
var fileDescriptor_638ff8d8aaf3d8ae = []byte{
|
||||
// 120 bytes of a gzipped FileDescriptorProto
|
||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x4c, 0xc9, 0x2b, 0xd6,
|
||||
0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x4e, 0xce, 0x2f, 0x4a, 0x05, 0x71, 0x53, 0xf2, 0x8a,
|
||||
0x95, 0x64, 0xb9, 0x38, 0x5d, 0xf2, 0x8a, 0x03, 0x12, 0x93, 0xb3, 0x53, 0x4b, 0x84, 0x04, 0xb8,
|
||||
0x98, 0x73, 0x8b, 0xd3, 0x25, 0x18, 0x15, 0x18, 0x35, 0x78, 0x82, 0x40, 0x4c, 0x23, 0x57, 0x2e,
|
||||
0x2e, 0x97, 0xbc, 0xe2, 0xe0, 0xd4, 0xa2, 0xb2, 0xcc, 0xe4, 0x54, 0x21, 0x73, 0x2e, 0xd6, 0xc0,
|
||||
0xd2, 0xd4, 0xa2, 0x4a, 0x21, 0x31, 0x3d, 0x24, 0x33, 0xf4, 0xe0, 0x06, 0x48, 0xe1, 0x10, 0x77,
|
||||
0x62, 0x89, 0x62, 0x2a, 0x48, 0x4a, 0x62, 0x03, 0xdb, 0x6f, 0x0c, 0x08, 0x00, 0x00, 0xff, 0xff,
|
||||
0xf5, 0xd1, 0x3f, 0x26, 0x8c, 0x00, 0x00, 0x00,
|
||||
}
|
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
var _ context.Context
|
||||
var _ grpc.ClientConn
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the grpc package it is being compiled against.
|
||||
const _ = grpc.SupportPackageIsVersion4
|
||||
|
||||
// DnsServiceClient is the client API for DnsService service.
|
||||
//
|
||||
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream.
|
||||
type DnsServiceClient interface {
|
||||
Query(ctx context.Context, in *DnsPacket, opts ...grpc.CallOption) (*DnsPacket, error)
|
||||
}
|
||||
|
||||
type dnsServiceClient struct {
|
||||
cc *grpc.ClientConn
|
||||
}
|
||||
|
||||
func NewDnsServiceClient(cc *grpc.ClientConn) DnsServiceClient {
|
||||
return &dnsServiceClient{cc}
|
||||
}
|
||||
|
||||
func (c *dnsServiceClient) Query(ctx context.Context, in *DnsPacket, opts ...grpc.CallOption) (*DnsPacket, error) {
|
||||
out := new(DnsPacket)
|
||||
err := c.cc.Invoke(ctx, "/coredns.dns.DnsService/Query", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// DnsServiceServer is the server API for DnsService service.
|
||||
type DnsServiceServer interface {
|
||||
Query(context.Context, *DnsPacket) (*DnsPacket, error)
|
||||
}
|
||||
|
||||
func RegisterDnsServiceServer(s *grpc.Server, srv DnsServiceServer) {
|
||||
s.RegisterService(&_DnsService_serviceDesc, srv)
|
||||
}
|
||||
|
||||
func _DnsService_Query_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(DnsPacket)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DnsServiceServer).Query(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/coredns.dns.DnsService/Query",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DnsServiceServer).Query(ctx, req.(*DnsPacket))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
var _DnsService_serviceDesc = grpc.ServiceDesc{
|
||||
ServiceName: "coredns.dns.DnsService",
|
||||
HandlerType: (*DnsServiceServer)(nil),
|
||||
Methods: []grpc.MethodDesc{
|
||||
{
|
||||
MethodName: "Query",
|
||||
Handler: _DnsService_Query_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{},
|
||||
Metadata: "dns.proto",
|
||||
}
|
|
@ -0,0 +1,12 @@
|
|||
syntax = "proto3";
|
||||
|
||||
package coredns.dns;
|
||||
option go_package = "pb";
|
||||
|
||||
message DnsPacket {
|
||||
bytes msg = 1;
|
||||
}
|
||||
|
||||
service DnsService {
|
||||
rpc Query (DnsPacket) returns (DnsPacket);
|
||||
}
|
|
@ -0,0 +1,49 @@
|
|||
package plugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/coredns/coredns/plugin/etcd/msg"
|
||||
"github.com/coredns/coredns/request"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// ServiceBackend defines a (dynamic) backend that returns a slice of service definitions.
|
||||
type ServiceBackend interface {
|
||||
// Services communicates with the backend to retrieve the service definitions. Exact indicates
|
||||
// on exact match should be returned.
|
||||
Services(ctx context.Context, state request.Request, exact bool, opt Options) ([]msg.Service, error)
|
||||
|
||||
// Reverse communicates with the backend to retrieve service definition based on a IP address
|
||||
// instead of a name. I.e. a reverse DNS lookup.
|
||||
Reverse(ctx context.Context, state request.Request, exact bool, opt Options) ([]msg.Service, error)
|
||||
|
||||
// Lookup is used to find records else where.
|
||||
Lookup(ctx context.Context, state request.Request, name string, typ uint16) (*dns.Msg, error)
|
||||
|
||||
// Returns _all_ services that matches a certain name.
|
||||
// Note: it does not implement a specific service.
|
||||
Records(ctx context.Context, state request.Request, exact bool) ([]msg.Service, error)
|
||||
|
||||
// IsNameError return true if err indicated a record not found condition
|
||||
IsNameError(err error) bool
|
||||
|
||||
Transferer
|
||||
}
|
||||
|
||||
// Transferer defines an interface for backends that provide AXFR of all records.
|
||||
type Transferer interface {
|
||||
// Serial returns a SOA serial number to construct a SOA record.
|
||||
Serial(state request.Request) uint32
|
||||
|
||||
// MinTTL returns the minimum TTL to be used in the SOA record.
|
||||
MinTTL(state request.Request) uint32
|
||||
|
||||
// Transfer handles a zone transfer it writes to the client just
|
||||
// like any other handler.
|
||||
Transfer(ctx context.Context, state request.Request) (int, error)
|
||||
}
|
||||
|
||||
// Options are extra options that can be specified for a lookup.
|
||||
type Options struct{}
|
|
@ -0,0 +1,488 @@
|
|||
package plugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"net"
|
||||
|
||||
"github.com/coredns/coredns/plugin/etcd/msg"
|
||||
"github.com/coredns/coredns/plugin/pkg/dnsutil"
|
||||
"github.com/coredns/coredns/request"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// A returns A records from Backend or an error.
|
||||
func A(ctx context.Context, b ServiceBackend, zone string, state request.Request, previousRecords []dns.RR, opt Options) (records []dns.RR, err error) {
|
||||
services, err := checkForApex(ctx, b, zone, state, opt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dup := make(map[string]struct{})
|
||||
|
||||
for _, serv := range services {
|
||||
|
||||
what, ip := serv.HostType()
|
||||
|
||||
switch what {
|
||||
case dns.TypeCNAME:
|
||||
if Name(state.Name()).Matches(dns.Fqdn(serv.Host)) {
|
||||
// x CNAME x is a direct loop, don't add those
|
||||
continue
|
||||
}
|
||||
|
||||
newRecord := serv.NewCNAME(state.QName(), serv.Host)
|
||||
if len(previousRecords) > 7 {
|
||||
// don't add it, and just continue
|
||||
continue
|
||||
}
|
||||
if dnsutil.DuplicateCNAME(newRecord, previousRecords) {
|
||||
continue
|
||||
}
|
||||
if dns.IsSubDomain(zone, dns.Fqdn(serv.Host)) {
|
||||
state1 := state.NewWithQuestion(serv.Host, state.QType())
|
||||
state1.Zone = zone
|
||||
nextRecords, err := A(ctx, b, zone, state1, append(previousRecords, newRecord), opt)
|
||||
|
||||
if err == nil {
|
||||
// Not only have we found something we should add the CNAME and the IP addresses.
|
||||
if len(nextRecords) > 0 {
|
||||
records = append(records, newRecord)
|
||||
records = append(records, nextRecords...)
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
// This means we can not complete the CNAME, try to look else where.
|
||||
target := newRecord.Target
|
||||
// Lookup
|
||||
m1, e1 := b.Lookup(ctx, state, target, state.QType())
|
||||
if e1 != nil {
|
||||
continue
|
||||
}
|
||||
// Len(m1.Answer) > 0 here is well?
|
||||
records = append(records, newRecord)
|
||||
records = append(records, m1.Answer...)
|
||||
continue
|
||||
|
||||
case dns.TypeA:
|
||||
if _, ok := dup[serv.Host]; !ok {
|
||||
dup[serv.Host] = struct{}{}
|
||||
records = append(records, serv.NewA(state.QName(), ip))
|
||||
}
|
||||
|
||||
case dns.TypeAAAA:
|
||||
// nada
|
||||
}
|
||||
}
|
||||
return records, nil
|
||||
}
|
||||
|
||||
// AAAA returns AAAA records from Backend or an error.
|
||||
func AAAA(ctx context.Context, b ServiceBackend, zone string, state request.Request, previousRecords []dns.RR, opt Options) (records []dns.RR, err error) {
|
||||
services, err := checkForApex(ctx, b, zone, state, opt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dup := make(map[string]struct{})
|
||||
|
||||
for _, serv := range services {
|
||||
|
||||
what, ip := serv.HostType()
|
||||
|
||||
switch what {
|
||||
case dns.TypeCNAME:
|
||||
// Try to resolve as CNAME if it's not an IP, but only if we don't create loops.
|
||||
if Name(state.Name()).Matches(dns.Fqdn(serv.Host)) {
|
||||
// x CNAME x is a direct loop, don't add those
|
||||
continue
|
||||
}
|
||||
|
||||
newRecord := serv.NewCNAME(state.QName(), serv.Host)
|
||||
if len(previousRecords) > 7 {
|
||||
// don't add it, and just continue
|
||||
continue
|
||||
}
|
||||
if dnsutil.DuplicateCNAME(newRecord, previousRecords) {
|
||||
continue
|
||||
}
|
||||
if dns.IsSubDomain(zone, dns.Fqdn(serv.Host)) {
|
||||
state1 := state.NewWithQuestion(serv.Host, state.QType())
|
||||
state1.Zone = zone
|
||||
nextRecords, err := AAAA(ctx, b, zone, state1, append(previousRecords, newRecord), opt)
|
||||
|
||||
if err == nil {
|
||||
// Not only have we found something we should add the CNAME and the IP addresses.
|
||||
if len(nextRecords) > 0 {
|
||||
records = append(records, newRecord)
|
||||
records = append(records, nextRecords...)
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
// This means we can not complete the CNAME, try to look else where.
|
||||
target := newRecord.Target
|
||||
m1, e1 := b.Lookup(ctx, state, target, state.QType())
|
||||
if e1 != nil {
|
||||
continue
|
||||
}
|
||||
// Len(m1.Answer) > 0 here is well?
|
||||
records = append(records, newRecord)
|
||||
records = append(records, m1.Answer...)
|
||||
continue
|
||||
// both here again
|
||||
|
||||
case dns.TypeA:
|
||||
// nada
|
||||
|
||||
case dns.TypeAAAA:
|
||||
if _, ok := dup[serv.Host]; !ok {
|
||||
dup[serv.Host] = struct{}{}
|
||||
records = append(records, serv.NewAAAA(state.QName(), ip))
|
||||
}
|
||||
}
|
||||
}
|
||||
return records, nil
|
||||
}
|
||||
|
||||
// SRV returns SRV records from the Backend.
|
||||
// If the Target is not a name but an IP address, a name is created on the fly.
|
||||
func SRV(ctx context.Context, b ServiceBackend, zone string, state request.Request, opt Options) (records, extra []dns.RR, err error) {
|
||||
services, err := b.Services(ctx, state, false, opt)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
dup := make(map[item]struct{})
|
||||
lookup := make(map[string]struct{})
|
||||
|
||||
// Looping twice to get the right weight vs priority. This might break because we may drop duplicate SRV records latter on.
|
||||
w := make(map[int]int)
|
||||
for _, serv := range services {
|
||||
weight := 100
|
||||
if serv.Weight != 0 {
|
||||
weight = serv.Weight
|
||||
}
|
||||
if _, ok := w[serv.Priority]; !ok {
|
||||
w[serv.Priority] = weight
|
||||
continue
|
||||
}
|
||||
w[serv.Priority] += weight
|
||||
}
|
||||
for _, serv := range services {
|
||||
// Don't add the entry if the port is -1 (invalid). The kubernetes plugin uses port -1 when a service/endpoint
|
||||
// does not have any declared ports.
|
||||
if serv.Port == -1 {
|
||||
continue
|
||||
}
|
||||
w1 := 100.0 / float64(w[serv.Priority])
|
||||
if serv.Weight == 0 {
|
||||
w1 *= 100
|
||||
} else {
|
||||
w1 *= float64(serv.Weight)
|
||||
}
|
||||
weight := uint16(math.Floor(w1))
|
||||
|
||||
what, ip := serv.HostType()
|
||||
|
||||
switch what {
|
||||
case dns.TypeCNAME:
|
||||
srv := serv.NewSRV(state.QName(), weight)
|
||||
records = append(records, srv)
|
||||
|
||||
if _, ok := lookup[srv.Target]; ok {
|
||||
break
|
||||
}
|
||||
|
||||
lookup[srv.Target] = struct{}{}
|
||||
|
||||
if !dns.IsSubDomain(zone, srv.Target) {
|
||||
m1, e1 := b.Lookup(ctx, state, srv.Target, dns.TypeA)
|
||||
if e1 == nil {
|
||||
extra = append(extra, m1.Answer...)
|
||||
}
|
||||
|
||||
m1, e1 = b.Lookup(ctx, state, srv.Target, dns.TypeAAAA)
|
||||
if e1 == nil {
|
||||
// If we have seen CNAME's we *assume* that they are already added.
|
||||
for _, a := range m1.Answer {
|
||||
if _, ok := a.(*dns.CNAME); !ok {
|
||||
extra = append(extra, a)
|
||||
}
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
// Internal name, we should have some info on them, either v4 or v6
|
||||
// Clients expect a complete answer, because we are a recursor in their view.
|
||||
state1 := state.NewWithQuestion(srv.Target, dns.TypeA)
|
||||
addr, e1 := A(ctx, b, zone, state1, nil, opt)
|
||||
if e1 == nil {
|
||||
extra = append(extra, addr...)
|
||||
}
|
||||
// TODO(miek): AAAA as well here.
|
||||
|
||||
case dns.TypeA, dns.TypeAAAA:
|
||||
addr := serv.Host
|
||||
serv.Host = msg.Domain(serv.Key)
|
||||
srv := serv.NewSRV(state.QName(), weight)
|
||||
|
||||
if ok := isDuplicate(dup, srv.Target, "", srv.Port); !ok {
|
||||
records = append(records, srv)
|
||||
}
|
||||
|
||||
if ok := isDuplicate(dup, srv.Target, addr, 0); !ok {
|
||||
extra = append(extra, newAddress(serv, srv.Target, ip, what))
|
||||
}
|
||||
}
|
||||
}
|
||||
return records, extra, nil
|
||||
}
|
||||
|
||||
// MX returns MX records from the Backend. If the Target is not a name but an IP address, a name is created on the fly.
|
||||
func MX(ctx context.Context, b ServiceBackend, zone string, state request.Request, opt Options) (records, extra []dns.RR, err error) {
|
||||
services, err := b.Services(ctx, state, false, opt)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
dup := make(map[item]struct{})
|
||||
lookup := make(map[string]struct{})
|
||||
for _, serv := range services {
|
||||
if !serv.Mail {
|
||||
continue
|
||||
}
|
||||
what, ip := serv.HostType()
|
||||
switch what {
|
||||
case dns.TypeCNAME:
|
||||
mx := serv.NewMX(state.QName())
|
||||
records = append(records, mx)
|
||||
if _, ok := lookup[mx.Mx]; ok {
|
||||
break
|
||||
}
|
||||
|
||||
lookup[mx.Mx] = struct{}{}
|
||||
|
||||
if !dns.IsSubDomain(zone, mx.Mx) {
|
||||
m1, e1 := b.Lookup(ctx, state, mx.Mx, dns.TypeA)
|
||||
if e1 == nil {
|
||||
extra = append(extra, m1.Answer...)
|
||||
}
|
||||
|
||||
m1, e1 = b.Lookup(ctx, state, mx.Mx, dns.TypeAAAA)
|
||||
if e1 == nil {
|
||||
// If we have seen CNAME's we *assume* that they are already added.
|
||||
for _, a := range m1.Answer {
|
||||
if _, ok := a.(*dns.CNAME); !ok {
|
||||
extra = append(extra, a)
|
||||
}
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
// Internal name
|
||||
state1 := state.NewWithQuestion(mx.Mx, dns.TypeA)
|
||||
addr, e1 := A(ctx, b, zone, state1, nil, opt)
|
||||
if e1 == nil {
|
||||
extra = append(extra, addr...)
|
||||
}
|
||||
// TODO(miek): AAAA as well here.
|
||||
|
||||
case dns.TypeA, dns.TypeAAAA:
|
||||
addr := serv.Host
|
||||
serv.Host = msg.Domain(serv.Key)
|
||||
mx := serv.NewMX(state.QName())
|
||||
|
||||
if ok := isDuplicate(dup, mx.Mx, "", mx.Preference); !ok {
|
||||
records = append(records, mx)
|
||||
}
|
||||
// Fake port to be 0 for address...
|
||||
if ok := isDuplicate(dup, serv.Host, addr, 0); !ok {
|
||||
extra = append(extra, newAddress(serv, serv.Host, ip, what))
|
||||
}
|
||||
}
|
||||
}
|
||||
return records, extra, nil
|
||||
}
|
||||
|
||||
// CNAME returns CNAME records from the backend or an error.
|
||||
func CNAME(ctx context.Context, b ServiceBackend, zone string, state request.Request, opt Options) (records []dns.RR, err error) {
|
||||
services, err := b.Services(ctx, state, true, opt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(services) > 0 {
|
||||
serv := services[0]
|
||||
if ip := net.ParseIP(serv.Host); ip == nil {
|
||||
records = append(records, serv.NewCNAME(state.QName(), serv.Host))
|
||||
}
|
||||
}
|
||||
return records, nil
|
||||
}
|
||||
|
||||
// TXT returns TXT records from Backend or an error.
|
||||
func TXT(ctx context.Context, b ServiceBackend, zone string, state request.Request, opt Options) (records []dns.RR, err error) {
|
||||
services, err := b.Services(ctx, state, false, opt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, serv := range services {
|
||||
records = append(records, serv.NewTXT(state.QName()))
|
||||
}
|
||||
return records, nil
|
||||
}
|
||||
|
||||
// PTR returns the PTR records from the backend, only services that have a domain name as host are included.
|
||||
func PTR(ctx context.Context, b ServiceBackend, zone string, state request.Request, opt Options) (records []dns.RR, err error) {
|
||||
services, err := b.Reverse(ctx, state, true, opt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dup := make(map[string]struct{})
|
||||
|
||||
for _, serv := range services {
|
||||
if ip := net.ParseIP(serv.Host); ip == nil {
|
||||
if _, ok := dup[serv.Host]; !ok {
|
||||
dup[serv.Host] = struct{}{}
|
||||
records = append(records, serv.NewPTR(state.QName(), serv.Host))
|
||||
}
|
||||
}
|
||||
}
|
||||
return records, nil
|
||||
}
|
||||
|
||||
// NS returns NS records from the backend
|
||||
func NS(ctx context.Context, b ServiceBackend, zone string, state request.Request, opt Options) (records, extra []dns.RR, err error) {
|
||||
// NS record for this zone live in a special place, ns.dns.<zone>. Fake our lookup.
|
||||
// only a tad bit fishy...
|
||||
old := state.QName()
|
||||
|
||||
state.Clear()
|
||||
state.Req.Question[0].Name = "ns.dns." + zone
|
||||
services, err := b.Services(ctx, state, false, opt)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
// ... and reset
|
||||
state.Req.Question[0].Name = old
|
||||
|
||||
for _, serv := range services {
|
||||
what, ip := serv.HostType()
|
||||
switch what {
|
||||
case dns.TypeCNAME:
|
||||
return nil, nil, fmt.Errorf("NS record must be an IP address: %s", serv.Host)
|
||||
|
||||
case dns.TypeA, dns.TypeAAAA:
|
||||
serv.Host = msg.Domain(serv.Key)
|
||||
records = append(records, serv.NewNS(state.QName()))
|
||||
extra = append(extra, newAddress(serv, serv.Host, ip, what))
|
||||
}
|
||||
}
|
||||
return records, extra, nil
|
||||
}
|
||||
|
||||
// SOA returns a SOA record from the backend.
|
||||
func SOA(ctx context.Context, b ServiceBackend, zone string, state request.Request, opt Options) ([]dns.RR, error) {
|
||||
minTTL := b.MinTTL(state)
|
||||
ttl := uint32(300)
|
||||
if minTTL < ttl {
|
||||
ttl = minTTL
|
||||
}
|
||||
|
||||
header := dns.RR_Header{Name: zone, Rrtype: dns.TypeSOA, Ttl: ttl, Class: dns.ClassINET}
|
||||
|
||||
Mbox := hostmaster + "."
|
||||
Ns := "ns.dns."
|
||||
if zone[0] != '.' {
|
||||
Mbox += zone
|
||||
Ns += zone
|
||||
}
|
||||
|
||||
soa := &dns.SOA{Hdr: header,
|
||||
Mbox: Mbox,
|
||||
Ns: Ns,
|
||||
Serial: b.Serial(state),
|
||||
Refresh: 7200,
|
||||
Retry: 1800,
|
||||
Expire: 86400,
|
||||
Minttl: minTTL,
|
||||
}
|
||||
return []dns.RR{soa}, nil
|
||||
}
|
||||
|
||||
// BackendError writes an error response to the client.
|
||||
func BackendError(ctx context.Context, b ServiceBackend, zone string, rcode int, state request.Request, err error, opt Options) (int, error) {
|
||||
m := new(dns.Msg)
|
||||
m.SetRcode(state.Req, rcode)
|
||||
m.Authoritative = true
|
||||
m.Ns, _ = SOA(ctx, b, zone, state, opt)
|
||||
|
||||
state.W.WriteMsg(m)
|
||||
// Return success as the rcode to signal we have written to the client.
|
||||
return dns.RcodeSuccess, err
|
||||
}
|
||||
|
||||
func newAddress(s msg.Service, name string, ip net.IP, what uint16) dns.RR {
|
||||
|
||||
hdr := dns.RR_Header{Name: name, Rrtype: what, Class: dns.ClassINET, Ttl: s.TTL}
|
||||
|
||||
if what == dns.TypeA {
|
||||
return &dns.A{Hdr: hdr, A: ip}
|
||||
}
|
||||
// Should always be dns.TypeAAAA
|
||||
return &dns.AAAA{Hdr: hdr, AAAA: ip}
|
||||
}
|
||||
|
||||
// checkForApex checks the special apex.dns directory for records that will be returned as A or AAAA.
|
||||
func checkForApex(ctx context.Context, b ServiceBackend, zone string, state request.Request, opt Options) ([]msg.Service, error) {
|
||||
if state.Name() != zone {
|
||||
return b.Services(ctx, state, false, opt)
|
||||
}
|
||||
|
||||
// If the zone name itself is queried we fake the query to search for a special entry
|
||||
// this is equivalent to the NS search code.
|
||||
old := state.QName()
|
||||
state.Clear()
|
||||
state.Req.Question[0].Name = dnsutil.Join("apex.dns", zone)
|
||||
|
||||
services, err := b.Services(ctx, state, false, opt)
|
||||
if err == nil {
|
||||
state.Req.Question[0].Name = old
|
||||
return services, err
|
||||
}
|
||||
|
||||
state.Req.Question[0].Name = old
|
||||
return b.Services(ctx, state, false, opt)
|
||||
}
|
||||
|
||||
// item holds records.
|
||||
type item struct {
|
||||
name string // name of the record (either owner or something else unique).
|
||||
port uint16 // port of the record (used for address records, A and AAAA).
|
||||
addr string // address of the record (A and AAAA).
|
||||
}
|
||||
|
||||
// isDuplicate uses m to see if the combo (name, addr, port) already exists. If it does
|
||||
// not exist already IsDuplicate will also add the record to the map.
|
||||
func isDuplicate(m map[item]struct{}, name, addr string, port uint16) bool {
|
||||
if addr != "" {
|
||||
_, ok := m[item{name, 0, addr}]
|
||||
if !ok {
|
||||
m[item{name, 0, addr}] = struct{}{}
|
||||
}
|
||||
return ok
|
||||
}
|
||||
_, ok := m[item{name, port, ""}]
|
||||
if !ok {
|
||||
m[item{name, port, ""}] = struct{}{}
|
||||
}
|
||||
return ok
|
||||
}
|
||||
|
||||
const hostmaster = "hostmaster"
|
|
@ -0,0 +1,6 @@
|
|||
reviewers:
|
||||
- grobie
|
||||
- miekg
|
||||
approvers:
|
||||
- grobie
|
||||
- miekg
|
|
@ -0,0 +1,105 @@
|
|||
# cache
|
||||
|
||||
## Name
|
||||
|
||||
*cache* - enables a frontend cache.
|
||||
|
||||
## Description
|
||||
|
||||
With *cache* enabled, all records except zone transfers and metadata records will be cached for up to
|
||||
3600s. Caching is mostly useful in a scenario when fetching data from the backend (upstream,
|
||||
database, etc.) is expensive.
|
||||
|
||||
This plugin can only be used once per Server Block.
|
||||
|
||||
## Syntax
|
||||
|
||||
~~~ txt
|
||||
cache [TTL] [ZONES...]
|
||||
~~~
|
||||
|
||||
* **TTL** max TTL in seconds. If not specified, the maximum TTL will be used, which is 3600 for
|
||||
NOERROR responses and 1800 for denial of existence ones.
|
||||
Setting a TTL of 300: `cache 300` would cache records up to 300 seconds.
|
||||
* **ZONES** zones it should cache for. If empty, the zones from the configuration block are used.
|
||||
|
||||
Each element in the cache is cached according to its TTL (with **TTL** as the max).
|
||||
A cache is divided into 256 shards, each holding up to 39 items by default - for a total size
|
||||
of 256 * 39 = 9984 items.
|
||||
|
||||
If you want more control:
|
||||
|
||||
~~~ txt
|
||||
cache [TTL] [ZONES...] {
|
||||
success CAPACITY [TTL] [MINTTL]
|
||||
denial CAPACITY [TTL] [MINTTL]
|
||||
prefetch AMOUNT [[DURATION] [PERCENTAGE%]]
|
||||
}
|
||||
~~~
|
||||
|
||||
* **TTL** and **ZONES** as above.
|
||||
* `success`, override the settings for caching successful responses. **CAPACITY** indicates the maximum
|
||||
number of packets we cache before we start evicting (*randomly*). **TTL** overrides the cache maximum TTL.
|
||||
**MINTTL** overrides the cache minimum TTL (default 5), which can be useful to limit queries to the backend.
|
||||
* `denial`, override the settings for caching denial of existence responses. **CAPACITY** indicates the maximum
|
||||
number of packets we cache before we start evicting (LRU). **TTL** overrides the cache maximum TTL.
|
||||
**MINTTL** overrides the cache minimum TTL (default 5), which can be useful to limit queries to the backend.
|
||||
There is a third category (`error`) but those responses are never cached.
|
||||
* `prefetch` will prefetch popular items when they are about to be expunged from the cache.
|
||||
Popular means **AMOUNT** queries have been seen with no gaps of **DURATION** or more between them.
|
||||
**DURATION** defaults to 1m. Prefetching will happen when the TTL drops below **PERCENTAGE**,
|
||||
which defaults to `10%`, or latest 1 second before TTL expiration. Values should be in the range `[10%, 90%]`.
|
||||
Note the percent sign is mandatory. **PERCENTAGE** is treated as an `int`.
|
||||
|
||||
## Capacity and Eviction
|
||||
|
||||
If **CAPACITY** _is not_ specified, the default cache size is 9984 per cache. The minimum allowed cache size is 1024.
|
||||
If **CAPACITY** _is_ specified, the actual cache size used will be rounded down to the nearest number divisible by 256 (so all shards are equal in size).
|
||||
|
||||
Eviction is done per shard. In effect, when a shard reaches capacity, items are evicted from that shard.
|
||||
Since shards don't fill up perfectly evenly, evictions will occur before the entire cache reaches full capacity.
|
||||
Each shard capacity is equal to the total cache size / number of shards (256). Eviction is random, not TTL based.
|
||||
Entries with 0 TTL will remain in the cache until randomly evicted when the shard reaches capacity.
|
||||
|
||||
## Metrics
|
||||
|
||||
If monitoring is enabled (via the *prometheus* directive) then the following metrics are exported:
|
||||
|
||||
* `coredns_cache_size{server, type}` - Total elements in the cache by cache type.
|
||||
* `coredns_cache_hits_total{server, type}` - Counter of cache hits by cache type.
|
||||
* `coredns_cache_misses_total{server}` - Counter of cache misses.
|
||||
* `coredns_cache_drops_total{server}` - Counter of dropped messages.
|
||||
|
||||
Cache types are either "denial" or "success". `Server` is the server handling the request, see the
|
||||
metrics plugin for documentation.
|
||||
|
||||
## Examples
|
||||
|
||||
Enable caching for all zones, but cap everything to a TTL of 10 seconds:
|
||||
|
||||
~~~ corefile
|
||||
. {
|
||||
cache 10
|
||||
whoami
|
||||
}
|
||||
~~~
|
||||
|
||||
Proxy to Google Public DNS and only cache responses for example.org (or below).
|
||||
|
||||
~~~ corefile
|
||||
. {
|
||||
forward . 8.8.8.8:53
|
||||
cache example.org
|
||||
}
|
||||
~~~
|
||||
|
||||
Enable caching for all zones, keep a positive cache size of 5000 and a negative cache size of 2500:
|
||||
|
||||
~~~ corefile
|
||||
. {
|
||||
cache {
|
||||
success 5000
|
||||
denial 2500
|
||||
}
|
||||
}
|
||||
~~~
|
|
@ -0,0 +1,242 @@
|
|||
// Package cache implements a cache.
|
||||
package cache
|
||||
|
||||
import (
|
||||
"hash/fnv"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/coredns/coredns/plugin"
|
||||
"github.com/coredns/coredns/plugin/pkg/cache"
|
||||
"github.com/coredns/coredns/plugin/pkg/dnsutil"
|
||||
"github.com/coredns/coredns/plugin/pkg/response"
|
||||
"github.com/coredns/coredns/request"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Cache is plugin that looks up responses in a cache and caches replies.
|
||||
// It has a success and a denial of existence cache.
|
||||
type Cache struct {
|
||||
Next plugin.Handler
|
||||
Zones []string
|
||||
|
||||
ncache *cache.Cache
|
||||
ncap int
|
||||
nttl time.Duration
|
||||
minnttl time.Duration
|
||||
|
||||
pcache *cache.Cache
|
||||
pcap int
|
||||
pttl time.Duration
|
||||
minpttl time.Duration
|
||||
|
||||
// Prefetch.
|
||||
prefetch int
|
||||
duration time.Duration
|
||||
percentage int
|
||||
|
||||
// Testing.
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
// New returns an initialized Cache with default settings. It's up to the
|
||||
// caller to set the Next handler.
|
||||
func New() *Cache {
|
||||
return &Cache{
|
||||
Zones: []string{"."},
|
||||
pcap: defaultCap,
|
||||
pcache: cache.New(defaultCap),
|
||||
pttl: maxTTL,
|
||||
minpttl: minTTL,
|
||||
ncap: defaultCap,
|
||||
ncache: cache.New(defaultCap),
|
||||
nttl: maxNTTL,
|
||||
minnttl: minNTTL,
|
||||
prefetch: 0,
|
||||
duration: 1 * time.Minute,
|
||||
percentage: 10,
|
||||
now: time.Now,
|
||||
}
|
||||
}
|
||||
|
||||
// key returns key under which we store the item, -1 will be returned if we don't store the message.
|
||||
// Currently we do not cache Truncated, errors zone transfers or dynamic update messages.
|
||||
// qname holds the already lowercased qname.
|
||||
func key(qname string, m *dns.Msg, t response.Type, do bool) (bool, uint64) {
|
||||
// We don't store truncated responses.
|
||||
if m.Truncated {
|
||||
return false, 0
|
||||
}
|
||||
// Nor errors or Meta or Update
|
||||
if t == response.OtherError || t == response.Meta || t == response.Update {
|
||||
return false, 0
|
||||
}
|
||||
|
||||
return true, hash(qname, m.Question[0].Qtype, do)
|
||||
}
|
||||
|
||||
var one = []byte("1")
|
||||
var zero = []byte("0")
|
||||
|
||||
func hash(qname string, qtype uint16, do bool) uint64 {
|
||||
h := fnv.New64()
|
||||
|
||||
if do {
|
||||
h.Write(one)
|
||||
} else {
|
||||
h.Write(zero)
|
||||
}
|
||||
|
||||
h.Write([]byte{byte(qtype >> 8)})
|
||||
h.Write([]byte{byte(qtype)})
|
||||
h.Write([]byte(qname))
|
||||
return h.Sum64()
|
||||
}
|
||||
|
||||
func computeTTL(msgTTL, minTTL, maxTTL time.Duration) time.Duration {
|
||||
ttl := msgTTL
|
||||
if ttl < minTTL {
|
||||
ttl = minTTL
|
||||
}
|
||||
if ttl > maxTTL {
|
||||
ttl = maxTTL
|
||||
}
|
||||
return ttl
|
||||
}
|
||||
|
||||
// ResponseWriter is a response writer that caches the reply message.
|
||||
type ResponseWriter struct {
|
||||
dns.ResponseWriter
|
||||
*Cache
|
||||
state request.Request
|
||||
server string // Server handling the request.
|
||||
|
||||
prefetch bool // When true write nothing back to the client.
|
||||
remoteAddr net.Addr
|
||||
}
|
||||
|
||||
// newPrefetchResponseWriter returns a Cache ResponseWriter to be used in
|
||||
// prefetch requests. It ensures RemoteAddr() can be called even after the
|
||||
// original connection has already been closed.
|
||||
func newPrefetchResponseWriter(server string, state request.Request, c *Cache) *ResponseWriter {
|
||||
// Resolve the address now, the connection might be already closed when the
|
||||
// actual prefetch request is made.
|
||||
addr := state.W.RemoteAddr()
|
||||
// The protocol of the client triggering a cache prefetch doesn't matter.
|
||||
// The address type is used by request.Proto to determine the response size,
|
||||
// and using TCP ensures the message isn't unnecessarily truncated.
|
||||
if u, ok := addr.(*net.UDPAddr); ok {
|
||||
addr = &net.TCPAddr{IP: u.IP, Port: u.Port, Zone: u.Zone}
|
||||
}
|
||||
|
||||
return &ResponseWriter{
|
||||
ResponseWriter: state.W,
|
||||
Cache: c,
|
||||
state: state,
|
||||
server: server,
|
||||
prefetch: true,
|
||||
remoteAddr: addr,
|
||||
}
|
||||
}
|
||||
|
||||
// RemoteAddr implements the dns.ResponseWriter interface.
|
||||
func (w *ResponseWriter) RemoteAddr() net.Addr {
|
||||
if w.remoteAddr != nil {
|
||||
return w.remoteAddr
|
||||
}
|
||||
return w.ResponseWriter.RemoteAddr()
|
||||
}
|
||||
|
||||
// WriteMsg implements the dns.ResponseWriter interface.
|
||||
func (w *ResponseWriter) WriteMsg(res *dns.Msg) error {
|
||||
do := false
|
||||
mt, opt := response.Typify(res, w.now().UTC())
|
||||
if opt != nil {
|
||||
do = opt.Do()
|
||||
}
|
||||
|
||||
// key returns empty string for anything we don't want to cache.
|
||||
hasKey, key := key(w.state.Name(), res, mt, do)
|
||||
|
||||
msgTTL := dnsutil.MinimalTTL(res, mt)
|
||||
var duration time.Duration
|
||||
if mt == response.NameError || mt == response.NoData {
|
||||
duration = computeTTL(msgTTL, w.minnttl, w.nttl)
|
||||
} else {
|
||||
duration = computeTTL(msgTTL, w.minpttl, w.pttl)
|
||||
}
|
||||
|
||||
if hasKey && duration > 0 {
|
||||
if w.state.Match(res) {
|
||||
w.set(res, key, mt, duration)
|
||||
cacheSize.WithLabelValues(w.server, Success).Set(float64(w.pcache.Len()))
|
||||
cacheSize.WithLabelValues(w.server, Denial).Set(float64(w.ncache.Len()))
|
||||
} else {
|
||||
// Don't log it, but increment counter
|
||||
cacheDrops.WithLabelValues(w.server).Inc()
|
||||
}
|
||||
}
|
||||
|
||||
if w.prefetch {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Apply capped TTL to this reply to avoid jarring TTL experience 1799 -> 8 (e.g.)
|
||||
ttl := uint32(duration.Seconds())
|
||||
for i := range res.Answer {
|
||||
res.Answer[i].Header().Ttl = ttl
|
||||
}
|
||||
for i := range res.Ns {
|
||||
res.Ns[i].Header().Ttl = ttl
|
||||
}
|
||||
for i := range res.Extra {
|
||||
if res.Extra[i].Header().Rrtype != dns.TypeOPT {
|
||||
res.Extra[i].Header().Ttl = ttl
|
||||
}
|
||||
}
|
||||
return w.ResponseWriter.WriteMsg(res)
|
||||
}
|
||||
|
||||
func (w *ResponseWriter) set(m *dns.Msg, key uint64, mt response.Type, duration time.Duration) {
|
||||
// duration is expected > 0
|
||||
// and key is valid
|
||||
switch mt {
|
||||
case response.NoError, response.Delegation:
|
||||
i := newItem(m, w.now(), duration)
|
||||
w.pcache.Add(key, i)
|
||||
|
||||
case response.NameError, response.NoData:
|
||||
i := newItem(m, w.now(), duration)
|
||||
w.ncache.Add(key, i)
|
||||
|
||||
case response.OtherError:
|
||||
// don't cache these
|
||||
default:
|
||||
log.Warningf("Caching called with unknown classification: %d", mt)
|
||||
}
|
||||
}
|
||||
|
||||
// Write implements the dns.ResponseWriter interface.
|
||||
func (w *ResponseWriter) Write(buf []byte) (int, error) {
|
||||
log.Warning("Caching called with Write: not caching reply")
|
||||
if w.prefetch {
|
||||
return 0, nil
|
||||
}
|
||||
n, err := w.ResponseWriter.Write(buf)
|
||||
return n, err
|
||||
}
|
||||
|
||||
const (
|
||||
maxTTL = dnsutil.MaximumDefaulTTL
|
||||
minTTL = dnsutil.MinimalDefaultTTL
|
||||
maxNTTL = dnsutil.MaximumDefaulTTL / 2
|
||||
minNTTL = dnsutil.MinimalDefaultTTL
|
||||
|
||||
defaultCap = 10000 // default capacity of the cache.
|
||||
|
||||
// Success is the class for caching positive caching.
|
||||
Success = "success"
|
||||
// Denial is the class defined for negative caching.
|
||||
Denial = "denial"
|
||||
)
|
|
@ -0,0 +1,55 @@
|
|||
// Package freq keeps track of last X seen events. The events themselves are not stored
|
||||
// here. So the Freq type should be added next to the thing it is tracking.
|
||||
package freq
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Freq tracks the frequencies of things.
|
||||
type Freq struct {
|
||||
// Last time we saw a query for this element.
|
||||
last time.Time
|
||||
// Number of this in the last time slice.
|
||||
hits int
|
||||
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
// New returns a new initialized Freq.
|
||||
func New(t time.Time) *Freq {
|
||||
return &Freq{last: t, hits: 0}
|
||||
}
|
||||
|
||||
// Update updates the number of hits. Last time seen will be set to now.
|
||||
// If the last time we've seen this entity is within now - d, we increment hits, otherwise
|
||||
// we reset hits to 1. It returns the number of hits.
|
||||
func (f *Freq) Update(d time.Duration, now time.Time) int {
|
||||
earliest := now.Add(-1 * d)
|
||||
f.Lock()
|
||||
defer f.Unlock()
|
||||
if f.last.Before(earliest) {
|
||||
f.last = now
|
||||
f.hits = 1
|
||||
return f.hits
|
||||
}
|
||||
f.last = now
|
||||
f.hits++
|
||||
return f.hits
|
||||
}
|
||||
|
||||
// Hits returns the number of hits that we have seen, according to the updates we have done to f.
|
||||
func (f *Freq) Hits() int {
|
||||
f.RLock()
|
||||
defer f.RUnlock()
|
||||
return f.hits
|
||||
}
|
||||
|
||||
// Reset resets f to time t and hits to hits.
|
||||
func (f *Freq) Reset(t time.Time, hits int) {
|
||||
f.Lock()
|
||||
defer f.Unlock()
|
||||
f.last = t
|
||||
f.hits = hits
|
||||
}
|
|
@ -0,0 +1,12 @@
|
|||
// +build fuzz
|
||||
|
||||
package cache
|
||||
|
||||
import (
|
||||
"github.com/coredns/coredns/plugin/pkg/fuzz"
|
||||
)
|
||||
|
||||
// Fuzz fuzzes cache.
|
||||
func Fuzz(data []byte) int {
|
||||
return fuzz.Do(New(), data)
|
||||
}
|
|
@ -0,0 +1,127 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/coredns/coredns/plugin"
|
||||
"github.com/coredns/coredns/plugin/metrics"
|
||||
"github.com/coredns/coredns/request"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
// ServeDNS implements the plugin.Handler interface.
|
||||
func (c *Cache) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
state := request.Request{W: w, Req: r}
|
||||
|
||||
zone := plugin.Zones(c.Zones).Matches(state.Name())
|
||||
if zone == "" {
|
||||
return plugin.NextOrFailure(c.Name(), c.Next, ctx, w, r)
|
||||
}
|
||||
|
||||
now := c.now().UTC()
|
||||
|
||||
server := metrics.WithServer(ctx)
|
||||
|
||||
i, found := c.get(now, state, server)
|
||||
if i != nil && found {
|
||||
resp := i.toMsg(r, now)
|
||||
|
||||
w.WriteMsg(resp)
|
||||
|
||||
if c.prefetch > 0 {
|
||||
ttl := i.ttl(now)
|
||||
i.Freq.Update(c.duration, now)
|
||||
|
||||
threshold := int(math.Ceil(float64(c.percentage) / 100 * float64(i.origTTL)))
|
||||
if i.Freq.Hits() >= c.prefetch && ttl <= threshold {
|
||||
cw := newPrefetchResponseWriter(server, state, c)
|
||||
go func(w dns.ResponseWriter) {
|
||||
cachePrefetches.WithLabelValues(server).Inc()
|
||||
plugin.NextOrFailure(c.Name(), c.Next, ctx, w, r)
|
||||
|
||||
// When prefetching we loose the item i, and with it the frequency
|
||||
// that we've gathered sofar. See we copy the frequencies info back
|
||||
// into the new item that was stored in the cache.
|
||||
if i1 := c.exists(state); i1 != nil {
|
||||
i1.Freq.Reset(now, i.Freq.Hits())
|
||||
}
|
||||
}(cw)
|
||||
}
|
||||
}
|
||||
return dns.RcodeSuccess, nil
|
||||
}
|
||||
|
||||
crr := &ResponseWriter{ResponseWriter: w, Cache: c, state: state, server: server}
|
||||
return plugin.NextOrFailure(c.Name(), c.Next, ctx, crr, r)
|
||||
}
|
||||
|
||||
// Name implements the Handler interface.
|
||||
func (c *Cache) Name() string { return "cache" }
|
||||
|
||||
func (c *Cache) get(now time.Time, state request.Request, server string) (*item, bool) {
|
||||
k := hash(state.Name(), state.QType(), state.Do())
|
||||
|
||||
if i, ok := c.ncache.Get(k); ok && i.(*item).ttl(now) > 0 {
|
||||
cacheHits.WithLabelValues(server, Denial).Inc()
|
||||
return i.(*item), true
|
||||
}
|
||||
|
||||
if i, ok := c.pcache.Get(k); ok && i.(*item).ttl(now) > 0 {
|
||||
cacheHits.WithLabelValues(server, Success).Inc()
|
||||
return i.(*item), true
|
||||
}
|
||||
cacheMisses.WithLabelValues(server).Inc()
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (c *Cache) exists(state request.Request) *item {
|
||||
k := hash(state.Name(), state.QType(), state.Do())
|
||||
if i, ok := c.ncache.Get(k); ok {
|
||||
return i.(*item)
|
||||
}
|
||||
if i, ok := c.pcache.Get(k); ok {
|
||||
return i.(*item)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
cacheSize = prometheus.NewGaugeVec(prometheus.GaugeOpts{
|
||||
Namespace: plugin.Namespace,
|
||||
Subsystem: "cache",
|
||||
Name: "size",
|
||||
Help: "The number of elements in the cache.",
|
||||
}, []string{"server", "type"})
|
||||
|
||||
cacheHits = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: plugin.Namespace,
|
||||
Subsystem: "cache",
|
||||
Name: "hits_total",
|
||||
Help: "The count of cache hits.",
|
||||
}, []string{"server", "type"})
|
||||
|
||||
cacheMisses = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: plugin.Namespace,
|
||||
Subsystem: "cache",
|
||||
Name: "misses_total",
|
||||
Help: "The count of cache misses.",
|
||||
}, []string{"server"})
|
||||
|
||||
cachePrefetches = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: plugin.Namespace,
|
||||
Subsystem: "cache",
|
||||
Name: "prefetch_total",
|
||||
Help: "The number of time the cache has prefetched a cached item.",
|
||||
}, []string{"server"})
|
||||
|
||||
cacheDrops = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: plugin.Namespace,
|
||||
Subsystem: "cache",
|
||||
Name: "drops_total",
|
||||
Help: "The number responses that are not cached, because the reply is malformed.",
|
||||
}, []string{"server"})
|
||||
)
|
|
@ -0,0 +1,88 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/coredns/coredns/plugin/cache/freq"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type item struct {
|
||||
Rcode int
|
||||
Authoritative bool
|
||||
AuthenticatedData bool
|
||||
RecursionAvailable bool
|
||||
Answer []dns.RR
|
||||
Ns []dns.RR
|
||||
Extra []dns.RR
|
||||
|
||||
origTTL uint32
|
||||
stored time.Time
|
||||
|
||||
*freq.Freq
|
||||
}
|
||||
|
||||
func newItem(m *dns.Msg, now time.Time, d time.Duration) *item {
|
||||
i := new(item)
|
||||
i.Rcode = m.Rcode
|
||||
i.Authoritative = m.Authoritative
|
||||
i.AuthenticatedData = m.AuthenticatedData
|
||||
i.RecursionAvailable = m.RecursionAvailable
|
||||
i.Answer = m.Answer
|
||||
i.Ns = m.Ns
|
||||
i.Extra = make([]dns.RR, len(m.Extra))
|
||||
// Don't copy OPT records as these are hop-by-hop.
|
||||
j := 0
|
||||
for _, e := range m.Extra {
|
||||
if e.Header().Rrtype == dns.TypeOPT {
|
||||
continue
|
||||
}
|
||||
i.Extra[j] = e
|
||||
j++
|
||||
}
|
||||
i.Extra = i.Extra[:j]
|
||||
|
||||
i.origTTL = uint32(d.Seconds())
|
||||
i.stored = now.UTC()
|
||||
|
||||
i.Freq = new(freq.Freq)
|
||||
|
||||
return i
|
||||
}
|
||||
|
||||
// toMsg turns i into a message, it tailors the reply to m.
|
||||
// The Authoritative bit is always set to 0, because the answer is from the cache.
|
||||
func (i *item) toMsg(m *dns.Msg, now time.Time) *dns.Msg {
|
||||
m1 := new(dns.Msg)
|
||||
m1.SetReply(m)
|
||||
|
||||
m1.Authoritative = false
|
||||
m1.AuthenticatedData = i.AuthenticatedData
|
||||
m1.RecursionAvailable = i.RecursionAvailable
|
||||
m1.Rcode = i.Rcode
|
||||
|
||||
m1.Answer = make([]dns.RR, len(i.Answer))
|
||||
m1.Ns = make([]dns.RR, len(i.Ns))
|
||||
m1.Extra = make([]dns.RR, len(i.Extra))
|
||||
|
||||
ttl := uint32(i.ttl(now))
|
||||
for j, r := range i.Answer {
|
||||
m1.Answer[j] = dns.Copy(r)
|
||||
m1.Answer[j].Header().Ttl = ttl
|
||||
}
|
||||
for j, r := range i.Ns {
|
||||
m1.Ns[j] = dns.Copy(r)
|
||||
m1.Ns[j].Header().Ttl = ttl
|
||||
}
|
||||
// newItem skips OPT records, so we can just use i.Extra as is.
|
||||
for j, r := range i.Extra {
|
||||
m1.Extra[j] = dns.Copy(r)
|
||||
m1.Extra[j].Header().Ttl = ttl
|
||||
}
|
||||
return m1
|
||||
}
|
||||
|
||||
func (i *item) ttl(now time.Time) int {
|
||||
ttl := int(i.origTTL) - int(now.UTC().Sub(i.stored).Seconds())
|
||||
return ttl
|
||||
}
|
|
@ -0,0 +1,199 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/coredns/coredns/core/dnsserver"
|
||||
"github.com/coredns/coredns/plugin"
|
||||
"github.com/coredns/coredns/plugin/metrics"
|
||||
"github.com/coredns/coredns/plugin/pkg/cache"
|
||||
clog "github.com/coredns/coredns/plugin/pkg/log"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
)
|
||||
|
||||
var log = clog.NewWithPlugin("cache")
|
||||
|
||||
func init() {
|
||||
caddy.RegisterPlugin("cache", caddy.Plugin{
|
||||
ServerType: "dns",
|
||||
Action: setup,
|
||||
})
|
||||
}
|
||||
|
||||
func setup(c *caddy.Controller) error {
|
||||
ca, err := cacheParse(c)
|
||||
if err != nil {
|
||||
return plugin.Error("cache", err)
|
||||
}
|
||||
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
|
||||
ca.Next = next
|
||||
return ca
|
||||
})
|
||||
|
||||
c.OnStartup(func() error {
|
||||
metrics.MustRegister(c,
|
||||
cacheSize, cacheHits, cacheMisses,
|
||||
cachePrefetches, cacheDrops)
|
||||
return nil
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func cacheParse(c *caddy.Controller) (*Cache, error) {
|
||||
ca := New()
|
||||
|
||||
j := 0
|
||||
for c.Next() {
|
||||
if j > 0 {
|
||||
return nil, plugin.ErrOnce
|
||||
}
|
||||
j++
|
||||
|
||||
// cache [ttl] [zones..]
|
||||
origins := make([]string, len(c.ServerBlockKeys))
|
||||
copy(origins, c.ServerBlockKeys)
|
||||
args := c.RemainingArgs()
|
||||
|
||||
if len(args) > 0 {
|
||||
// first args may be just a number, then it is the ttl, if not it is a zone
|
||||
ttl, err := strconv.Atoi(args[0])
|
||||
if err == nil {
|
||||
// Reserve 0 (and smaller for future things)
|
||||
if ttl <= 0 {
|
||||
return nil, fmt.Errorf("cache TTL can not be zero or negative: %d", ttl)
|
||||
}
|
||||
ca.pttl = time.Duration(ttl) * time.Second
|
||||
ca.nttl = time.Duration(ttl) * time.Second
|
||||
args = args[1:]
|
||||
}
|
||||
if len(args) > 0 {
|
||||
copy(origins, args)
|
||||
}
|
||||
}
|
||||
|
||||
// Refinements? In an extra block.
|
||||
for c.NextBlock() {
|
||||
switch c.Val() {
|
||||
// first number is cap, second is an new ttl
|
||||
case Success:
|
||||
args := c.RemainingArgs()
|
||||
if len(args) == 0 {
|
||||
return nil, c.ArgErr()
|
||||
}
|
||||
pcap, err := strconv.Atoi(args[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ca.pcap = pcap
|
||||
if len(args) > 1 {
|
||||
pttl, err := strconv.Atoi(args[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Reserve 0 (and smaller for future things)
|
||||
if pttl <= 0 {
|
||||
return nil, fmt.Errorf("cache TTL can not be zero or negative: %d", pttl)
|
||||
}
|
||||
ca.pttl = time.Duration(pttl) * time.Second
|
||||
if len(args) > 2 {
|
||||
minpttl, err := strconv.Atoi(args[2])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Reserve < 0
|
||||
if minpttl < 0 {
|
||||
return nil, fmt.Errorf("cache min TTL can not be negative: %d", minpttl)
|
||||
}
|
||||
ca.minpttl = time.Duration(minpttl) * time.Second
|
||||
}
|
||||
}
|
||||
case Denial:
|
||||
args := c.RemainingArgs()
|
||||
if len(args) == 0 {
|
||||
return nil, c.ArgErr()
|
||||
}
|
||||
ncap, err := strconv.Atoi(args[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ca.ncap = ncap
|
||||
if len(args) > 1 {
|
||||
nttl, err := strconv.Atoi(args[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Reserve 0 (and smaller for future things)
|
||||
if nttl <= 0 {
|
||||
return nil, fmt.Errorf("cache TTL can not be zero or negative: %d", nttl)
|
||||
}
|
||||
ca.nttl = time.Duration(nttl) * time.Second
|
||||
if len(args) > 2 {
|
||||
minnttl, err := strconv.Atoi(args[2])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Reserve < 0
|
||||
if minnttl < 0 {
|
||||
return nil, fmt.Errorf("cache min TTL can not be negative: %d", minnttl)
|
||||
}
|
||||
ca.minnttl = time.Duration(minnttl) * time.Second
|
||||
}
|
||||
}
|
||||
case "prefetch":
|
||||
args := c.RemainingArgs()
|
||||
if len(args) == 0 || len(args) > 3 {
|
||||
return nil, c.ArgErr()
|
||||
}
|
||||
amount, err := strconv.Atoi(args[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if amount < 0 {
|
||||
return nil, fmt.Errorf("prefetch amount should be positive: %d", amount)
|
||||
}
|
||||
ca.prefetch = amount
|
||||
|
||||
if len(args) > 1 {
|
||||
dur, err := time.ParseDuration(args[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ca.duration = dur
|
||||
}
|
||||
if len(args) > 2 {
|
||||
pct := args[2]
|
||||
if x := pct[len(pct)-1]; x != '%' {
|
||||
return nil, fmt.Errorf("last character of percentage should be `%%`, but is: %q", x)
|
||||
}
|
||||
pct = pct[:len(pct)-1]
|
||||
|
||||
num, err := strconv.Atoi(pct)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if num < 10 || num > 90 {
|
||||
return nil, fmt.Errorf("percentage should fall in range [10, 90]: %d", num)
|
||||
}
|
||||
ca.percentage = num
|
||||
}
|
||||
|
||||
default:
|
||||
return nil, c.ArgErr()
|
||||
}
|
||||
}
|
||||
|
||||
for i := range origins {
|
||||
origins[i] = plugin.Host(origins[i]).Normalize()
|
||||
}
|
||||
ca.Zones = origins
|
||||
|
||||
ca.pcache = cache.New(ca.pcap)
|
||||
ca.ncache = cache.New(ca.ncap)
|
||||
}
|
||||
|
||||
return ca, nil
|
||||
}
|
|
@ -0,0 +1,14 @@
|
|||
package plugin
|
||||
|
||||
import "context"
|
||||
|
||||
// Done is a non-blocking function that returns true if the context has been canceled.
|
||||
func Done(ctx context.Context) bool {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
return false
|
||||
}
|
|
@ -0,0 +1,48 @@
|
|||
package msg
|
||||
|
||||
import (
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
"github.com/coredns/coredns/plugin/pkg/dnsutil"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Path converts a domainname to an etcd path. If s looks like service.staging.skydns.local.,
|
||||
// the resulting key will be /skydns/local/skydns/staging/service .
|
||||
func Path(s, prefix string) string {
|
||||
l := dns.SplitDomainName(s)
|
||||
for i, j := 0, len(l)-1; i < j; i, j = i+1, j-1 {
|
||||
l[i], l[j] = l[j], l[i]
|
||||
}
|
||||
return path.Join(append([]string{"/" + prefix + "/"}, l...)...)
|
||||
}
|
||||
|
||||
// Domain is the opposite of Path.
|
||||
func Domain(s string) string {
|
||||
l := strings.Split(s, "/")
|
||||
// start with 1, to strip /skydns
|
||||
for i, j := 1, len(l)-1; i < j; i, j = i+1, j-1 {
|
||||
l[i], l[j] = l[j], l[i]
|
||||
}
|
||||
return dnsutil.Join(l[1 : len(l)-1]...)
|
||||
}
|
||||
|
||||
// PathWithWildcard ascts as Path, but if a name contains wildcards (* or any), the name will be
|
||||
// chopped of before the (first) wildcard, and we do a higher level search and
|
||||
// later find the matching names. So service.*.skydns.local, will look for all
|
||||
// services under skydns.local and will later check for names that match
|
||||
// service.*.skydns.local. If a wildcard is found the returned bool is true.
|
||||
func PathWithWildcard(s, prefix string) (string, bool) {
|
||||
l := dns.SplitDomainName(s)
|
||||
for i, j := 0, len(l)-1; i < j; i, j = i+1, j-1 {
|
||||
l[i], l[j] = l[j], l[i]
|
||||
}
|
||||
for i, k := range l {
|
||||
if k == "*" || k == "any" {
|
||||
return path.Join(append([]string{"/" + prefix + "/"}, l[:i]...)...), true
|
||||
}
|
||||
}
|
||||
return path.Join(append([]string{"/" + prefix + "/"}, l...)...), false
|
||||
}
|
|
@ -0,0 +1,177 @@
|
|||
// Package msg defines the Service structure which is used for service discovery.
|
||||
package msg
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Service defines a discoverable service in etcd. It is the rdata from a SRV
|
||||
// record, but with a twist. Host (Target in SRV) must be a domain name, but
|
||||
// if it looks like an IP address (4/6), we will treat it like an IP address.
|
||||
type Service struct {
|
||||
Host string `json:"host,omitempty"`
|
||||
Port int `json:"port,omitempty"`
|
||||
Priority int `json:"priority,omitempty"`
|
||||
Weight int `json:"weight,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Mail bool `json:"mail,omitempty"` // Be an MX record. Priority becomes Preference.
|
||||
TTL uint32 `json:"ttl,omitempty"`
|
||||
|
||||
// When a SRV record with a "Host: IP-address" is added, we synthesize
|
||||
// a srv.Target domain name. Normally we convert the full Key where
|
||||
// the record lives to a DNS name and use this as the srv.Target. When
|
||||
// TargetStrip > 0 we strip the left most TargetStrip labels from the
|
||||
// DNS name.
|
||||
TargetStrip int `json:"targetstrip,omitempty"`
|
||||
|
||||
// Group is used to group (or *not* to group) different services
|
||||
// together. Services with an identical Group are returned in the same
|
||||
// answer.
|
||||
Group string `json:"group,omitempty"`
|
||||
|
||||
// Etcd key where we found this service and ignored from json un-/marshalling
|
||||
Key string `json:"-"`
|
||||
}
|
||||
|
||||
// NewSRV returns a new SRV record based on the Service.
|
||||
func (s *Service) NewSRV(name string, weight uint16) *dns.SRV {
|
||||
host := dns.Fqdn(s.Host)
|
||||
if s.TargetStrip > 0 {
|
||||
host = targetStrip(host, s.TargetStrip)
|
||||
}
|
||||
|
||||
return &dns.SRV{Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeSRV, Class: dns.ClassINET, Ttl: s.TTL},
|
||||
Priority: uint16(s.Priority), Weight: weight, Port: uint16(s.Port), Target: host}
|
||||
}
|
||||
|
||||
// NewMX returns a new MX record based on the Service.
|
||||
func (s *Service) NewMX(name string) *dns.MX {
|
||||
host := dns.Fqdn(s.Host)
|
||||
if s.TargetStrip > 0 {
|
||||
host = targetStrip(host, s.TargetStrip)
|
||||
}
|
||||
|
||||
return &dns.MX{Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeMX, Class: dns.ClassINET, Ttl: s.TTL},
|
||||
Preference: uint16(s.Priority), Mx: host}
|
||||
}
|
||||
|
||||
// NewA returns a new A record based on the Service.
|
||||
func (s *Service) NewA(name string, ip net.IP) *dns.A {
|
||||
return &dns.A{Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: s.TTL}, A: ip}
|
||||
}
|
||||
|
||||
// NewAAAA returns a new AAAA record based on the Service.
|
||||
func (s *Service) NewAAAA(name string, ip net.IP) *dns.AAAA {
|
||||
return &dns.AAAA{Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: s.TTL}, AAAA: ip}
|
||||
}
|
||||
|
||||
// NewCNAME returns a new CNAME record based on the Service.
|
||||
func (s *Service) NewCNAME(name string, target string) *dns.CNAME {
|
||||
return &dns.CNAME{Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: s.TTL}, Target: dns.Fqdn(target)}
|
||||
}
|
||||
|
||||
// NewTXT returns a new TXT record based on the Service.
|
||||
func (s *Service) NewTXT(name string) *dns.TXT {
|
||||
return &dns.TXT{Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: s.TTL}, Txt: split255(s.Text)}
|
||||
}
|
||||
|
||||
// NewPTR returns a new PTR record based on the Service.
|
||||
func (s *Service) NewPTR(name string, target string) *dns.PTR {
|
||||
return &dns.PTR{Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypePTR, Class: dns.ClassINET, Ttl: s.TTL}, Ptr: dns.Fqdn(target)}
|
||||
}
|
||||
|
||||
// NewNS returns a new NS record based on the Service.
|
||||
func (s *Service) NewNS(name string) *dns.NS {
|
||||
host := dns.Fqdn(s.Host)
|
||||
if s.TargetStrip > 0 {
|
||||
host = targetStrip(host, s.TargetStrip)
|
||||
}
|
||||
return &dns.NS{Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: s.TTL}, Ns: host}
|
||||
}
|
||||
|
||||
// Group checks the services in sx, it looks for a Group attribute on the shortest
|
||||
// keys. If there are multiple shortest keys *and* the group attribute disagrees (and
|
||||
// is not empty), we don't consider it a group.
|
||||
// If a group is found, only services with *that* group (or no group) will be returned.
|
||||
func Group(sx []Service) []Service {
|
||||
if len(sx) == 0 {
|
||||
return sx
|
||||
}
|
||||
|
||||
// Shortest key with group attribute sets the group for this set.
|
||||
group := sx[0].Group
|
||||
slashes := strings.Count(sx[0].Key, "/")
|
||||
length := make([]int, len(sx))
|
||||
for i, s := range sx {
|
||||
x := strings.Count(s.Key, "/")
|
||||
length[i] = x
|
||||
if x < slashes {
|
||||
if s.Group == "" {
|
||||
break
|
||||
}
|
||||
slashes = x
|
||||
group = s.Group
|
||||
}
|
||||
}
|
||||
|
||||
if group == "" {
|
||||
return sx
|
||||
}
|
||||
|
||||
ret := []Service{} // with slice-tricks in sx we can prolly save this allocation (TODO)
|
||||
|
||||
for i, s := range sx {
|
||||
if s.Group == "" {
|
||||
ret = append(ret, s)
|
||||
continue
|
||||
}
|
||||
|
||||
// Disagreement on the same level
|
||||
if length[i] == slashes && s.Group != group {
|
||||
return sx
|
||||
}
|
||||
|
||||
if s.Group == group {
|
||||
ret = append(ret, s)
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// Split255 splits a string into 255 byte chunks.
|
||||
func split255(s string) []string {
|
||||
if len(s) < 255 {
|
||||
return []string{s}
|
||||
}
|
||||
sx := []string{}
|
||||
p, i := 0, 255
|
||||
for {
|
||||
if i <= len(s) {
|
||||
sx = append(sx, s[p:i])
|
||||
} else {
|
||||
sx = append(sx, s[p:])
|
||||
break
|
||||
|
||||
}
|
||||
p, i = p+255, i+255
|
||||
}
|
||||
|
||||
return sx
|
||||
}
|
||||
|
||||
// targetStrip strips "targetstrip" labels from the left side of the fully qualified name.
|
||||
func targetStrip(name string, targetStrip int) string {
|
||||
offset, end := 0, false
|
||||
for i := 0; i < targetStrip; i++ {
|
||||
offset, end = dns.NextLabel(name, offset)
|
||||
}
|
||||
if end {
|
||||
// We overshot the name, use the original one.
|
||||
offset = 0
|
||||
}
|
||||
name = name[offset:]
|
||||
return name
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
package msg
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// HostType returns the DNS type of what is encoded in the Service Host field. We're reusing
|
||||
// dns.TypeXXX to not reinvent a new set of identifiers.
|
||||
//
|
||||
// dns.TypeA: the service's Host field contains an A record.
|
||||
// dns.TypeAAAA: the service's Host field contains an AAAA record.
|
||||
// dns.TypeCNAME: the service's Host field contains a name.
|
||||
//
|
||||
// Note that a service can double/triple as a TXT record or MX record.
|
||||
func (s *Service) HostType() (what uint16, normalized net.IP) {
|
||||
|
||||
ip := net.ParseIP(s.Host)
|
||||
|
||||
switch {
|
||||
case ip == nil:
|
||||
return dns.TypeCNAME, nil
|
||||
|
||||
case ip.To4() != nil:
|
||||
return dns.TypeA, ip.To4()
|
||||
|
||||
case ip.To4() == nil:
|
||||
return dns.TypeAAAA, ip.To16()
|
||||
}
|
||||
// This should never be reached.
|
||||
return dns.TypeNone, nil
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
reviewers:
|
||||
- fastest963
|
||||
- miekg
|
||||
- superq
|
||||
- greenpau
|
||||
approvers:
|
||||
- fastest963
|
||||
- miekg
|
||||
- superq
|
|
@ -0,0 +1,79 @@
|
|||
# prometheus
|
||||
|
||||
## Name
|
||||
|
||||
*prometheus* - enables [Prometheus](https://prometheus.io/) metrics.
|
||||
|
||||
## Description
|
||||
|
||||
With *prometheus* you export metrics from CoreDNS and any plugin that has them.
|
||||
The default location for the metrics is `localhost:9153`. The metrics path is fixed to `/metrics`.
|
||||
The following metrics are exported:
|
||||
|
||||
* `coredns_build_info{version, revision, goversion}` - info about CoreDNS itself.
|
||||
* `coredns_panic_count_total{}` - total number of panics.
|
||||
* `coredns_dns_request_count_total{server, zone, proto, family}` - total query count.
|
||||
* `coredns_dns_request_duration_seconds{server, zone}` - duration to process each query.
|
||||
* `coredns_dns_request_size_bytes{server, zone, proto}` - size of the request in bytes.
|
||||
* `coredns_dns_request_do_count_total{server, zone}` - queries that have the DO bit set
|
||||
* `coredns_dns_request_type_count_total{server, zone, type}` - counter of queries per zone and type.
|
||||
* `coredns_dns_response_size_bytes{server, zone, proto}` - response size in bytes.
|
||||
* `coredns_dns_response_rcode_count_total{server, zone, rcode}` - response per zone and rcode.
|
||||
* `coredns_plugin_enabled{server, zone, name}` - indicates whether a plugin is enabled on per server and zone basis.
|
||||
|
||||
Each counter has a label `zone` which is the zonename used for the request/response.
|
||||
|
||||
Extra labels used are:
|
||||
|
||||
* `server` is identifying the server responsible for the request. This is a string formatted
|
||||
as the server's listening address: `<scheme>://[<bind>]:<port>`. I.e. for a "normal" DNS server
|
||||
this is `dns://:53`. If you are using the *bind* plugin an IP address is included, e.g.: `dns://127.0.0.53:53`.
|
||||
* `proto` which holds the transport of the response ("udp" or "tcp")
|
||||
* The address family (`family`) of the transport (1 = IP (IP version 4), 2 = IP6 (IP version 6)).
|
||||
* `type` which holds the query type. It holds most common types (A, AAAA, MX, SOA, CNAME, PTR, TXT,
|
||||
NS, SRV, DS, DNSKEY, RRSIG, NSEC, NSEC3, IXFR, AXFR and ANY) and "other" which lumps together all
|
||||
other types.
|
||||
* The `response_rcode_count_total` has an extra label `rcode` which holds the rcode of the response.
|
||||
|
||||
If monitoring is enabled, queries that do not enter the plugin chain are exported under the fake
|
||||
name "dropped" (without a closing dot - this is never a valid domain name).
|
||||
|
||||
This plugin can only be used once per Server Block.
|
||||
|
||||
## Syntax
|
||||
|
||||
~~~
|
||||
prometheus [ADDRESS]
|
||||
~~~
|
||||
|
||||
For each zone that you want to see metrics for.
|
||||
|
||||
It optionally takes a bind address to which the metrics are exported; the default
|
||||
listens on `localhost:9153`. The metrics path is fixed to `/metrics`.
|
||||
|
||||
## Examples
|
||||
|
||||
Use an alternative listening address:
|
||||
|
||||
~~~ corefile
|
||||
. {
|
||||
prometheus localhost:9253
|
||||
}
|
||||
~~~
|
||||
|
||||
Or via an environment variable (this is supported throughout the Corefile): `export PORT=9253`, and
|
||||
then:
|
||||
|
||||
~~~ corefile
|
||||
. {
|
||||
prometheus localhost:{$PORT}
|
||||
}
|
||||
~~~
|
||||
|
||||
## Bugs
|
||||
|
||||
When reloading, the Prometheus handler is stopped before the new server instance is started.
|
||||
If that new server fails to start, then the initial server instance is still available and DNS queries still served,
|
||||
but Prometheus handler stays down.
|
||||
Prometheus will not reply HTTP request until a successful reload or a complete restart of CoreDNS.
|
||||
Only the plugins that register as Handler are visible in `coredns_plugin_enabled{server, zone, name}`. As of today the plugins reload and bind will not be reported.
|
|
@ -0,0 +1,24 @@
|
|||
package metrics
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/coredns/coredns/core/dnsserver"
|
||||
)
|
||||
|
||||
// WithServer returns the current server handling the request. It returns the
|
||||
// server listening address: <scheme>://[<bind>]:<port> Normally this is
|
||||
// something like "dns://:53", but if the bind plugin is used, i.e. "bind
|
||||
// 127.0.0.53", it will be "dns://127.0.0.53:53", etc. If not address is found
|
||||
// the empty string is returned.
|
||||
//
|
||||
// Basic usage with a metric:
|
||||
//
|
||||
// <metric>.WithLabelValues(metrics.WithServer(ctx), labels..).Add(1)
|
||||
func WithServer(ctx context.Context) string {
|
||||
srv := ctx.Value(dnsserver.Key{})
|
||||
if srv == nil {
|
||||
return ""
|
||||
}
|
||||
return srv.(*dnsserver.Server).Addr
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
package metrics
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/coredns/coredns/plugin"
|
||||
"github.com/coredns/coredns/plugin/metrics/vars"
|
||||
"github.com/coredns/coredns/plugin/pkg/dnstest"
|
||||
"github.com/coredns/coredns/plugin/pkg/rcode"
|
||||
"github.com/coredns/coredns/request"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// ServeDNS implements the Handler interface.
|
||||
func (m *Metrics) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
state := request.Request{W: w, Req: r}
|
||||
|
||||
qname := state.QName()
|
||||
zone := plugin.Zones(m.ZoneNames()).Matches(qname)
|
||||
if zone == "" {
|
||||
zone = "."
|
||||
}
|
||||
|
||||
// Record response to get status code and size of the reply.
|
||||
rw := dnstest.NewRecorder(w)
|
||||
status, err := plugin.NextOrFailure(m.Name(), m.Next, ctx, rw, r)
|
||||
|
||||
vars.Report(WithServer(ctx), state, zone, rcode.ToString(rw.Rcode), rw.Len, rw.Start)
|
||||
|
||||
return status, err
|
||||
}
|
||||
|
||||
// Name implements the Handler interface.
|
||||
func (m *Metrics) Name() string { return "prometheus" }
|
|
@ -0,0 +1,164 @@
|
|||
// Package metrics implement a handler and plugin that provides Prometheus metrics.
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/coredns/coredns/plugin"
|
||||
"github.com/coredns/coredns/plugin/metrics/vars"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
)
|
||||
|
||||
// Metrics holds the prometheus configuration. The metrics' path is fixed to be /metrics
|
||||
type Metrics struct {
|
||||
Next plugin.Handler
|
||||
Addr string
|
||||
Reg *prometheus.Registry
|
||||
ln net.Listener
|
||||
lnSetup bool
|
||||
mux *http.ServeMux
|
||||
srv *http.Server
|
||||
|
||||
zoneNames []string
|
||||
zoneMap map[string]struct{}
|
||||
zoneMu sync.RWMutex
|
||||
}
|
||||
|
||||
// New returns a new instance of Metrics with the given address
|
||||
func New(addr string) *Metrics {
|
||||
met := &Metrics{
|
||||
Addr: addr,
|
||||
Reg: prometheus.NewRegistry(),
|
||||
zoneMap: make(map[string]struct{}),
|
||||
}
|
||||
// Add the default collectors
|
||||
met.MustRegister(prometheus.NewGoCollector())
|
||||
met.MustRegister(prometheus.NewProcessCollector(prometheus.ProcessCollectorOpts{}))
|
||||
|
||||
// Add all of our collectors
|
||||
met.MustRegister(buildInfo)
|
||||
met.MustRegister(vars.Panic)
|
||||
met.MustRegister(vars.RequestCount)
|
||||
met.MustRegister(vars.RequestDuration)
|
||||
met.MustRegister(vars.RequestSize)
|
||||
met.MustRegister(vars.RequestDo)
|
||||
met.MustRegister(vars.RequestType)
|
||||
met.MustRegister(vars.ResponseSize)
|
||||
met.MustRegister(vars.ResponseRcode)
|
||||
met.MustRegister(vars.PluginEnabled)
|
||||
|
||||
return met
|
||||
}
|
||||
|
||||
// MustRegister wraps m.Reg.MustRegister.
|
||||
func (m *Metrics) MustRegister(c prometheus.Collector) {
|
||||
err := m.Reg.Register(c)
|
||||
if err != nil {
|
||||
// ignore any duplicate error, but fatal on any other kind of error
|
||||
if _, ok := err.(prometheus.AlreadyRegisteredError); !ok {
|
||||
log.Fatalf("Cannot register metrics collector: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AddZone adds zone z to m.
|
||||
func (m *Metrics) AddZone(z string) {
|
||||
m.zoneMu.Lock()
|
||||
m.zoneMap[z] = struct{}{}
|
||||
m.zoneNames = keys(m.zoneMap)
|
||||
m.zoneMu.Unlock()
|
||||
}
|
||||
|
||||
// RemoveZone remove zone z from m.
|
||||
func (m *Metrics) RemoveZone(z string) {
|
||||
m.zoneMu.Lock()
|
||||
delete(m.zoneMap, z)
|
||||
m.zoneNames = keys(m.zoneMap)
|
||||
m.zoneMu.Unlock()
|
||||
}
|
||||
|
||||
// ZoneNames returns the zones of m.
|
||||
func (m *Metrics) ZoneNames() []string {
|
||||
m.zoneMu.RLock()
|
||||
s := m.zoneNames
|
||||
m.zoneMu.RUnlock()
|
||||
return s
|
||||
}
|
||||
|
||||
// OnStartup sets up the metrics on startup.
|
||||
func (m *Metrics) OnStartup() error {
|
||||
ln, err := net.Listen("tcp", m.Addr)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to start metrics handler: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
m.ln = ln
|
||||
m.lnSetup = true
|
||||
ListenAddr = m.ln.Addr().String() // For tests
|
||||
|
||||
m.mux = http.NewServeMux()
|
||||
m.mux.Handle("/metrics", promhttp.HandlerFor(m.Reg, promhttp.HandlerOpts{}))
|
||||
m.srv = &http.Server{Handler: m.mux}
|
||||
go func() {
|
||||
m.srv.Serve(m.ln)
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
// OnRestart stops the listener on reload.
|
||||
func (m *Metrics) OnRestart() error {
|
||||
if !m.lnSetup {
|
||||
return nil
|
||||
}
|
||||
uniqAddr.Unset(m.Addr)
|
||||
return m.stopServer()
|
||||
}
|
||||
|
||||
func (m *Metrics) stopServer() error {
|
||||
if !m.lnSetup {
|
||||
return nil
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
|
||||
defer cancel()
|
||||
if err := m.srv.Shutdown(ctx); err != nil {
|
||||
log.Infof("Failed to stop prometheus http server: %s", err)
|
||||
return err
|
||||
}
|
||||
m.lnSetup = false
|
||||
m.ln.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// OnFinalShutdown tears down the metrics listener on shutdown and restart.
|
||||
func (m *Metrics) OnFinalShutdown() error {
|
||||
return m.stopServer()
|
||||
}
|
||||
|
||||
func keys(m map[string]struct{}) []string {
|
||||
sx := []string{}
|
||||
for k := range m {
|
||||
sx = append(sx, k)
|
||||
}
|
||||
return sx
|
||||
}
|
||||
|
||||
// ListenAddr is assigned the address of the prometheus listener. Its use is mainly in tests where
|
||||
// we listen on "localhost:0" and need to retrieve the actual address.
|
||||
var ListenAddr string
|
||||
|
||||
// shutdownTimeout is the maximum amount of time the metrics plugin will wait
|
||||
// before erroring when it tries to close the metrics server
|
||||
const shutdownTimeout time.Duration = time.Second * 5
|
||||
|
||||
var buildInfo = prometheus.NewGaugeVec(prometheus.GaugeOpts{
|
||||
Namespace: plugin.Namespace,
|
||||
Name: "build_info",
|
||||
Help: "A metric with a constant '1' value labeled by version, revision, and goversion from which CoreDNS was built.",
|
||||
}, []string{"version", "revision", "goversion"})
|
|
@ -0,0 +1,23 @@
|
|||
package metrics
|
||||
|
||||
import (
|
||||
"github.com/coredns/coredns/core/dnsserver"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
// MustRegister registers the prometheus Collectors when the metrics middleware is used.
|
||||
func MustRegister(c *caddy.Controller, cs ...prometheus.Collector) {
|
||||
m := dnsserver.GetConfig(c).Handler("prometheus")
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
x, ok := m.(*Metrics)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
for _, c := range cs {
|
||||
x.MustRegister(c)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,111 @@
|
|||
package metrics
|
||||
|
||||
import (
|
||||
"net"
|
||||
"runtime"
|
||||
|
||||
"github.com/coredns/coredns/core/dnsserver"
|
||||
"github.com/coredns/coredns/coremain"
|
||||
"github.com/coredns/coredns/plugin"
|
||||
"github.com/coredns/coredns/plugin/metrics/vars"
|
||||
clog "github.com/coredns/coredns/plugin/pkg/log"
|
||||
"github.com/coredns/coredns/plugin/pkg/uniq"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
)
|
||||
|
||||
var (
|
||||
log = clog.NewWithPlugin("prometheus")
|
||||
uniqAddr = uniq.New()
|
||||
)
|
||||
|
||||
func init() {
|
||||
caddy.RegisterPlugin("prometheus", caddy.Plugin{
|
||||
ServerType: "dns",
|
||||
Action: setup,
|
||||
})
|
||||
}
|
||||
|
||||
func setup(c *caddy.Controller) error {
|
||||
m, err := prometheusParse(c)
|
||||
if err != nil {
|
||||
return plugin.Error("prometheus", err)
|
||||
}
|
||||
|
||||
// register the metrics to its address (ensure only one active metrics per address)
|
||||
obj := uniqAddr.Set(m.Addr, m.OnStartup, m)
|
||||
//propagate the real active Registry to current metrics
|
||||
if om, ok := obj.(*Metrics); ok {
|
||||
m.Reg = om.Reg
|
||||
}
|
||||
|
||||
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
|
||||
m.Next = next
|
||||
return m
|
||||
})
|
||||
|
||||
c.OncePerServerBlock(func() error {
|
||||
c.OnStartup(func() error {
|
||||
return uniqAddr.ForEach()
|
||||
})
|
||||
return nil
|
||||
})
|
||||
|
||||
c.OnRestart(func() error {
|
||||
vars.PluginEnabled.Reset()
|
||||
return nil
|
||||
})
|
||||
|
||||
c.OnStartup(func() error {
|
||||
conf := dnsserver.GetConfig(c)
|
||||
plugins := conf.Handlers()
|
||||
for _, h := range conf.ListenHosts {
|
||||
addrstr := conf.Transport + "://" + net.JoinHostPort(h, conf.Port)
|
||||
for _, p := range plugins {
|
||||
vars.PluginEnabled.WithLabelValues(addrstr, conf.Zone, p.Name()).Set(1)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
})
|
||||
c.OnRestart(m.OnRestart)
|
||||
c.OnFinalShutdown(m.OnFinalShutdown)
|
||||
|
||||
// Initialize metrics.
|
||||
buildInfo.WithLabelValues(coremain.CoreVersion, coremain.GitCommit, runtime.Version()).Set(1)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func prometheusParse(c *caddy.Controller) (*Metrics, error) {
|
||||
var met = New(defaultAddr)
|
||||
|
||||
i := 0
|
||||
for c.Next() {
|
||||
if i > 0 {
|
||||
return nil, plugin.ErrOnce
|
||||
}
|
||||
i++
|
||||
|
||||
for _, z := range c.ServerBlockKeys {
|
||||
met.AddZone(plugin.Host(z).Normalize())
|
||||
}
|
||||
args := c.RemainingArgs()
|
||||
|
||||
switch len(args) {
|
||||
case 0:
|
||||
case 1:
|
||||
met.Addr = args[0]
|
||||
_, _, e := net.SplitHostPort(met.Addr)
|
||||
if e != nil {
|
||||
return met, e
|
||||
}
|
||||
default:
|
||||
return met, c.ArgErr()
|
||||
}
|
||||
}
|
||||
return met, nil
|
||||
}
|
||||
|
||||
// defaultAddr is the address the where the metrics are exported by default.
|
||||
const defaultAddr = "localhost:9153"
|
|
@ -0,0 +1,63 @@
|
|||
package vars
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/coredns/coredns/request"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Report reports the metrics data associated with request. This function is exported because it is also
|
||||
// called from core/dnsserver to report requests hitting the server that should not be handled and are thus
|
||||
// not sent down the plugin chain.
|
||||
func Report(server string, req request.Request, zone, rcode string, size int, start time.Time) {
|
||||
// Proto and Family.
|
||||
net := req.Proto()
|
||||
fam := "1"
|
||||
if req.Family() == 2 {
|
||||
fam = "2"
|
||||
}
|
||||
|
||||
typ := req.QType()
|
||||
RequestCount.WithLabelValues(server, zone, net, fam).Inc()
|
||||
RequestDuration.WithLabelValues(server, zone).Observe(time.Since(start).Seconds())
|
||||
|
||||
if req.Do() {
|
||||
RequestDo.WithLabelValues(server, zone).Inc()
|
||||
}
|
||||
|
||||
if _, known := monitorType[typ]; known {
|
||||
RequestType.WithLabelValues(server, zone, dns.Type(typ).String()).Inc()
|
||||
} else {
|
||||
RequestType.WithLabelValues(server, zone, other).Inc()
|
||||
}
|
||||
|
||||
ResponseSize.WithLabelValues(server, zone, net).Observe(float64(size))
|
||||
RequestSize.WithLabelValues(server, zone, net).Observe(float64(req.Len()))
|
||||
|
||||
ResponseRcode.WithLabelValues(server, zone, rcode).Inc()
|
||||
}
|
||||
|
||||
var monitorType = map[uint16]struct{}{
|
||||
dns.TypeAAAA: struct{}{},
|
||||
dns.TypeA: struct{}{},
|
||||
dns.TypeCNAME: struct{}{},
|
||||
dns.TypeDNSKEY: struct{}{},
|
||||
dns.TypeDS: struct{}{},
|
||||
dns.TypeMX: struct{}{},
|
||||
dns.TypeNSEC3: struct{}{},
|
||||
dns.TypeNSEC: struct{}{},
|
||||
dns.TypeNS: struct{}{},
|
||||
dns.TypePTR: struct{}{},
|
||||
dns.TypeRRSIG: struct{}{},
|
||||
dns.TypeSOA: struct{}{},
|
||||
dns.TypeSRV: struct{}{},
|
||||
dns.TypeTXT: struct{}{},
|
||||
// Meta Qtypes
|
||||
dns.TypeIXFR: struct{}{},
|
||||
dns.TypeAXFR: struct{}{},
|
||||
dns.TypeANY: struct{}{},
|
||||
}
|
||||
|
||||
const other = "other"
|
|
@ -0,0 +1,81 @@
|
|||
package vars
|
||||
|
||||
import (
|
||||
"github.com/coredns/coredns/plugin"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
// Request* and Response* are the prometheus counters and gauges we are using for exporting metrics.
|
||||
var (
|
||||
RequestCount = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: plugin.Namespace,
|
||||
Subsystem: subsystem,
|
||||
Name: "request_count_total",
|
||||
Help: "Counter of DNS requests made per zone, protocol and family.",
|
||||
}, []string{"server", "zone", "proto", "family"})
|
||||
|
||||
RequestDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Namespace: plugin.Namespace,
|
||||
Subsystem: subsystem,
|
||||
Name: "request_duration_seconds",
|
||||
Buckets: plugin.TimeBuckets,
|
||||
Help: "Histogram of the time (in seconds) each request took.",
|
||||
}, []string{"server", "zone"})
|
||||
|
||||
RequestSize = prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Namespace: plugin.Namespace,
|
||||
Subsystem: subsystem,
|
||||
Name: "request_size_bytes",
|
||||
Help: "Size of the EDNS0 UDP buffer in bytes (64K for TCP).",
|
||||
Buckets: []float64{0, 100, 200, 300, 400, 511, 1023, 2047, 4095, 8291, 16e3, 32e3, 48e3, 64e3},
|
||||
}, []string{"server", "zone", "proto"})
|
||||
|
||||
RequestDo = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: plugin.Namespace,
|
||||
Subsystem: subsystem,
|
||||
Name: "request_do_count_total",
|
||||
Help: "Counter of DNS requests with DO bit set per zone.",
|
||||
}, []string{"server", "zone"})
|
||||
|
||||
RequestType = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: plugin.Namespace,
|
||||
Subsystem: subsystem,
|
||||
Name: "request_type_count_total",
|
||||
Help: "Counter of DNS requests per type, per zone.",
|
||||
}, []string{"server", "zone", "type"})
|
||||
|
||||
ResponseSize = prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Namespace: plugin.Namespace,
|
||||
Subsystem: subsystem,
|
||||
Name: "response_size_bytes",
|
||||
Help: "Size of the returned response in bytes.",
|
||||
Buckets: []float64{0, 100, 200, 300, 400, 511, 1023, 2047, 4095, 8291, 16e3, 32e3, 48e3, 64e3},
|
||||
}, []string{"server", "zone", "proto"})
|
||||
|
||||
ResponseRcode = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: plugin.Namespace,
|
||||
Subsystem: subsystem,
|
||||
Name: "response_rcode_count_total",
|
||||
Help: "Counter of response status codes.",
|
||||
}, []string{"server", "zone", "rcode"})
|
||||
|
||||
Panic = prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Namespace: plugin.Namespace,
|
||||
Name: "panic_count_total",
|
||||
Help: "A metrics that counts the number of panics.",
|
||||
})
|
||||
|
||||
PluginEnabled = prometheus.NewGaugeVec(prometheus.GaugeOpts{
|
||||
Namespace: plugin.Namespace,
|
||||
Name: "plugin_enabled",
|
||||
Help: "A metric that indicates whether a plugin is enabled on per server and zone basis.",
|
||||
}, []string{"server", "zone", "name"})
|
||||
)
|
||||
|
||||
const (
|
||||
subsystem = "dns"
|
||||
|
||||
// Dropped indicates we dropped the query before any handling. It has no closing dot, so it can not be a valid zone.
|
||||
Dropped = "dropped"
|
||||
)
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue