TUN-7125: Add management streaming logs WebSocket protocol
This commit is contained in:
parent
5972540efa
commit
93acdaface
2
go.mod
2
go.mod
|
@ -46,6 +46,7 @@ require (
|
||||||
gopkg.in/natefinch/lumberjack.v2 v2.0.0
|
gopkg.in/natefinch/lumberjack.v2 v2.0.0
|
||||||
gopkg.in/square/go-jose.v2 v2.6.0
|
gopkg.in/square/go-jose.v2 v2.6.0
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
|
nhooyr.io/websocket v1.8.7
|
||||||
zombiezen.com/go/capnproto2 v2.18.0+incompatible
|
zombiezen.com/go/capnproto2 v2.18.0+incompatible
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -73,6 +74,7 @@ require (
|
||||||
github.com/golang/protobuf v1.5.2 // indirect
|
github.com/golang/protobuf v1.5.2 // indirect
|
||||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.7.0 // indirect
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.7.0 // indirect
|
||||||
github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645 // indirect
|
github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645 // indirect
|
||||||
|
github.com/klauspost/compress v1.15.11 // indirect
|
||||||
github.com/kr/text v0.2.0 // indirect
|
github.com/kr/text v0.2.0 // indirect
|
||||||
github.com/kylelemons/godebug v1.1.0 // indirect
|
github.com/kylelemons/godebug v1.1.0 // indirect
|
||||||
github.com/marten-seemann/qtls-go1-16 v0.1.5 // indirect
|
github.com/marten-seemann/qtls-go1-16 v0.1.5 // indirect
|
||||||
|
|
32
go.sum
32
go.sum
|
@ -166,6 +166,10 @@ github.com/getsentry/raven-go v0.2.0/go.mod h1:KungGk8q33+aIAZUIVWZDr2OfAEBsO49P
|
||||||
github.com/getsentry/sentry-go v0.16.0 h1:owk+S+5XcgJLlGR/3+3s6N4d+uKwqYvh/eS0AIMjPWo=
|
github.com/getsentry/sentry-go v0.16.0 h1:owk+S+5XcgJLlGR/3+3s6N4d+uKwqYvh/eS0AIMjPWo=
|
||||||
github.com/getsentry/sentry-go v0.16.0/go.mod h1:ZXCloQLj0pG7mja5NK6NPf2V4A88YJ4pNlc2mOHwh6Y=
|
github.com/getsentry/sentry-go v0.16.0/go.mod h1:ZXCloQLj0pG7mja5NK6NPf2V4A88YJ4pNlc2mOHwh6Y=
|
||||||
github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
|
github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
|
||||||
|
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
|
||||||
|
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
|
||||||
|
github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M=
|
||||||
|
github.com/gin-gonic/gin v1.8.1 h1:4+fr/el88TOO3ewCmQr8cx/CtZ/umlIRIs5M4NTNjf8=
|
||||||
github.com/gliderlabs/ssh v0.1.1/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0=
|
github.com/gliderlabs/ssh v0.1.1/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0=
|
||||||
github.com/go-chi/chi/v5 v5.0.8 h1:lD+NLqFcAi1ovnVZpsnObHGW4xb4J8lNmoYVfECH1Y0=
|
github.com/go-chi/chi/v5 v5.0.8 h1:lD+NLqFcAi1ovnVZpsnObHGW4xb4J8lNmoYVfECH1Y0=
|
||||||
github.com/go-chi/chi/v5 v5.0.8/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8=
|
github.com/go-chi/chi/v5 v5.0.8/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8=
|
||||||
|
@ -187,15 +191,26 @@ github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0=
|
||||||
github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||||
|
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||||
|
github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8=
|
||||||
|
github.com/go-playground/locales v0.14.0 h1:u50s323jtVGugKlcYeyzC0etD1HifMjqmJqb8WugfUU=
|
||||||
|
github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA=
|
||||||
|
github.com/go-playground/universal-translator v0.18.0 h1:82dyy6p4OuJq4/CByFNOn/jYrnRPArHwAcmLoJZxyho=
|
||||||
|
github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI=
|
||||||
|
github.com/go-playground/validator/v10 v10.11.1 h1:prmOlTVv+YjZjmRmNSF3VmspqJIxJWXmqUsHwfTRRkQ=
|
||||||
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
|
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
|
||||||
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I=
|
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I=
|
||||||
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
|
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
|
||||||
|
github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo=
|
||||||
github.com/gobwas/httphead v0.0.0-20200921212729-da3d93bc3c58 h1:YyrUZvJaU8Q0QsoVo+xLFBgWDTam29PKea6GYmwvSiQ=
|
github.com/gobwas/httphead v0.0.0-20200921212729-da3d93bc3c58 h1:YyrUZvJaU8Q0QsoVo+xLFBgWDTam29PKea6GYmwvSiQ=
|
||||||
github.com/gobwas/httphead v0.0.0-20200921212729-da3d93bc3c58/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo=
|
github.com/gobwas/httphead v0.0.0-20200921212729-da3d93bc3c58/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo=
|
||||||
|
github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
|
||||||
github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og=
|
github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og=
|
||||||
github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
|
github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
|
||||||
|
github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM=
|
||||||
github.com/gobwas/ws v1.0.4 h1:5eXU1CZhpQdq5kXbKb+sECH5Ia5KiO6CYzIzdlVx6Bs=
|
github.com/gobwas/ws v1.0.4 h1:5eXU1CZhpQdq5kXbKb+sECH5Ia5KiO6CYzIzdlVx6Bs=
|
||||||
github.com/gobwas/ws v1.0.4/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM=
|
github.com/gobwas/ws v1.0.4/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM=
|
||||||
|
github.com/goccy/go-json v0.9.11 h1:/pAaQDLHEoCq/5FFmSKBswWmK6H0e8g4159Kc/X/nqk=
|
||||||
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
|
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
|
||||||
github.com/golang-collections/collections v0.0.0-20130729185459-604e922904d3 h1:zN2lZNZRflqFyxVaTIU61KNKQ9C0055u9CAfpmqUvo4=
|
github.com/golang-collections/collections v0.0.0-20130729185459-604e922904d3 h1:zN2lZNZRflqFyxVaTIU61KNKQ9C0055u9CAfpmqUvo4=
|
||||||
github.com/golang-collections/collections v0.0.0-20130729185459-604e922904d3/go.mod h1:nPpo7qLxd6XL3hWJG/O60sR8ZKfMCiIoNap5GvD12KU=
|
github.com/golang-collections/collections v0.0.0-20130729185459-604e922904d3/go.mod h1:nPpo7qLxd6XL3hWJG/O60sR8ZKfMCiIoNap5GvD12KU=
|
||||||
|
@ -292,6 +307,7 @@ github.com/googleapis/gax-go/v2 v2.3.0/go.mod h1:b8LNqSzNabLiUpXKkY7HAR5jr6bIT99
|
||||||
github.com/googleapis/gax-go/v2 v2.4.0/go.mod h1:XOTVJ59hdnfJLIP/dh8n5CGryZR2LxK9wbMD5+iXC6c=
|
github.com/googleapis/gax-go/v2 v2.4.0/go.mod h1:XOTVJ59hdnfJLIP/dh8n5CGryZR2LxK9wbMD5+iXC6c=
|
||||||
github.com/googleapis/go-type-adapters v1.0.0/go.mod h1:zHW75FOG2aur7gAO2B+MLby+cLsWGBF62rFAi7WjWO4=
|
github.com/googleapis/go-type-adapters v1.0.0/go.mod h1:zHW75FOG2aur7gAO2B+MLby+cLsWGBF62rFAi7WjWO4=
|
||||||
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
|
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
|
||||||
|
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||||
github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
|
github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
|
||||||
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||||
github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA=
|
github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA=
|
||||||
|
@ -311,6 +327,7 @@ github.com/ipostelnik/cli/v2 v2.3.1-0.20210324024421-b6ea8234fe3d/go.mod h1:LJmU
|
||||||
github.com/jellevandenhooff/dkim v0.0.0-20150330215556-f50fe3d243e1/go.mod h1:E0B/fFc00Y+Rasa88328GlI/XbtyysCtTHZS8h7IrBU=
|
github.com/jellevandenhooff/dkim v0.0.0-20150330215556-f50fe3d243e1/go.mod h1:E0B/fFc00Y+Rasa88328GlI/XbtyysCtTHZS8h7IrBU=
|
||||||
github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4=
|
github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4=
|
||||||
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
|
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
|
||||||
|
github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||||
github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||||
github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||||
|
@ -320,6 +337,9 @@ github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/X
|
||||||
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
|
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
|
||||||
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
|
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
|
||||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||||
|
github.com/klauspost/compress v1.10.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
|
||||||
|
github.com/klauspost/compress v1.15.11 h1:Lcadnb3RKGin4FYM/orgq0qde+nc15E5Cbqg4B9Sx9c=
|
||||||
|
github.com/klauspost/compress v1.15.11/go.mod h1:QPwzmACJjUTFsnSHH934V6woptycfrDDJnH7hvFVbGM=
|
||||||
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||||
github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||||
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
|
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
|
||||||
|
@ -333,6 +353,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||||
|
github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII=
|
||||||
|
github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w=
|
||||||
github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI=
|
github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI=
|
||||||
github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
|
github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
|
||||||
github.com/marten-seemann/qpack v0.2.1/go.mod h1:F7Gl5L1jIgN1D11ucXefiuJS9UMVP2opoCp2jDKb7wc=
|
github.com/marten-seemann/qpack v0.2.1/go.mod h1:F7Gl5L1jIgN1D11ucXefiuJS9UMVP2opoCp2jDKb7wc=
|
||||||
|
@ -342,6 +364,7 @@ github.com/marten-seemann/qtls-go1-17 v0.1.2 h1:JADBlm0LYiVbuSySCHeY863dNkcpMmDR
|
||||||
github.com/marten-seemann/qtls-go1-17 v0.1.2/go.mod h1:C2ekUKcDdz9SDWxec1N/MvcXBpaX9l3Nx67XaR84L5s=
|
github.com/marten-seemann/qtls-go1-17 v0.1.2/go.mod h1:C2ekUKcDdz9SDWxec1N/MvcXBpaX9l3Nx67XaR84L5s=
|
||||||
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
|
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
|
||||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||||
|
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
|
||||||
github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ=
|
github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ=
|
||||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||||
github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU=
|
github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU=
|
||||||
|
@ -380,6 +403,7 @@ github.com/onsi/gomega v1.23.0/go.mod h1:Z/NWtiqwBrwUt4/2loMmHL63EDLnYHmVbuBpDr2
|
||||||
github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs=
|
github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs=
|
||||||
github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc=
|
github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc=
|
||||||
github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8=
|
github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8=
|
||||||
|
github.com/pelletier/go-toml/v2 v2.0.5 h1:ipoSadvV8oGUjnUbMub59IDPPwfxF694nG/jwbMiyQg=
|
||||||
github.com/philhofer/fwd v1.1.1 h1:GdGcTjf5RNAxwS4QLsiMzJYj5KEvPJD3Abr261yRQXQ=
|
github.com/philhofer/fwd v1.1.1 h1:GdGcTjf5RNAxwS4QLsiMzJYj5KEvPJD3Abr261yRQXQ=
|
||||||
github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4=
|
github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4=
|
||||||
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
|
@ -473,6 +497,10 @@ github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKs
|
||||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||||
github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA=
|
github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA=
|
||||||
github.com/tinylib/msgp v1.1.2 h1:gWmO7n0Ys2RBEb7GPYB9Ujq8Mk5p2U08lRnmMcGy6BQ=
|
github.com/tinylib/msgp v1.1.2 h1:gWmO7n0Ys2RBEb7GPYB9Ujq8Mk5p2U08lRnmMcGy6BQ=
|
||||||
|
github.com/ugorji/go v1.1.7 h1:/68gy2h+1mWMrwZFeD1kQialdSzAb432dtpeJ42ovdo=
|
||||||
|
github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw=
|
||||||
|
github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY=
|
||||||
|
github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0=
|
||||||
github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU=
|
github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU=
|
||||||
github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM=
|
github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM=
|
||||||
github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||||
|
@ -677,6 +705,7 @@ golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||||
golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20200113162924-86b910548bc1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20200113162924-86b910548bc1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20200122134326-e047566fdf82/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20200122134326-e047566fdf82/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20200212091648-12a6c2dcc1e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20200212091648-12a6c2dcc1e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
@ -1031,6 +1060,7 @@ gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
|
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||||
|
@ -1045,6 +1075,8 @@ honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWh
|
||||||
honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg=
|
honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg=
|
||||||
honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
|
honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
|
||||||
honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
|
honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
|
||||||
|
nhooyr.io/websocket v1.8.7 h1:usjR2uOr/zjjkVMy0lW+PPohFok7PCow5sDjLgX4P4g=
|
||||||
|
nhooyr.io/websocket v1.8.7/go.mod h1:B70DZP8IakI65RVQ51MsWP/8jndNma26DVA/nFSCgW0=
|
||||||
rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8=
|
rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8=
|
||||||
rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0=
|
rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0=
|
||||||
rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA=
|
rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA=
|
||||||
|
|
|
@ -0,0 +1,66 @@
|
||||||
|
package test
|
||||||
|
|
||||||
|
// copied from https://github.com/nhooyr/websocket/blob/master/internal/test/wstest/pipe.go
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
|
||||||
|
"nhooyr.io/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Pipe is used to create an in memory connection
|
||||||
|
// between two websockets analogous to net.Pipe.
|
||||||
|
func WSPipe(dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) (clientConn, serverConn *websocket.Conn) {
|
||||||
|
tt := fakeTransport{
|
||||||
|
h: func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
serverConn, _ = websocket.Accept(w, r, acceptOpts)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if dialOpts == nil {
|
||||||
|
dialOpts = &websocket.DialOptions{}
|
||||||
|
}
|
||||||
|
dialOpts = &*dialOpts
|
||||||
|
dialOpts.HTTPClient = &http.Client{
|
||||||
|
Transport: tt,
|
||||||
|
}
|
||||||
|
|
||||||
|
clientConn, _, _ = websocket.Dial(context.Background(), "ws://example.com", dialOpts)
|
||||||
|
return clientConn, serverConn
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakeTransport struct {
|
||||||
|
h http.HandlerFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t fakeTransport) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||||
|
clientConn, serverConn := net.Pipe()
|
||||||
|
|
||||||
|
hj := testHijacker{
|
||||||
|
ResponseRecorder: httptest.NewRecorder(),
|
||||||
|
serverConn: serverConn,
|
||||||
|
}
|
||||||
|
|
||||||
|
t.h.ServeHTTP(hj, r)
|
||||||
|
|
||||||
|
resp := hj.ResponseRecorder.Result()
|
||||||
|
if resp.StatusCode == http.StatusSwitchingProtocols {
|
||||||
|
resp.Body = clientConn
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type testHijacker struct {
|
||||||
|
*httptest.ResponseRecorder
|
||||||
|
serverConn net.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ http.Hijacker = testHijacker{}
|
||||||
|
|
||||||
|
func (hj testHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
return hj.serverConn, bufio.NewReadWriter(bufio.NewReader(hj.serverConn), bufio.NewWriter(hj.serverConn)), nil
|
||||||
|
}
|
|
@ -1,5 +1,67 @@
|
||||||
package management
|
package management
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
jsoniter "github.com/json-iterator/go"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
"nhooyr.io/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errInvalidMessageType = fmt.Errorf("invalid message type was provided")
|
||||||
|
)
|
||||||
|
|
||||||
|
// ServerEventType represents the event types that can come from the server
|
||||||
|
type ServerEventType string
|
||||||
|
|
||||||
|
// ClientEventType represents the event types that can come from the client
|
||||||
|
type ClientEventType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
UnknownClientEventType ClientEventType = ""
|
||||||
|
StartStreaming ClientEventType = "start_streaming"
|
||||||
|
StopStreaming ClientEventType = "stop_streaming"
|
||||||
|
|
||||||
|
UnknownServerEventType ServerEventType = ""
|
||||||
|
Logs ServerEventType = "logs"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ServerEvent is the base struct that informs, based of the Type field, which Event type was provided from the server.
|
||||||
|
type ServerEvent struct {
|
||||||
|
Type ServerEventType `json:"type,omitempty"`
|
||||||
|
// The raw json message is provided to allow better deserialization once the type is known
|
||||||
|
event jsoniter.RawMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClientEvent is the base struct that informs, based of the Type field, which Event type was provided from the client.
|
||||||
|
type ClientEvent struct {
|
||||||
|
Type ClientEventType `json:"type,omitempty"`
|
||||||
|
// The raw json message is provided to allow better deserialization once the type is known
|
||||||
|
event jsoniter.RawMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
// EventStartStreaming signifies that the client wishes to start receiving log events.
|
||||||
|
// Additional filters can be provided to augment the log events requested.
|
||||||
|
type EventStartStreaming struct {
|
||||||
|
ClientEvent
|
||||||
|
Filters []string `json:"filters"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// EventStopStreaming signifies that the client wishes to halt receiving log events.
|
||||||
|
type EventStopStreaming struct {
|
||||||
|
ClientEvent
|
||||||
|
}
|
||||||
|
|
||||||
|
// EventLog is the event that the server sends to the client with the log events.
|
||||||
|
type EventLog struct {
|
||||||
|
ServerEvent
|
||||||
|
Logs []Log `json:"logs"`
|
||||||
|
}
|
||||||
|
|
||||||
// LogEventType is the way that logging messages are able to be filtered.
|
// LogEventType is the way that logging messages are able to be filtered.
|
||||||
// Example: assigning LogEventType.Cloudflared to a zerolog event will allow the client to filter for only
|
// Example: assigning LogEventType.Cloudflared to a zerolog event will allow the client to filter for only
|
||||||
// the Cloudflared-related events.
|
// the Cloudflared-related events.
|
||||||
|
@ -38,3 +100,113 @@ const (
|
||||||
Warn LogLevel = "warn"
|
Warn LogLevel = "warn"
|
||||||
Error LogLevel = "error"
|
Error LogLevel = "error"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Log is the basic structure of the events that are sent to the client.
|
||||||
|
type Log struct {
|
||||||
|
Event LogEventType `json:"event"`
|
||||||
|
Timestamp string `json:"timestamp"`
|
||||||
|
Level LogLevel `json:"level"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// IntoClientEvent unmarshals the provided ClientEvent into the proper type.
|
||||||
|
func IntoClientEvent[T EventStartStreaming | EventStopStreaming](e *ClientEvent, eventType ClientEventType) (*T, bool) {
|
||||||
|
if e.Type != eventType {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
event := new(T)
|
||||||
|
err := json.Unmarshal(e.event, event)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return event, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// IntoServerEvent unmarshals the provided ServerEvent into the proper type.
|
||||||
|
func IntoServerEvent[T EventLog](e *ServerEvent, eventType ServerEventType) (*T, bool) {
|
||||||
|
if e.Type != eventType {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
event := new(T)
|
||||||
|
err := json.Unmarshal(e.event, event)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return event, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadEvent will read a message from the websocket connection and parse it into a valid ServerEvent.
|
||||||
|
func ReadServerEvent(c *websocket.Conn, ctx context.Context) (*ServerEvent, error) {
|
||||||
|
message, err := readMessage(c, ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
event := ServerEvent{}
|
||||||
|
if err := json.Unmarshal(message, &event); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
switch event.Type {
|
||||||
|
case Logs:
|
||||||
|
event.event = message
|
||||||
|
return &event, nil
|
||||||
|
case UnknownServerEventType:
|
||||||
|
return nil, errInvalidMessageType
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid server message type was provided: %s", event.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadEvent will read a message from the websocket connection and parse it into a valid ClientEvent.
|
||||||
|
func ReadClientEvent(c *websocket.Conn, ctx context.Context) (*ClientEvent, error) {
|
||||||
|
message, err := readMessage(c, ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
event := ClientEvent{}
|
||||||
|
if err := json.Unmarshal(message, &event); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
switch event.Type {
|
||||||
|
case StartStreaming, StopStreaming:
|
||||||
|
event.event = message
|
||||||
|
return &event, nil
|
||||||
|
case UnknownClientEventType:
|
||||||
|
return nil, errInvalidMessageType
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid client message type was provided: %s", event.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// readMessage will read a message from the websocket connection and return the payload.
|
||||||
|
func readMessage(c *websocket.Conn, ctx context.Context) ([]byte, error) {
|
||||||
|
messageType, reader, err := c.Reader(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if messageType != websocket.MessageText {
|
||||||
|
return nil, errInvalidMessageType
|
||||||
|
}
|
||||||
|
return io.ReadAll(reader)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteEvent will write a Event type message to the websocket connection.
|
||||||
|
func WriteEvent(c *websocket.Conn, ctx context.Context, event any) error {
|
||||||
|
payload, err := json.Marshal(event)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.Write(ctx, websocket.MessageText, payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsClosed returns true if the websocket error is a websocket.CloseError; returns false if not a
|
||||||
|
// websocket.CloseError
|
||||||
|
func IsClosed(err error, log *zerolog.Logger) bool {
|
||||||
|
var closeErr websocket.CloseError
|
||||||
|
if errors.As(err, &closeErr) {
|
||||||
|
if closeErr.Code != websocket.StatusNormalClosure {
|
||||||
|
log.Debug().Msgf("connection is already closed: (%d) %s", closeErr.Code, closeErr.Reason)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,168 @@
|
||||||
|
package management
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"nhooyr.io/websocket"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/internal/test"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIntoClientEvent_StartStreaming(t *testing.T) {
|
||||||
|
event := ClientEvent{
|
||||||
|
Type: StartStreaming,
|
||||||
|
event: []byte(`{"type": "start_streaming"}`),
|
||||||
|
}
|
||||||
|
ce, ok := IntoClientEvent[EventStartStreaming](&event, StartStreaming)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, EventStartStreaming{ClientEvent: ClientEvent{Type: StartStreaming}}, *ce)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntoClientEvent_StopStreaming(t *testing.T) {
|
||||||
|
event := ClientEvent{
|
||||||
|
Type: StopStreaming,
|
||||||
|
event: []byte(`{"type": "stop_streaming"}`),
|
||||||
|
}
|
||||||
|
ce, ok := IntoClientEvent[EventStopStreaming](&event, StopStreaming)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, EventStopStreaming{ClientEvent: ClientEvent{Type: StopStreaming}}, *ce)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntoClientEvent_Invalid(t *testing.T) {
|
||||||
|
event := ClientEvent{
|
||||||
|
Type: UnknownClientEventType,
|
||||||
|
event: []byte(`{"type": "invalid"}`),
|
||||||
|
}
|
||||||
|
_, ok := IntoClientEvent[EventStartStreaming](&event, StartStreaming)
|
||||||
|
require.False(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntoServerEvent_Logs(t *testing.T) {
|
||||||
|
event := ServerEvent{
|
||||||
|
Type: Logs,
|
||||||
|
event: []byte(`{"type": "logs"}`),
|
||||||
|
}
|
||||||
|
ce, ok := IntoServerEvent(&event, Logs)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, EventLog{ServerEvent: ServerEvent{Type: Logs}}, *ce)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntoServerEvent_Invalid(t *testing.T) {
|
||||||
|
event := ServerEvent{
|
||||||
|
Type: UnknownServerEventType,
|
||||||
|
event: []byte(`{"type": "invalid"}`),
|
||||||
|
}
|
||||||
|
_, ok := IntoServerEvent(&event, Logs)
|
||||||
|
require.False(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadServerEvent(t *testing.T) {
|
||||||
|
sentEvent := EventLog{
|
||||||
|
ServerEvent: ServerEvent{Type: Logs},
|
||||||
|
Logs: []Log{
|
||||||
|
{
|
||||||
|
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||||
|
Event: HTTP,
|
||||||
|
Level: Info,
|
||||||
|
Message: "test",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
client, server := test.WSPipe(nil, nil)
|
||||||
|
server.CloseRead(context.Background())
|
||||||
|
defer func() {
|
||||||
|
server.Close(websocket.StatusInternalError, "")
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
err := WriteEvent(server, context.Background(), &sentEvent)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}()
|
||||||
|
event, err := ReadServerEvent(client, context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, sentEvent.Type, event.Type)
|
||||||
|
client.Close(websocket.StatusInternalError, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadServerEvent_InvalidWebSocketMessageType(t *testing.T) {
|
||||||
|
client, server := test.WSPipe(nil, nil)
|
||||||
|
server.CloseRead(context.Background())
|
||||||
|
defer func() {
|
||||||
|
server.Close(websocket.StatusInternalError, "")
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
err := server.Write(context.Background(), websocket.MessageBinary, []byte("test1234"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
}()
|
||||||
|
_, err := ReadServerEvent(client, context.Background())
|
||||||
|
require.Error(t, err)
|
||||||
|
client.Close(websocket.StatusInternalError, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadServerEvent_InvalidMessageType(t *testing.T) {
|
||||||
|
sentEvent := ClientEvent{Type: ClientEventType(UnknownServerEventType)}
|
||||||
|
client, server := test.WSPipe(nil, nil)
|
||||||
|
server.CloseRead(context.Background())
|
||||||
|
defer func() {
|
||||||
|
server.Close(websocket.StatusInternalError, "")
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
err := WriteEvent(server, context.Background(), &sentEvent)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}()
|
||||||
|
_, err := ReadServerEvent(client, context.Background())
|
||||||
|
require.ErrorIs(t, err, errInvalidMessageType)
|
||||||
|
client.Close(websocket.StatusInternalError, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadClientEvent(t *testing.T) {
|
||||||
|
sentEvent := EventStartStreaming{
|
||||||
|
ClientEvent: ClientEvent{Type: StartStreaming},
|
||||||
|
}
|
||||||
|
client, server := test.WSPipe(nil, nil)
|
||||||
|
client.CloseRead(context.Background())
|
||||||
|
defer func() {
|
||||||
|
client.Close(websocket.StatusInternalError, "")
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
err := WriteEvent(client, context.Background(), &sentEvent)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}()
|
||||||
|
event, err := ReadClientEvent(server, context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, sentEvent.Type, event.Type)
|
||||||
|
server.Close(websocket.StatusInternalError, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadClientEvent_InvalidWebSocketMessageType(t *testing.T) {
|
||||||
|
client, server := test.WSPipe(nil, nil)
|
||||||
|
client.CloseRead(context.Background())
|
||||||
|
defer func() {
|
||||||
|
client.Close(websocket.StatusInternalError, "")
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
err := client.Write(context.Background(), websocket.MessageBinary, []byte("test1234"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
}()
|
||||||
|
_, err := ReadClientEvent(server, context.Background())
|
||||||
|
require.Error(t, err)
|
||||||
|
server.Close(websocket.StatusInternalError, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadClientEvent_InvalidMessageType(t *testing.T) {
|
||||||
|
sentEvent := ClientEvent{Type: UnknownClientEventType}
|
||||||
|
client, server := test.WSPipe(nil, nil)
|
||||||
|
client.CloseRead(context.Background())
|
||||||
|
defer func() {
|
||||||
|
client.Close(websocket.StatusInternalError, "")
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
err := WriteEvent(client, context.Background(), &sentEvent)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}()
|
||||||
|
_, err := ReadClientEvent(server, context.Background())
|
||||||
|
require.ErrorIs(t, err, errInvalidMessageType)
|
||||||
|
server.Close(websocket.StatusInternalError, "")
|
||||||
|
}
|
|
@ -1,10 +1,24 @@
|
||||||
package management
|
package management
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
"nhooyr.io/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// In the current state, an invalid command was provided by the client
|
||||||
|
StatusInvalidCommand websocket.StatusCode = 4001
|
||||||
|
reasonInvalidCommand = "expected start streaming as first event"
|
||||||
|
// There are a limited number of available streaming log sessions that cloudflared will service, exceeding this
|
||||||
|
// value will return this error to incoming requests.
|
||||||
|
StatusSessionLimitExceeded websocket.StatusCode = 4002
|
||||||
|
reasonSessionLimitExceeded = "limit exceeded for streaming sessions"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ManagementService struct {
|
type ManagementService struct {
|
||||||
|
@ -13,6 +27,14 @@ type ManagementService struct {
|
||||||
|
|
||||||
log *zerolog.Logger
|
log *zerolog.Logger
|
||||||
router chi.Router
|
router chi.Router
|
||||||
|
|
||||||
|
// streaming signifies if the service is already streaming logs. Helps limit the number of active users streaming logs
|
||||||
|
// from this cloudflared instance.
|
||||||
|
streaming atomic.Bool
|
||||||
|
// streamingMut is a lock to prevent concurrent requests to start streaming. Utilizing the atomic.Bool is not
|
||||||
|
// sufficient to complete this operation since many other checks during an incoming new request are needed
|
||||||
|
// to validate this before setting streaming to true.
|
||||||
|
streamingMut sync.Mutex
|
||||||
logger LoggerListener
|
logger LoggerListener
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -25,6 +47,7 @@ func New(managementHostname string, log *zerolog.Logger, logger LoggerListener)
|
||||||
r := chi.NewRouter()
|
r := chi.NewRouter()
|
||||||
r.Get("/ping", ping)
|
r.Get("/ping", ping)
|
||||||
r.Head("/ping", ping)
|
r.Head("/ping", ping)
|
||||||
|
r.Get("/logs", s.logs)
|
||||||
s.router = r
|
s.router = r
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
@ -37,3 +60,127 @@ func (m *ManagementService) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
func ping(w http.ResponseWriter, r *http.Request) {
|
func ping(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(200)
|
w.WriteHeader(200)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// readEvents will loop through all incoming websocket messages from a client and marshal them into the
|
||||||
|
// proper Event structure and pass through to the events channel. Any invalid messages sent will automatically
|
||||||
|
// terminate the connection.
|
||||||
|
func (m *ManagementService) readEvents(c *websocket.Conn, ctx context.Context, events chan<- *ClientEvent) {
|
||||||
|
for {
|
||||||
|
event, err := ReadClientEvent(c, ctx)
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
if err != nil {
|
||||||
|
// If the client (or the server) already closed the connection, don't attempt to close it again
|
||||||
|
if !IsClosed(err, m.log) {
|
||||||
|
m.log.Err(err).Send()
|
||||||
|
m.log.Err(c.Close(websocket.StatusUnsupportedData, err.Error())).Send()
|
||||||
|
}
|
||||||
|
// Any errors when reading the messages from the client will close the connection
|
||||||
|
return
|
||||||
|
}
|
||||||
|
events <- event
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// streamLogs will begin the process of reading from the Session listener and write the log events to the client.
|
||||||
|
func (m *ManagementService) streamLogs(c *websocket.Conn, ctx context.Context, session *Session) {
|
||||||
|
defer m.logger.Close(session)
|
||||||
|
for m.streaming.Load() {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
m.streaming.Store(false)
|
||||||
|
return
|
||||||
|
case event := <-session.listener:
|
||||||
|
err := WriteEvent(c, ctx, &EventLog{
|
||||||
|
ServerEvent: ServerEvent{Type: Logs},
|
||||||
|
Logs: []Log{{
|
||||||
|
Event: Cloudflared,
|
||||||
|
Timestamp: event.Time,
|
||||||
|
Level: event.Level,
|
||||||
|
Message: event.Message,
|
||||||
|
}},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
// If the client (or the server) already closed the connection, don't attempt to close it again
|
||||||
|
if !IsClosed(err, m.log) {
|
||||||
|
m.log.Err(err).Send()
|
||||||
|
m.log.Err(c.Close(websocket.StatusInternalError, err.Error())).Send()
|
||||||
|
}
|
||||||
|
// Any errors when writing the messages to the client will stop streaming and close the connection
|
||||||
|
m.streaming.Store(false)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// No messages to send
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// startStreaming will check the conditions of the request and begin streaming or close the connection for invalid
|
||||||
|
// requests.
|
||||||
|
func (m *ManagementService) startStreaming(c *websocket.Conn, ctx context.Context, event *ClientEvent) {
|
||||||
|
m.streamingMut.Lock()
|
||||||
|
defer m.streamingMut.Unlock()
|
||||||
|
// Limits to one user for streaming logs
|
||||||
|
if m.streaming.Load() {
|
||||||
|
m.log.Warn().
|
||||||
|
Msgf("Another management session request was attempted but one session already being served; there is a limit of streaming log sessions to reduce overall performance impact.")
|
||||||
|
m.log.Err(c.Close(StatusSessionLimitExceeded, reasonSessionLimitExceeded)).Send()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Expect the first incoming request
|
||||||
|
_, ok := IntoClientEvent[EventStartStreaming](event, StartStreaming)
|
||||||
|
if !ok {
|
||||||
|
m.log.Err(c.Close(StatusInvalidCommand, reasonInvalidCommand)).Msgf("expected start_streaming as first recieved event")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m.streaming.Store(true)
|
||||||
|
listener := m.logger.Listen()
|
||||||
|
m.log.Debug().Msgf("Streaming logs")
|
||||||
|
go m.streamLogs(c, ctx, listener)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Management Streaming Logs accept handler
|
||||||
|
func (m *ManagementService) logs(w http.ResponseWriter, r *http.Request) {
|
||||||
|
c, err := websocket.Accept(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
m.log.Debug().Msgf("management handshake: %s", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Make sure the connection is closed if other go routines fail to close the connection after completing.
|
||||||
|
defer c.Close(websocket.StatusInternalError, "")
|
||||||
|
ctx := r.Context()
|
||||||
|
events := make(chan *ClientEvent)
|
||||||
|
go m.readEvents(c, ctx, events)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
m.log.Debug().Msgf("management logs: context cancelled")
|
||||||
|
c.Close(websocket.StatusNormalClosure, "context closed")
|
||||||
|
return
|
||||||
|
case event := <-events:
|
||||||
|
switch event.Type {
|
||||||
|
case StartStreaming:
|
||||||
|
m.startStreaming(c, ctx, event)
|
||||||
|
continue
|
||||||
|
case StopStreaming:
|
||||||
|
// TODO: limit StopStreaming to only halt streaming for clients that are already streaming
|
||||||
|
m.streaming.Store(false)
|
||||||
|
case UnknownClientEventType:
|
||||||
|
fallthrough
|
||||||
|
default:
|
||||||
|
// Drop unknown events and close connection
|
||||||
|
m.log.Debug().Msgf("unexpected management message received: %s", event.Type)
|
||||||
|
// If the client (or the server) already closed the connection, don't attempt to close it again
|
||||||
|
if !IsClosed(err, m.log) {
|
||||||
|
m.log.Err(err).Err(c.Close(websocket.StatusUnsupportedData, err.Error())).Send()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,61 @@
|
||||||
|
package management
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"nhooyr.io/websocket"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/internal/test"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
noopLogger = zerolog.New(io.Discard)
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestReadEventsLoop(t *testing.T) {
|
||||||
|
sentEvent := EventStartStreaming{
|
||||||
|
ClientEvent: ClientEvent{Type: StartStreaming},
|
||||||
|
}
|
||||||
|
client, server := test.WSPipe(nil, nil)
|
||||||
|
client.CloseRead(context.Background())
|
||||||
|
defer func() {
|
||||||
|
client.Close(websocket.StatusInternalError, "")
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
err := WriteEvent(client, context.Background(), &sentEvent)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}()
|
||||||
|
m := ManagementService{
|
||||||
|
log: &noopLogger,
|
||||||
|
}
|
||||||
|
events := make(chan *ClientEvent)
|
||||||
|
go m.readEvents(server, context.Background(), events)
|
||||||
|
event := <-events
|
||||||
|
require.Equal(t, sentEvent.Type, event.Type)
|
||||||
|
server.Close(websocket.StatusInternalError, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadEventsLoop_ContextCancelled(t *testing.T) {
|
||||||
|
client, server := test.WSPipe(nil, nil)
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
client.CloseRead(ctx)
|
||||||
|
defer func() {
|
||||||
|
client.Close(websocket.StatusInternalError, "")
|
||||||
|
}()
|
||||||
|
m := ManagementService{
|
||||||
|
log: &noopLogger,
|
||||||
|
}
|
||||||
|
events := make(chan *ClientEvent)
|
||||||
|
go func() {
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
// Want to make sure this function returns when context is cancelled
|
||||||
|
m.readEvents(server, ctx, events)
|
||||||
|
server.Close(websocket.StatusInternalError, "")
|
||||||
|
}
|
|
@ -142,7 +142,7 @@ func (p *Proxy) ProxyHTTP(
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
case ingress.HTTPLocalProxy:
|
case ingress.HTTPLocalProxy:
|
||||||
originProxy.ServeHTTP(w, req)
|
p.proxyLocalRequest(originProxy, w, req, isWebsocket)
|
||||||
return nil
|
return nil
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("Unrecognized service: %s, %t", rule.Service, originProxy)
|
return fmt.Errorf("Unrecognized service: %s, %t", rule.Service, originProxy)
|
||||||
|
@ -306,6 +306,17 @@ func (p *Proxy) proxyStream(
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *Proxy) proxyLocalRequest(proxy ingress.HTTPLocalProxy, w connection.ResponseWriter, req *http.Request, isWebsocket bool) {
|
||||||
|
if isWebsocket {
|
||||||
|
// These headers are added since they are stripped off during an eyeball request to origintunneld, but they
|
||||||
|
// are required during the Handshake process of a WebSocket request.
|
||||||
|
req.Header.Set("Connection", "Upgrade")
|
||||||
|
req.Header.Set("Upgrade", "websocket")
|
||||||
|
req.Header.Set("Sec-Websocket-Version", "13")
|
||||||
|
}
|
||||||
|
proxy.ServeHTTP(w, req)
|
||||||
|
}
|
||||||
|
|
||||||
type bidirectionalStream struct {
|
type bidirectionalStream struct {
|
||||||
reader io.Reader
|
reader io.Reader
|
||||||
writer io.Writer
|
writer io.Writer
|
||||||
|
|
|
@ -0,0 +1,304 @@
|
||||||
|
Copyright (c) 2012 The Go Authors. All rights reserved.
|
||||||
|
Copyright (c) 2019 Klaus Post. All rights reserved.
|
||||||
|
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions are
|
||||||
|
met:
|
||||||
|
|
||||||
|
* Redistributions of source code must retain the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer.
|
||||||
|
* Redistributions in binary form must reproduce the above
|
||||||
|
copyright notice, this list of conditions and the following disclaimer
|
||||||
|
in the documentation and/or other materials provided with the
|
||||||
|
distribution.
|
||||||
|
* Neither the name of Google Inc. nor the names of its
|
||||||
|
contributors may be used to endorse or promote products derived from
|
||||||
|
this software without specific prior written permission.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
|
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||||
|
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||||
|
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||||
|
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||||
|
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
|
------------------
|
||||||
|
|
||||||
|
Files: gzhttp/*
|
||||||
|
|
||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright 2016-2017 The New York Times Company
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
|
||||||
|
------------------
|
||||||
|
|
||||||
|
Files: s2/cmd/internal/readahead/*
|
||||||
|
|
||||||
|
The MIT License (MIT)
|
||||||
|
|
||||||
|
Copyright (c) 2015 Klaus Post
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
|
|
||||||
|
---------------------
|
||||||
|
Files: snappy/*
|
||||||
|
Files: internal/snapref/*
|
||||||
|
|
||||||
|
Copyright (c) 2011 The Snappy-Go Authors. All rights reserved.
|
||||||
|
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions are
|
||||||
|
met:
|
||||||
|
|
||||||
|
* Redistributions of source code must retain the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer.
|
||||||
|
* Redistributions in binary form must reproduce the above
|
||||||
|
copyright notice, this list of conditions and the following disclaimer
|
||||||
|
in the documentation and/or other materials provided with the
|
||||||
|
distribution.
|
||||||
|
* Neither the name of Google Inc. nor the names of its
|
||||||
|
contributors may be used to endorse or promote products derived from
|
||||||
|
this software without specific prior written permission.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
|
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||||
|
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||||
|
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||||
|
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||||
|
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
|
-----------------
|
||||||
|
|
||||||
|
Files: s2/cmd/internal/filepathx/*
|
||||||
|
|
||||||
|
Copyright 2016 The filepathx Authors
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
@ -0,0 +1,910 @@
|
||||||
|
// Copyright 2009 The Go Authors. All rights reserved.
|
||||||
|
// Copyright (c) 2015 Klaus Post
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package flate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"math"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
NoCompression = 0
|
||||||
|
BestSpeed = 1
|
||||||
|
BestCompression = 9
|
||||||
|
DefaultCompression = -1
|
||||||
|
|
||||||
|
// HuffmanOnly disables Lempel-Ziv match searching and only performs Huffman
|
||||||
|
// entropy encoding. This mode is useful in compressing data that has
|
||||||
|
// already been compressed with an LZ style algorithm (e.g. Snappy or LZ4)
|
||||||
|
// that lacks an entropy encoder. Compression gains are achieved when
|
||||||
|
// certain bytes in the input stream occur more frequently than others.
|
||||||
|
//
|
||||||
|
// Note that HuffmanOnly produces a compressed output that is
|
||||||
|
// RFC 1951 compliant. That is, any valid DEFLATE decompressor will
|
||||||
|
// continue to be able to decompress this output.
|
||||||
|
HuffmanOnly = -2
|
||||||
|
ConstantCompression = HuffmanOnly // compatibility alias.
|
||||||
|
|
||||||
|
logWindowSize = 15
|
||||||
|
windowSize = 1 << logWindowSize
|
||||||
|
windowMask = windowSize - 1
|
||||||
|
logMaxOffsetSize = 15 // Standard DEFLATE
|
||||||
|
minMatchLength = 4 // The smallest match that the compressor looks for
|
||||||
|
maxMatchLength = 258 // The longest match for the compressor
|
||||||
|
minOffsetSize = 1 // The shortest offset that makes any sense
|
||||||
|
|
||||||
|
// The maximum number of tokens we will encode at the time.
|
||||||
|
// Smaller sizes usually creates less optimal blocks.
|
||||||
|
// Bigger can make context switching slow.
|
||||||
|
// We use this for levels 7-9, so we make it big.
|
||||||
|
maxFlateBlockTokens = 1 << 15
|
||||||
|
maxStoreBlockSize = 65535
|
||||||
|
hashBits = 17 // After 17 performance degrades
|
||||||
|
hashSize = 1 << hashBits
|
||||||
|
hashMask = (1 << hashBits) - 1
|
||||||
|
hashShift = (hashBits + minMatchLength - 1) / minMatchLength
|
||||||
|
maxHashOffset = 1 << 28
|
||||||
|
|
||||||
|
skipNever = math.MaxInt32
|
||||||
|
|
||||||
|
debugDeflate = false
|
||||||
|
)
|
||||||
|
|
||||||
|
type compressionLevel struct {
|
||||||
|
good, lazy, nice, chain, fastSkipHashing, level int
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compression levels have been rebalanced from zlib deflate defaults
|
||||||
|
// to give a bigger spread in speed and compression.
|
||||||
|
// See https://blog.klauspost.com/rebalancing-deflate-compression-levels/
|
||||||
|
var levels = []compressionLevel{
|
||||||
|
{}, // 0
|
||||||
|
// Level 1-6 uses specialized algorithm - values not used
|
||||||
|
{0, 0, 0, 0, 0, 1},
|
||||||
|
{0, 0, 0, 0, 0, 2},
|
||||||
|
{0, 0, 0, 0, 0, 3},
|
||||||
|
{0, 0, 0, 0, 0, 4},
|
||||||
|
{0, 0, 0, 0, 0, 5},
|
||||||
|
{0, 0, 0, 0, 0, 6},
|
||||||
|
// Levels 7-9 use increasingly more lazy matching
|
||||||
|
// and increasingly stringent conditions for "good enough".
|
||||||
|
{8, 12, 16, 24, skipNever, 7},
|
||||||
|
{16, 30, 40, 64, skipNever, 8},
|
||||||
|
{32, 258, 258, 1024, skipNever, 9},
|
||||||
|
}
|
||||||
|
|
||||||
|
// advancedState contains state for the advanced levels, with bigger hash tables, etc.
|
||||||
|
type advancedState struct {
|
||||||
|
// deflate state
|
||||||
|
length int
|
||||||
|
offset int
|
||||||
|
maxInsertIndex int
|
||||||
|
chainHead int
|
||||||
|
hashOffset int
|
||||||
|
|
||||||
|
ii uint16 // position of last match, intended to overflow to reset.
|
||||||
|
|
||||||
|
// input window: unprocessed data is window[index:windowEnd]
|
||||||
|
index int
|
||||||
|
estBitsPerByte int
|
||||||
|
hashMatch [maxMatchLength + minMatchLength]uint32
|
||||||
|
|
||||||
|
// Input hash chains
|
||||||
|
// hashHead[hashValue] contains the largest inputIndex with the specified hash value
|
||||||
|
// If hashHead[hashValue] is within the current window, then
|
||||||
|
// hashPrev[hashHead[hashValue] & windowMask] contains the previous index
|
||||||
|
// with the same hash value.
|
||||||
|
hashHead [hashSize]uint32
|
||||||
|
hashPrev [windowSize]uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
type compressor struct {
|
||||||
|
compressionLevel
|
||||||
|
|
||||||
|
h *huffmanEncoder
|
||||||
|
w *huffmanBitWriter
|
||||||
|
|
||||||
|
// compression algorithm
|
||||||
|
fill func(*compressor, []byte) int // copy data to window
|
||||||
|
step func(*compressor) // process window
|
||||||
|
|
||||||
|
window []byte
|
||||||
|
windowEnd int
|
||||||
|
blockStart int // window index where current tokens start
|
||||||
|
err error
|
||||||
|
|
||||||
|
// queued output tokens
|
||||||
|
tokens tokens
|
||||||
|
fast fastEnc
|
||||||
|
state *advancedState
|
||||||
|
|
||||||
|
sync bool // requesting flush
|
||||||
|
byteAvailable bool // if true, still need to process window[index-1].
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *compressor) fillDeflate(b []byte) int {
|
||||||
|
s := d.state
|
||||||
|
if s.index >= 2*windowSize-(minMatchLength+maxMatchLength) {
|
||||||
|
// shift the window by windowSize
|
||||||
|
//copy(d.window[:], d.window[windowSize:2*windowSize])
|
||||||
|
*(*[windowSize]byte)(d.window) = *(*[windowSize]byte)(d.window[windowSize:])
|
||||||
|
s.index -= windowSize
|
||||||
|
d.windowEnd -= windowSize
|
||||||
|
if d.blockStart >= windowSize {
|
||||||
|
d.blockStart -= windowSize
|
||||||
|
} else {
|
||||||
|
d.blockStart = math.MaxInt32
|
||||||
|
}
|
||||||
|
s.hashOffset += windowSize
|
||||||
|
if s.hashOffset > maxHashOffset {
|
||||||
|
delta := s.hashOffset - 1
|
||||||
|
s.hashOffset -= delta
|
||||||
|
s.chainHead -= delta
|
||||||
|
// Iterate over slices instead of arrays to avoid copying
|
||||||
|
// the entire table onto the stack (Issue #18625).
|
||||||
|
for i, v := range s.hashPrev[:] {
|
||||||
|
if int(v) > delta {
|
||||||
|
s.hashPrev[i] = uint32(int(v) - delta)
|
||||||
|
} else {
|
||||||
|
s.hashPrev[i] = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for i, v := range s.hashHead[:] {
|
||||||
|
if int(v) > delta {
|
||||||
|
s.hashHead[i] = uint32(int(v) - delta)
|
||||||
|
} else {
|
||||||
|
s.hashHead[i] = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
n := copy(d.window[d.windowEnd:], b)
|
||||||
|
d.windowEnd += n
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *compressor) writeBlock(tok *tokens, index int, eof bool) error {
|
||||||
|
if index > 0 || eof {
|
||||||
|
var window []byte
|
||||||
|
if d.blockStart <= index {
|
||||||
|
window = d.window[d.blockStart:index]
|
||||||
|
}
|
||||||
|
d.blockStart = index
|
||||||
|
//d.w.writeBlock(tok, eof, window)
|
||||||
|
d.w.writeBlockDynamic(tok, eof, window, d.sync)
|
||||||
|
return d.w.err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeBlockSkip writes the current block and uses the number of tokens
|
||||||
|
// to determine if the block should be stored on no matches, or
|
||||||
|
// only huffman encoded.
|
||||||
|
func (d *compressor) writeBlockSkip(tok *tokens, index int, eof bool) error {
|
||||||
|
if index > 0 || eof {
|
||||||
|
if d.blockStart <= index {
|
||||||
|
window := d.window[d.blockStart:index]
|
||||||
|
// If we removed less than a 64th of all literals
|
||||||
|
// we huffman compress the block.
|
||||||
|
if int(tok.n) > len(window)-int(tok.n>>6) {
|
||||||
|
d.w.writeBlockHuff(eof, window, d.sync)
|
||||||
|
} else {
|
||||||
|
// Write a dynamic huffman block.
|
||||||
|
d.w.writeBlockDynamic(tok, eof, window, d.sync)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
d.w.writeBlock(tok, eof, nil)
|
||||||
|
}
|
||||||
|
d.blockStart = index
|
||||||
|
return d.w.err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// fillWindow will fill the current window with the supplied
|
||||||
|
// dictionary and calculate all hashes.
|
||||||
|
// This is much faster than doing a full encode.
|
||||||
|
// Should only be used after a start/reset.
|
||||||
|
func (d *compressor) fillWindow(b []byte) {
|
||||||
|
// Do not fill window if we are in store-only or huffman mode.
|
||||||
|
if d.level <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if d.fast != nil {
|
||||||
|
// encode the last data, but discard the result
|
||||||
|
if len(b) > maxMatchOffset {
|
||||||
|
b = b[len(b)-maxMatchOffset:]
|
||||||
|
}
|
||||||
|
d.fast.Encode(&d.tokens, b)
|
||||||
|
d.tokens.Reset()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s := d.state
|
||||||
|
// If we are given too much, cut it.
|
||||||
|
if len(b) > windowSize {
|
||||||
|
b = b[len(b)-windowSize:]
|
||||||
|
}
|
||||||
|
// Add all to window.
|
||||||
|
n := copy(d.window[d.windowEnd:], b)
|
||||||
|
|
||||||
|
// Calculate 256 hashes at the time (more L1 cache hits)
|
||||||
|
loops := (n + 256 - minMatchLength) / 256
|
||||||
|
for j := 0; j < loops; j++ {
|
||||||
|
startindex := j * 256
|
||||||
|
end := startindex + 256 + minMatchLength - 1
|
||||||
|
if end > n {
|
||||||
|
end = n
|
||||||
|
}
|
||||||
|
tocheck := d.window[startindex:end]
|
||||||
|
dstSize := len(tocheck) - minMatchLength + 1
|
||||||
|
|
||||||
|
if dstSize <= 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
dst := s.hashMatch[:dstSize]
|
||||||
|
bulkHash4(tocheck, dst)
|
||||||
|
var newH uint32
|
||||||
|
for i, val := range dst {
|
||||||
|
di := i + startindex
|
||||||
|
newH = val & hashMask
|
||||||
|
// Get previous value with the same hash.
|
||||||
|
// Our chain should point to the previous value.
|
||||||
|
s.hashPrev[di&windowMask] = s.hashHead[newH]
|
||||||
|
// Set the head of the hash chain to us.
|
||||||
|
s.hashHead[newH] = uint32(di + s.hashOffset)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Update window information.
|
||||||
|
d.windowEnd += n
|
||||||
|
s.index = n
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to find a match starting at index whose length is greater than prevSize.
|
||||||
|
// We only look at chainCount possibilities before giving up.
|
||||||
|
// pos = s.index, prevHead = s.chainHead-s.hashOffset, prevLength=minMatchLength-1, lookahead
|
||||||
|
func (d *compressor) findMatch(pos int, prevHead int, lookahead int) (length, offset int, ok bool) {
|
||||||
|
minMatchLook := maxMatchLength
|
||||||
|
if lookahead < minMatchLook {
|
||||||
|
minMatchLook = lookahead
|
||||||
|
}
|
||||||
|
|
||||||
|
win := d.window[0 : pos+minMatchLook]
|
||||||
|
|
||||||
|
// We quit when we get a match that's at least nice long
|
||||||
|
nice := len(win) - pos
|
||||||
|
if d.nice < nice {
|
||||||
|
nice = d.nice
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we've got a match that's good enough, only look in 1/4 the chain.
|
||||||
|
tries := d.chain
|
||||||
|
length = minMatchLength - 1
|
||||||
|
|
||||||
|
wEnd := win[pos+length]
|
||||||
|
wPos := win[pos:]
|
||||||
|
minIndex := pos - windowSize
|
||||||
|
if minIndex < 0 {
|
||||||
|
minIndex = 0
|
||||||
|
}
|
||||||
|
offset = 0
|
||||||
|
|
||||||
|
cGain := 0
|
||||||
|
if d.chain < 100 {
|
||||||
|
for i := prevHead; tries > 0; tries-- {
|
||||||
|
if wEnd == win[i+length] {
|
||||||
|
n := matchLen(win[i:i+minMatchLook], wPos)
|
||||||
|
if n > length {
|
||||||
|
length = n
|
||||||
|
offset = pos - i
|
||||||
|
ok = true
|
||||||
|
if n >= nice {
|
||||||
|
// The match is good enough that we don't try to find a better one.
|
||||||
|
break
|
||||||
|
}
|
||||||
|
wEnd = win[pos+n]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if i <= minIndex {
|
||||||
|
// hashPrev[i & windowMask] has already been overwritten, so stop now.
|
||||||
|
break
|
||||||
|
}
|
||||||
|
i = int(d.state.hashPrev[i&windowMask]) - d.state.hashOffset
|
||||||
|
if i < minIndex {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Some like it higher (CSV), some like it lower (JSON)
|
||||||
|
const baseCost = 6
|
||||||
|
// Base is 4 bytes at with an additional cost.
|
||||||
|
// Matches must be better than this.
|
||||||
|
for i := prevHead; tries > 0; tries-- {
|
||||||
|
if wEnd == win[i+length] {
|
||||||
|
n := matchLen(win[i:i+minMatchLook], wPos)
|
||||||
|
if n > length {
|
||||||
|
// Calculate gain. Estimate
|
||||||
|
newGain := d.h.bitLengthRaw(wPos[:n]) - int(offsetExtraBits[offsetCode(uint32(pos-i))]) - baseCost - int(lengthExtraBits[lengthCodes[(n-3)&255]])
|
||||||
|
|
||||||
|
//fmt.Println(n, "gain:", newGain, "prev:", cGain, "raw:", d.h.bitLengthRaw(wPos[:n]))
|
||||||
|
if newGain > cGain {
|
||||||
|
length = n
|
||||||
|
offset = pos - i
|
||||||
|
cGain = newGain
|
||||||
|
ok = true
|
||||||
|
if n >= nice {
|
||||||
|
// The match is good enough that we don't try to find a better one.
|
||||||
|
break
|
||||||
|
}
|
||||||
|
wEnd = win[pos+n]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if i <= minIndex {
|
||||||
|
// hashPrev[i & windowMask] has already been overwritten, so stop now.
|
||||||
|
break
|
||||||
|
}
|
||||||
|
i = int(d.state.hashPrev[i&windowMask]) - d.state.hashOffset
|
||||||
|
if i < minIndex {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *compressor) writeStoredBlock(buf []byte) error {
|
||||||
|
if d.w.writeStoredHeader(len(buf), false); d.w.err != nil {
|
||||||
|
return d.w.err
|
||||||
|
}
|
||||||
|
d.w.writeBytes(buf)
|
||||||
|
return d.w.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// hash4 returns a hash representation of the first 4 bytes
|
||||||
|
// of the supplied slice.
|
||||||
|
// The caller must ensure that len(b) >= 4.
|
||||||
|
func hash4(b []byte) uint32 {
|
||||||
|
return hash4u(binary.LittleEndian.Uint32(b), hashBits)
|
||||||
|
}
|
||||||
|
|
||||||
|
// hash4 returns the hash of u to fit in a hash table with h bits.
|
||||||
|
// Preferably h should be a constant and should always be <32.
|
||||||
|
func hash4u(u uint32, h uint8) uint32 {
|
||||||
|
return (u * prime4bytes) >> (32 - h)
|
||||||
|
}
|
||||||
|
|
||||||
|
// bulkHash4 will compute hashes using the same
|
||||||
|
// algorithm as hash4
|
||||||
|
func bulkHash4(b []byte, dst []uint32) {
|
||||||
|
if len(b) < 4 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
hb := binary.LittleEndian.Uint32(b)
|
||||||
|
|
||||||
|
dst[0] = hash4u(hb, hashBits)
|
||||||
|
end := len(b) - 4 + 1
|
||||||
|
for i := 1; i < end; i++ {
|
||||||
|
hb = (hb >> 8) | uint32(b[i+3])<<24
|
||||||
|
dst[i] = hash4u(hb, hashBits)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *compressor) initDeflate() {
|
||||||
|
d.window = make([]byte, 2*windowSize)
|
||||||
|
d.byteAvailable = false
|
||||||
|
d.err = nil
|
||||||
|
if d.state == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s := d.state
|
||||||
|
s.index = 0
|
||||||
|
s.hashOffset = 1
|
||||||
|
s.length = minMatchLength - 1
|
||||||
|
s.offset = 0
|
||||||
|
s.chainHead = -1
|
||||||
|
}
|
||||||
|
|
||||||
|
// deflateLazy is the same as deflate, but with d.fastSkipHashing == skipNever,
|
||||||
|
// meaning it always has lazy matching on.
|
||||||
|
func (d *compressor) deflateLazy() {
|
||||||
|
s := d.state
|
||||||
|
// Sanity enables additional runtime tests.
|
||||||
|
// It's intended to be used during development
|
||||||
|
// to supplement the currently ad-hoc unit tests.
|
||||||
|
const sanity = debugDeflate
|
||||||
|
|
||||||
|
if d.windowEnd-s.index < minMatchLength+maxMatchLength && !d.sync {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if d.windowEnd != s.index && d.chain > 100 {
|
||||||
|
// Get literal huffman coder.
|
||||||
|
if d.h == nil {
|
||||||
|
d.h = newHuffmanEncoder(maxFlateBlockTokens)
|
||||||
|
}
|
||||||
|
var tmp [256]uint16
|
||||||
|
for _, v := range d.window[s.index:d.windowEnd] {
|
||||||
|
tmp[v]++
|
||||||
|
}
|
||||||
|
d.h.generate(tmp[:], 15)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.maxInsertIndex = d.windowEnd - (minMatchLength - 1)
|
||||||
|
|
||||||
|
for {
|
||||||
|
if sanity && s.index > d.windowEnd {
|
||||||
|
panic("index > windowEnd")
|
||||||
|
}
|
||||||
|
lookahead := d.windowEnd - s.index
|
||||||
|
if lookahead < minMatchLength+maxMatchLength {
|
||||||
|
if !d.sync {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if sanity && s.index > d.windowEnd {
|
||||||
|
panic("index > windowEnd")
|
||||||
|
}
|
||||||
|
if lookahead == 0 {
|
||||||
|
// Flush current output block if any.
|
||||||
|
if d.byteAvailable {
|
||||||
|
// There is still one pending token that needs to be flushed
|
||||||
|
d.tokens.AddLiteral(d.window[s.index-1])
|
||||||
|
d.byteAvailable = false
|
||||||
|
}
|
||||||
|
if d.tokens.n > 0 {
|
||||||
|
if d.err = d.writeBlock(&d.tokens, s.index, false); d.err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
d.tokens.Reset()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if s.index < s.maxInsertIndex {
|
||||||
|
// Update the hash
|
||||||
|
hash := hash4(d.window[s.index:])
|
||||||
|
ch := s.hashHead[hash]
|
||||||
|
s.chainHead = int(ch)
|
||||||
|
s.hashPrev[s.index&windowMask] = ch
|
||||||
|
s.hashHead[hash] = uint32(s.index + s.hashOffset)
|
||||||
|
}
|
||||||
|
prevLength := s.length
|
||||||
|
prevOffset := s.offset
|
||||||
|
s.length = minMatchLength - 1
|
||||||
|
s.offset = 0
|
||||||
|
minIndex := s.index - windowSize
|
||||||
|
if minIndex < 0 {
|
||||||
|
minIndex = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.chainHead-s.hashOffset >= minIndex && lookahead > prevLength && prevLength < d.lazy {
|
||||||
|
if newLength, newOffset, ok := d.findMatch(s.index, s.chainHead-s.hashOffset, lookahead); ok {
|
||||||
|
s.length = newLength
|
||||||
|
s.offset = newOffset
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if prevLength >= minMatchLength && s.length <= prevLength {
|
||||||
|
// Check for better match at end...
|
||||||
|
//
|
||||||
|
// checkOff must be >=2 since we otherwise risk checking s.index
|
||||||
|
// Offset of 2 seems to yield best results.
|
||||||
|
const checkOff = 2
|
||||||
|
prevIndex := s.index - 1
|
||||||
|
if prevIndex+prevLength+checkOff < s.maxInsertIndex {
|
||||||
|
end := lookahead
|
||||||
|
if lookahead > maxMatchLength {
|
||||||
|
end = maxMatchLength
|
||||||
|
}
|
||||||
|
end += prevIndex
|
||||||
|
idx := prevIndex + prevLength - (4 - checkOff)
|
||||||
|
h := hash4(d.window[idx:])
|
||||||
|
ch2 := int(s.hashHead[h]) - s.hashOffset - prevLength + (4 - checkOff)
|
||||||
|
if ch2 > minIndex {
|
||||||
|
length := matchLen(d.window[prevIndex:end], d.window[ch2:])
|
||||||
|
// It seems like a pure length metric is best.
|
||||||
|
if length > prevLength {
|
||||||
|
prevLength = length
|
||||||
|
prevOffset = prevIndex - ch2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// There was a match at the previous step, and the current match is
|
||||||
|
// not better. Output the previous match.
|
||||||
|
d.tokens.AddMatch(uint32(prevLength-3), uint32(prevOffset-minOffsetSize))
|
||||||
|
|
||||||
|
// Insert in the hash table all strings up to the end of the match.
|
||||||
|
// index and index-1 are already inserted. If there is not enough
|
||||||
|
// lookahead, the last two strings are not inserted into the hash
|
||||||
|
// table.
|
||||||
|
newIndex := s.index + prevLength - 1
|
||||||
|
// Calculate missing hashes
|
||||||
|
end := newIndex
|
||||||
|
if end > s.maxInsertIndex {
|
||||||
|
end = s.maxInsertIndex
|
||||||
|
}
|
||||||
|
end += minMatchLength - 1
|
||||||
|
startindex := s.index + 1
|
||||||
|
if startindex > s.maxInsertIndex {
|
||||||
|
startindex = s.maxInsertIndex
|
||||||
|
}
|
||||||
|
tocheck := d.window[startindex:end]
|
||||||
|
dstSize := len(tocheck) - minMatchLength + 1
|
||||||
|
if dstSize > 0 {
|
||||||
|
dst := s.hashMatch[:dstSize]
|
||||||
|
bulkHash4(tocheck, dst)
|
||||||
|
var newH uint32
|
||||||
|
for i, val := range dst {
|
||||||
|
di := i + startindex
|
||||||
|
newH = val & hashMask
|
||||||
|
// Get previous value with the same hash.
|
||||||
|
// Our chain should point to the previous value.
|
||||||
|
s.hashPrev[di&windowMask] = s.hashHead[newH]
|
||||||
|
// Set the head of the hash chain to us.
|
||||||
|
s.hashHead[newH] = uint32(di + s.hashOffset)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.index = newIndex
|
||||||
|
d.byteAvailable = false
|
||||||
|
s.length = minMatchLength - 1
|
||||||
|
if d.tokens.n == maxFlateBlockTokens {
|
||||||
|
// The block includes the current character
|
||||||
|
if d.err = d.writeBlock(&d.tokens, s.index, false); d.err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
d.tokens.Reset()
|
||||||
|
}
|
||||||
|
s.ii = 0
|
||||||
|
} else {
|
||||||
|
// Reset, if we got a match this run.
|
||||||
|
if s.length >= minMatchLength {
|
||||||
|
s.ii = 0
|
||||||
|
}
|
||||||
|
// We have a byte waiting. Emit it.
|
||||||
|
if d.byteAvailable {
|
||||||
|
s.ii++
|
||||||
|
d.tokens.AddLiteral(d.window[s.index-1])
|
||||||
|
if d.tokens.n == maxFlateBlockTokens {
|
||||||
|
if d.err = d.writeBlock(&d.tokens, s.index, false); d.err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
d.tokens.Reset()
|
||||||
|
}
|
||||||
|
s.index++
|
||||||
|
|
||||||
|
// If we have a long run of no matches, skip additional bytes
|
||||||
|
// Resets when s.ii overflows after 64KB.
|
||||||
|
if n := int(s.ii) - d.chain; n > 0 {
|
||||||
|
n = 1 + int(n>>6)
|
||||||
|
for j := 0; j < n; j++ {
|
||||||
|
if s.index >= d.windowEnd-1 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
d.tokens.AddLiteral(d.window[s.index-1])
|
||||||
|
if d.tokens.n == maxFlateBlockTokens {
|
||||||
|
if d.err = d.writeBlock(&d.tokens, s.index, false); d.err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
d.tokens.Reset()
|
||||||
|
}
|
||||||
|
// Index...
|
||||||
|
if s.index < s.maxInsertIndex {
|
||||||
|
h := hash4(d.window[s.index:])
|
||||||
|
ch := s.hashHead[h]
|
||||||
|
s.chainHead = int(ch)
|
||||||
|
s.hashPrev[s.index&windowMask] = ch
|
||||||
|
s.hashHead[h] = uint32(s.index + s.hashOffset)
|
||||||
|
}
|
||||||
|
s.index++
|
||||||
|
}
|
||||||
|
// Flush last byte
|
||||||
|
d.tokens.AddLiteral(d.window[s.index-1])
|
||||||
|
d.byteAvailable = false
|
||||||
|
// s.length = minMatchLength - 1 // not needed, since s.ii is reset above, so it should never be > minMatchLength
|
||||||
|
if d.tokens.n == maxFlateBlockTokens {
|
||||||
|
if d.err = d.writeBlock(&d.tokens, s.index, false); d.err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
d.tokens.Reset()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
s.index++
|
||||||
|
d.byteAvailable = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *compressor) store() {
|
||||||
|
if d.windowEnd > 0 && (d.windowEnd == maxStoreBlockSize || d.sync) {
|
||||||
|
d.err = d.writeStoredBlock(d.window[:d.windowEnd])
|
||||||
|
d.windowEnd = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// fillWindow will fill the buffer with data for huffman-only compression.
|
||||||
|
// The number of bytes copied is returned.
|
||||||
|
func (d *compressor) fillBlock(b []byte) int {
|
||||||
|
n := copy(d.window[d.windowEnd:], b)
|
||||||
|
d.windowEnd += n
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
// storeHuff will compress and store the currently added data,
|
||||||
|
// if enough has been accumulated or we at the end of the stream.
|
||||||
|
// Any error that occurred will be in d.err
|
||||||
|
func (d *compressor) storeHuff() {
|
||||||
|
if d.windowEnd < len(d.window) && !d.sync || d.windowEnd == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
d.w.writeBlockHuff(false, d.window[:d.windowEnd], d.sync)
|
||||||
|
d.err = d.w.err
|
||||||
|
d.windowEnd = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// storeFast will compress and store the currently added data,
|
||||||
|
// if enough has been accumulated or we at the end of the stream.
|
||||||
|
// Any error that occurred will be in d.err
|
||||||
|
func (d *compressor) storeFast() {
|
||||||
|
// We only compress if we have maxStoreBlockSize.
|
||||||
|
if d.windowEnd < len(d.window) {
|
||||||
|
if !d.sync {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Handle extremely small sizes.
|
||||||
|
if d.windowEnd < 128 {
|
||||||
|
if d.windowEnd == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if d.windowEnd <= 32 {
|
||||||
|
d.err = d.writeStoredBlock(d.window[:d.windowEnd])
|
||||||
|
} else {
|
||||||
|
d.w.writeBlockHuff(false, d.window[:d.windowEnd], true)
|
||||||
|
d.err = d.w.err
|
||||||
|
}
|
||||||
|
d.tokens.Reset()
|
||||||
|
d.windowEnd = 0
|
||||||
|
d.fast.Reset()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
d.fast.Encode(&d.tokens, d.window[:d.windowEnd])
|
||||||
|
// If we made zero matches, store the block as is.
|
||||||
|
if d.tokens.n == 0 {
|
||||||
|
d.err = d.writeStoredBlock(d.window[:d.windowEnd])
|
||||||
|
// If we removed less than 1/16th, huffman compress the block.
|
||||||
|
} else if int(d.tokens.n) > d.windowEnd-(d.windowEnd>>4) {
|
||||||
|
d.w.writeBlockHuff(false, d.window[:d.windowEnd], d.sync)
|
||||||
|
d.err = d.w.err
|
||||||
|
} else {
|
||||||
|
d.w.writeBlockDynamic(&d.tokens, false, d.window[:d.windowEnd], d.sync)
|
||||||
|
d.err = d.w.err
|
||||||
|
}
|
||||||
|
d.tokens.Reset()
|
||||||
|
d.windowEnd = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// write will add input byte to the stream.
|
||||||
|
// Unless an error occurs all bytes will be consumed.
|
||||||
|
func (d *compressor) write(b []byte) (n int, err error) {
|
||||||
|
if d.err != nil {
|
||||||
|
return 0, d.err
|
||||||
|
}
|
||||||
|
n = len(b)
|
||||||
|
for len(b) > 0 {
|
||||||
|
if d.windowEnd == len(d.window) || d.sync {
|
||||||
|
d.step(d)
|
||||||
|
}
|
||||||
|
b = b[d.fill(d, b):]
|
||||||
|
if d.err != nil {
|
||||||
|
return 0, d.err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return n, d.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *compressor) syncFlush() error {
|
||||||
|
d.sync = true
|
||||||
|
if d.err != nil {
|
||||||
|
return d.err
|
||||||
|
}
|
||||||
|
d.step(d)
|
||||||
|
if d.err == nil {
|
||||||
|
d.w.writeStoredHeader(0, false)
|
||||||
|
d.w.flush()
|
||||||
|
d.err = d.w.err
|
||||||
|
}
|
||||||
|
d.sync = false
|
||||||
|
return d.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *compressor) init(w io.Writer, level int) (err error) {
|
||||||
|
d.w = newHuffmanBitWriter(w)
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case level == NoCompression:
|
||||||
|
d.window = make([]byte, maxStoreBlockSize)
|
||||||
|
d.fill = (*compressor).fillBlock
|
||||||
|
d.step = (*compressor).store
|
||||||
|
case level == ConstantCompression:
|
||||||
|
d.w.logNewTablePenalty = 10
|
||||||
|
d.window = make([]byte, 32<<10)
|
||||||
|
d.fill = (*compressor).fillBlock
|
||||||
|
d.step = (*compressor).storeHuff
|
||||||
|
case level == DefaultCompression:
|
||||||
|
level = 5
|
||||||
|
fallthrough
|
||||||
|
case level >= 1 && level <= 6:
|
||||||
|
d.w.logNewTablePenalty = 7
|
||||||
|
d.fast = newFastEnc(level)
|
||||||
|
d.window = make([]byte, maxStoreBlockSize)
|
||||||
|
d.fill = (*compressor).fillBlock
|
||||||
|
d.step = (*compressor).storeFast
|
||||||
|
case 7 <= level && level <= 9:
|
||||||
|
d.w.logNewTablePenalty = 8
|
||||||
|
d.state = &advancedState{}
|
||||||
|
d.compressionLevel = levels[level]
|
||||||
|
d.initDeflate()
|
||||||
|
d.fill = (*compressor).fillDeflate
|
||||||
|
d.step = (*compressor).deflateLazy
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("flate: invalid compression level %d: want value in range [-2, 9]", level)
|
||||||
|
}
|
||||||
|
d.level = level
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// reset the state of the compressor.
|
||||||
|
func (d *compressor) reset(w io.Writer) {
|
||||||
|
d.w.reset(w)
|
||||||
|
d.sync = false
|
||||||
|
d.err = nil
|
||||||
|
// We only need to reset a few things for Snappy.
|
||||||
|
if d.fast != nil {
|
||||||
|
d.fast.Reset()
|
||||||
|
d.windowEnd = 0
|
||||||
|
d.tokens.Reset()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch d.compressionLevel.chain {
|
||||||
|
case 0:
|
||||||
|
// level was NoCompression or ConstantCompresssion.
|
||||||
|
d.windowEnd = 0
|
||||||
|
default:
|
||||||
|
s := d.state
|
||||||
|
s.chainHead = -1
|
||||||
|
for i := range s.hashHead {
|
||||||
|
s.hashHead[i] = 0
|
||||||
|
}
|
||||||
|
for i := range s.hashPrev {
|
||||||
|
s.hashPrev[i] = 0
|
||||||
|
}
|
||||||
|
s.hashOffset = 1
|
||||||
|
s.index, d.windowEnd = 0, 0
|
||||||
|
d.blockStart, d.byteAvailable = 0, false
|
||||||
|
d.tokens.Reset()
|
||||||
|
s.length = minMatchLength - 1
|
||||||
|
s.offset = 0
|
||||||
|
s.ii = 0
|
||||||
|
s.maxInsertIndex = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *compressor) close() error {
|
||||||
|
if d.err != nil {
|
||||||
|
return d.err
|
||||||
|
}
|
||||||
|
d.sync = true
|
||||||
|
d.step(d)
|
||||||
|
if d.err != nil {
|
||||||
|
return d.err
|
||||||
|
}
|
||||||
|
if d.w.writeStoredHeader(0, true); d.w.err != nil {
|
||||||
|
return d.w.err
|
||||||
|
}
|
||||||
|
d.w.flush()
|
||||||
|
d.w.reset(nil)
|
||||||
|
return d.w.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWriter returns a new Writer compressing data at the given level.
|
||||||
|
// Following zlib, levels range from 1 (BestSpeed) to 9 (BestCompression);
|
||||||
|
// higher levels typically run slower but compress more.
|
||||||
|
// Level 0 (NoCompression) does not attempt any compression; it only adds the
|
||||||
|
// necessary DEFLATE framing.
|
||||||
|
// Level -1 (DefaultCompression) uses the default compression level.
|
||||||
|
// Level -2 (ConstantCompression) will use Huffman compression only, giving
|
||||||
|
// a very fast compression for all types of input, but sacrificing considerable
|
||||||
|
// compression efficiency.
|
||||||
|
//
|
||||||
|
// If level is in the range [-2, 9] then the error returned will be nil.
|
||||||
|
// Otherwise the error returned will be non-nil.
|
||||||
|
func NewWriter(w io.Writer, level int) (*Writer, error) {
|
||||||
|
var dw Writer
|
||||||
|
if err := dw.d.init(w, level); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &dw, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWriterDict is like NewWriter but initializes the new
|
||||||
|
// Writer with a preset dictionary. The returned Writer behaves
|
||||||
|
// as if the dictionary had been written to it without producing
|
||||||
|
// any compressed output. The compressed data written to w
|
||||||
|
// can only be decompressed by a Reader initialized with the
|
||||||
|
// same dictionary.
|
||||||
|
func NewWriterDict(w io.Writer, level int, dict []byte) (*Writer, error) {
|
||||||
|
zw, err := NewWriter(w, level)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
zw.d.fillWindow(dict)
|
||||||
|
zw.dict = append(zw.dict, dict...) // duplicate dictionary for Reset method.
|
||||||
|
return zw, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// A Writer takes data written to it and writes the compressed
|
||||||
|
// form of that data to an underlying writer (see NewWriter).
|
||||||
|
type Writer struct {
|
||||||
|
d compressor
|
||||||
|
dict []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write writes data to w, which will eventually write the
|
||||||
|
// compressed form of data to its underlying writer.
|
||||||
|
func (w *Writer) Write(data []byte) (n int, err error) {
|
||||||
|
return w.d.write(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush flushes any pending data to the underlying writer.
|
||||||
|
// It is useful mainly in compressed network protocols, to ensure that
|
||||||
|
// a remote reader has enough data to reconstruct a packet.
|
||||||
|
// Flush does not return until the data has been written.
|
||||||
|
// Calling Flush when there is no pending data still causes the Writer
|
||||||
|
// to emit a sync marker of at least 4 bytes.
|
||||||
|
// If the underlying writer returns an error, Flush returns that error.
|
||||||
|
//
|
||||||
|
// In the terminology of the zlib library, Flush is equivalent to Z_SYNC_FLUSH.
|
||||||
|
func (w *Writer) Flush() error {
|
||||||
|
// For more about flushing:
|
||||||
|
// http://www.bolet.org/~pornin/deflate-flush.html
|
||||||
|
return w.d.syncFlush()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close flushes and closes the writer.
|
||||||
|
func (w *Writer) Close() error {
|
||||||
|
return w.d.close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset discards the writer's state and makes it equivalent to
|
||||||
|
// the result of NewWriter or NewWriterDict called with dst
|
||||||
|
// and w's level and dictionary.
|
||||||
|
func (w *Writer) Reset(dst io.Writer) {
|
||||||
|
if len(w.dict) > 0 {
|
||||||
|
// w was created with NewWriterDict
|
||||||
|
w.d.reset(dst)
|
||||||
|
if dst != nil {
|
||||||
|
w.d.fillWindow(w.dict)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// w was created with NewWriter
|
||||||
|
w.d.reset(dst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetDict discards the writer's state and makes it equivalent to
|
||||||
|
// the result of NewWriter or NewWriterDict called with dst
|
||||||
|
// and w's level, but sets a specific dictionary.
|
||||||
|
func (w *Writer) ResetDict(dst io.Writer, dict []byte) {
|
||||||
|
w.dict = dict
|
||||||
|
w.d.reset(dst)
|
||||||
|
w.d.fillWindow(w.dict)
|
||||||
|
}
|
|
@ -0,0 +1,184 @@
|
||||||
|
// Copyright 2016 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package flate
|
||||||
|
|
||||||
|
// dictDecoder implements the LZ77 sliding dictionary as used in decompression.
|
||||||
|
// LZ77 decompresses data through sequences of two forms of commands:
|
||||||
|
//
|
||||||
|
// - Literal insertions: Runs of one or more symbols are inserted into the data
|
||||||
|
// stream as is. This is accomplished through the writeByte method for a
|
||||||
|
// single symbol, or combinations of writeSlice/writeMark for multiple symbols.
|
||||||
|
// Any valid stream must start with a literal insertion if no preset dictionary
|
||||||
|
// is used.
|
||||||
|
//
|
||||||
|
// - Backward copies: Runs of one or more symbols are copied from previously
|
||||||
|
// emitted data. Backward copies come as the tuple (dist, length) where dist
|
||||||
|
// determines how far back in the stream to copy from and length determines how
|
||||||
|
// many bytes to copy. Note that it is valid for the length to be greater than
|
||||||
|
// the distance. Since LZ77 uses forward copies, that situation is used to
|
||||||
|
// perform a form of run-length encoding on repeated runs of symbols.
|
||||||
|
// The writeCopy and tryWriteCopy are used to implement this command.
|
||||||
|
//
|
||||||
|
// For performance reasons, this implementation performs little to no sanity
|
||||||
|
// checks about the arguments. As such, the invariants documented for each
|
||||||
|
// method call must be respected.
|
||||||
|
type dictDecoder struct {
|
||||||
|
hist []byte // Sliding window history
|
||||||
|
|
||||||
|
// Invariant: 0 <= rdPos <= wrPos <= len(hist)
|
||||||
|
wrPos int // Current output position in buffer
|
||||||
|
rdPos int // Have emitted hist[:rdPos] already
|
||||||
|
full bool // Has a full window length been written yet?
|
||||||
|
}
|
||||||
|
|
||||||
|
// init initializes dictDecoder to have a sliding window dictionary of the given
|
||||||
|
// size. If a preset dict is provided, it will initialize the dictionary with
|
||||||
|
// the contents of dict.
|
||||||
|
func (dd *dictDecoder) init(size int, dict []byte) {
|
||||||
|
*dd = dictDecoder{hist: dd.hist}
|
||||||
|
|
||||||
|
if cap(dd.hist) < size {
|
||||||
|
dd.hist = make([]byte, size)
|
||||||
|
}
|
||||||
|
dd.hist = dd.hist[:size]
|
||||||
|
|
||||||
|
if len(dict) > len(dd.hist) {
|
||||||
|
dict = dict[len(dict)-len(dd.hist):]
|
||||||
|
}
|
||||||
|
dd.wrPos = copy(dd.hist, dict)
|
||||||
|
if dd.wrPos == len(dd.hist) {
|
||||||
|
dd.wrPos = 0
|
||||||
|
dd.full = true
|
||||||
|
}
|
||||||
|
dd.rdPos = dd.wrPos
|
||||||
|
}
|
||||||
|
|
||||||
|
// histSize reports the total amount of historical data in the dictionary.
|
||||||
|
func (dd *dictDecoder) histSize() int {
|
||||||
|
if dd.full {
|
||||||
|
return len(dd.hist)
|
||||||
|
}
|
||||||
|
return dd.wrPos
|
||||||
|
}
|
||||||
|
|
||||||
|
// availRead reports the number of bytes that can be flushed by readFlush.
|
||||||
|
func (dd *dictDecoder) availRead() int {
|
||||||
|
return dd.wrPos - dd.rdPos
|
||||||
|
}
|
||||||
|
|
||||||
|
// availWrite reports the available amount of output buffer space.
|
||||||
|
func (dd *dictDecoder) availWrite() int {
|
||||||
|
return len(dd.hist) - dd.wrPos
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeSlice returns a slice of the available buffer to write data to.
|
||||||
|
//
|
||||||
|
// This invariant will be kept: len(s) <= availWrite()
|
||||||
|
func (dd *dictDecoder) writeSlice() []byte {
|
||||||
|
return dd.hist[dd.wrPos:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeMark advances the writer pointer by cnt.
|
||||||
|
//
|
||||||
|
// This invariant must be kept: 0 <= cnt <= availWrite()
|
||||||
|
func (dd *dictDecoder) writeMark(cnt int) {
|
||||||
|
dd.wrPos += cnt
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeByte writes a single byte to the dictionary.
|
||||||
|
//
|
||||||
|
// This invariant must be kept: 0 < availWrite()
|
||||||
|
func (dd *dictDecoder) writeByte(c byte) {
|
||||||
|
dd.hist[dd.wrPos] = c
|
||||||
|
dd.wrPos++
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeCopy copies a string at a given (dist, length) to the output.
|
||||||
|
// This returns the number of bytes copied and may be less than the requested
|
||||||
|
// length if the available space in the output buffer is too small.
|
||||||
|
//
|
||||||
|
// This invariant must be kept: 0 < dist <= histSize()
|
||||||
|
func (dd *dictDecoder) writeCopy(dist, length int) int {
|
||||||
|
dstBase := dd.wrPos
|
||||||
|
dstPos := dstBase
|
||||||
|
srcPos := dstPos - dist
|
||||||
|
endPos := dstPos + length
|
||||||
|
if endPos > len(dd.hist) {
|
||||||
|
endPos = len(dd.hist)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy non-overlapping section after destination position.
|
||||||
|
//
|
||||||
|
// This section is non-overlapping in that the copy length for this section
|
||||||
|
// is always less than or equal to the backwards distance. This can occur
|
||||||
|
// if a distance refers to data that wraps-around in the buffer.
|
||||||
|
// Thus, a backwards copy is performed here; that is, the exact bytes in
|
||||||
|
// the source prior to the copy is placed in the destination.
|
||||||
|
if srcPos < 0 {
|
||||||
|
srcPos += len(dd.hist)
|
||||||
|
dstPos += copy(dd.hist[dstPos:endPos], dd.hist[srcPos:])
|
||||||
|
srcPos = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy possibly overlapping section before destination position.
|
||||||
|
//
|
||||||
|
// This section can overlap if the copy length for this section is larger
|
||||||
|
// than the backwards distance. This is allowed by LZ77 so that repeated
|
||||||
|
// strings can be succinctly represented using (dist, length) pairs.
|
||||||
|
// Thus, a forwards copy is performed here; that is, the bytes copied is
|
||||||
|
// possibly dependent on the resulting bytes in the destination as the copy
|
||||||
|
// progresses along. This is functionally equivalent to the following:
|
||||||
|
//
|
||||||
|
// for i := 0; i < endPos-dstPos; i++ {
|
||||||
|
// dd.hist[dstPos+i] = dd.hist[srcPos+i]
|
||||||
|
// }
|
||||||
|
// dstPos = endPos
|
||||||
|
//
|
||||||
|
for dstPos < endPos {
|
||||||
|
dstPos += copy(dd.hist[dstPos:endPos], dd.hist[srcPos:dstPos])
|
||||||
|
}
|
||||||
|
|
||||||
|
dd.wrPos = dstPos
|
||||||
|
return dstPos - dstBase
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryWriteCopy tries to copy a string at a given (distance, length) to the
|
||||||
|
// output. This specialized version is optimized for short distances.
|
||||||
|
//
|
||||||
|
// This method is designed to be inlined for performance reasons.
|
||||||
|
//
|
||||||
|
// This invariant must be kept: 0 < dist <= histSize()
|
||||||
|
func (dd *dictDecoder) tryWriteCopy(dist, length int) int {
|
||||||
|
dstPos := dd.wrPos
|
||||||
|
endPos := dstPos + length
|
||||||
|
if dstPos < dist || endPos > len(dd.hist) {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
dstBase := dstPos
|
||||||
|
srcPos := dstPos - dist
|
||||||
|
|
||||||
|
// Copy possibly overlapping section before destination position.
|
||||||
|
loop:
|
||||||
|
dstPos += copy(dd.hist[dstPos:endPos], dd.hist[srcPos:dstPos])
|
||||||
|
if dstPos < endPos {
|
||||||
|
goto loop // Avoid for-loop so that this function can be inlined
|
||||||
|
}
|
||||||
|
|
||||||
|
dd.wrPos = dstPos
|
||||||
|
return dstPos - dstBase
|
||||||
|
}
|
||||||
|
|
||||||
|
// readFlush returns a slice of the historical buffer that is ready to be
|
||||||
|
// emitted to the user. The data returned by readFlush must be fully consumed
|
||||||
|
// before calling any other dictDecoder methods.
|
||||||
|
func (dd *dictDecoder) readFlush() []byte {
|
||||||
|
toRead := dd.hist[dd.rdPos:dd.wrPos]
|
||||||
|
dd.rdPos = dd.wrPos
|
||||||
|
if dd.wrPos == len(dd.hist) {
|
||||||
|
dd.wrPos, dd.rdPos = 0, 0
|
||||||
|
dd.full = true
|
||||||
|
}
|
||||||
|
return toRead
|
||||||
|
}
|
|
@ -0,0 +1,216 @@
|
||||||
|
// Copyright 2011 The Snappy-Go Authors. All rights reserved.
|
||||||
|
// Modified for deflate by Klaus Post (c) 2015.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package flate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"math/bits"
|
||||||
|
)
|
||||||
|
|
||||||
|
type fastEnc interface {
|
||||||
|
Encode(dst *tokens, src []byte)
|
||||||
|
Reset()
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFastEnc(level int) fastEnc {
|
||||||
|
switch level {
|
||||||
|
case 1:
|
||||||
|
return &fastEncL1{fastGen: fastGen{cur: maxStoreBlockSize}}
|
||||||
|
case 2:
|
||||||
|
return &fastEncL2{fastGen: fastGen{cur: maxStoreBlockSize}}
|
||||||
|
case 3:
|
||||||
|
return &fastEncL3{fastGen: fastGen{cur: maxStoreBlockSize}}
|
||||||
|
case 4:
|
||||||
|
return &fastEncL4{fastGen: fastGen{cur: maxStoreBlockSize}}
|
||||||
|
case 5:
|
||||||
|
return &fastEncL5{fastGen: fastGen{cur: maxStoreBlockSize}}
|
||||||
|
case 6:
|
||||||
|
return &fastEncL6{fastGen: fastGen{cur: maxStoreBlockSize}}
|
||||||
|
default:
|
||||||
|
panic("invalid level specified")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
tableBits = 15 // Bits used in the table
|
||||||
|
tableSize = 1 << tableBits // Size of the table
|
||||||
|
tableShift = 32 - tableBits // Right-shift to get the tableBits most significant bits of a uint32.
|
||||||
|
baseMatchOffset = 1 // The smallest match offset
|
||||||
|
baseMatchLength = 3 // The smallest match length per the RFC section 3.2.5
|
||||||
|
maxMatchOffset = 1 << 15 // The largest match offset
|
||||||
|
|
||||||
|
bTableBits = 17 // Bits used in the big tables
|
||||||
|
bTableSize = 1 << bTableBits // Size of the table
|
||||||
|
allocHistory = maxStoreBlockSize * 5 // Size to preallocate for history.
|
||||||
|
bufferReset = (1 << 31) - allocHistory - maxStoreBlockSize - 1 // Reset the buffer offset when reaching this.
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
prime3bytes = 506832829
|
||||||
|
prime4bytes = 2654435761
|
||||||
|
prime5bytes = 889523592379
|
||||||
|
prime6bytes = 227718039650203
|
||||||
|
prime7bytes = 58295818150454627
|
||||||
|
prime8bytes = 0xcf1bbcdcb7a56463
|
||||||
|
)
|
||||||
|
|
||||||
|
func load3232(b []byte, i int32) uint32 {
|
||||||
|
return binary.LittleEndian.Uint32(b[i:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func load6432(b []byte, i int32) uint64 {
|
||||||
|
return binary.LittleEndian.Uint64(b[i:])
|
||||||
|
}
|
||||||
|
|
||||||
|
type tableEntry struct {
|
||||||
|
offset int32
|
||||||
|
}
|
||||||
|
|
||||||
|
// fastGen maintains the table for matches,
|
||||||
|
// and the previous byte block for level 2.
|
||||||
|
// This is the generic implementation.
|
||||||
|
type fastGen struct {
|
||||||
|
hist []byte
|
||||||
|
cur int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *fastGen) addBlock(src []byte) int32 {
|
||||||
|
// check if we have space already
|
||||||
|
if len(e.hist)+len(src) > cap(e.hist) {
|
||||||
|
if cap(e.hist) == 0 {
|
||||||
|
e.hist = make([]byte, 0, allocHistory)
|
||||||
|
} else {
|
||||||
|
if cap(e.hist) < maxMatchOffset*2 {
|
||||||
|
panic("unexpected buffer size")
|
||||||
|
}
|
||||||
|
// Move down
|
||||||
|
offset := int32(len(e.hist)) - maxMatchOffset
|
||||||
|
// copy(e.hist[0:maxMatchOffset], e.hist[offset:])
|
||||||
|
*(*[maxMatchOffset]byte)(e.hist) = *(*[maxMatchOffset]byte)(e.hist[offset:])
|
||||||
|
e.cur += offset
|
||||||
|
e.hist = e.hist[:maxMatchOffset]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s := int32(len(e.hist))
|
||||||
|
e.hist = append(e.hist, src...)
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
type tableEntryPrev struct {
|
||||||
|
Cur tableEntry
|
||||||
|
Prev tableEntry
|
||||||
|
}
|
||||||
|
|
||||||
|
// hash7 returns the hash of the lowest 7 bytes of u to fit in a hash table with h bits.
|
||||||
|
// Preferably h should be a constant and should always be <64.
|
||||||
|
func hash7(u uint64, h uint8) uint32 {
|
||||||
|
return uint32(((u << (64 - 56)) * prime7bytes) >> ((64 - h) & reg8SizeMask64))
|
||||||
|
}
|
||||||
|
|
||||||
|
// hashLen returns a hash of the lowest mls bytes of with length output bits.
|
||||||
|
// mls must be >=3 and <=8. Any other value will return hash for 4 bytes.
|
||||||
|
// length should always be < 32.
|
||||||
|
// Preferably length and mls should be a constant for inlining.
|
||||||
|
func hashLen(u uint64, length, mls uint8) uint32 {
|
||||||
|
switch mls {
|
||||||
|
case 3:
|
||||||
|
return (uint32(u<<8) * prime3bytes) >> (32 - length)
|
||||||
|
case 5:
|
||||||
|
return uint32(((u << (64 - 40)) * prime5bytes) >> (64 - length))
|
||||||
|
case 6:
|
||||||
|
return uint32(((u << (64 - 48)) * prime6bytes) >> (64 - length))
|
||||||
|
case 7:
|
||||||
|
return uint32(((u << (64 - 56)) * prime7bytes) >> (64 - length))
|
||||||
|
case 8:
|
||||||
|
return uint32((u * prime8bytes) >> (64 - length))
|
||||||
|
default:
|
||||||
|
return (uint32(u) * prime4bytes) >> (32 - length)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// matchlen will return the match length between offsets and t in src.
|
||||||
|
// The maximum length returned is maxMatchLength - 4.
|
||||||
|
// It is assumed that s > t, that t >=0 and s < len(src).
|
||||||
|
func (e *fastGen) matchlen(s, t int32, src []byte) int32 {
|
||||||
|
if debugDecode {
|
||||||
|
if t >= s {
|
||||||
|
panic(fmt.Sprint("t >=s:", t, s))
|
||||||
|
}
|
||||||
|
if int(s) >= len(src) {
|
||||||
|
panic(fmt.Sprint("s >= len(src):", s, len(src)))
|
||||||
|
}
|
||||||
|
if t < 0 {
|
||||||
|
panic(fmt.Sprint("t < 0:", t))
|
||||||
|
}
|
||||||
|
if s-t > maxMatchOffset {
|
||||||
|
panic(fmt.Sprint(s, "-", t, "(", s-t, ") > maxMatchLength (", maxMatchOffset, ")"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s1 := int(s) + maxMatchLength - 4
|
||||||
|
if s1 > len(src) {
|
||||||
|
s1 = len(src)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extend the match to be as long as possible.
|
||||||
|
return int32(matchLen(src[s:s1], src[t:]))
|
||||||
|
}
|
||||||
|
|
||||||
|
// matchlenLong will return the match length between offsets and t in src.
|
||||||
|
// It is assumed that s > t, that t >=0 and s < len(src).
|
||||||
|
func (e *fastGen) matchlenLong(s, t int32, src []byte) int32 {
|
||||||
|
if debugDeflate {
|
||||||
|
if t >= s {
|
||||||
|
panic(fmt.Sprint("t >=s:", t, s))
|
||||||
|
}
|
||||||
|
if int(s) >= len(src) {
|
||||||
|
panic(fmt.Sprint("s >= len(src):", s, len(src)))
|
||||||
|
}
|
||||||
|
if t < 0 {
|
||||||
|
panic(fmt.Sprint("t < 0:", t))
|
||||||
|
}
|
||||||
|
if s-t > maxMatchOffset {
|
||||||
|
panic(fmt.Sprint(s, "-", t, "(", s-t, ") > maxMatchLength (", maxMatchOffset, ")"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Extend the match to be as long as possible.
|
||||||
|
return int32(matchLen(src[s:], src[t:]))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset the encoding table.
|
||||||
|
func (e *fastGen) Reset() {
|
||||||
|
if cap(e.hist) < allocHistory {
|
||||||
|
e.hist = make([]byte, 0, allocHistory)
|
||||||
|
}
|
||||||
|
// We offset current position so everything will be out of reach.
|
||||||
|
// If we are above the buffer reset it will be cleared anyway since len(hist) == 0.
|
||||||
|
if e.cur <= bufferReset {
|
||||||
|
e.cur += maxMatchOffset + int32(len(e.hist))
|
||||||
|
}
|
||||||
|
e.hist = e.hist[:0]
|
||||||
|
}
|
||||||
|
|
||||||
|
// matchLen returns the maximum length.
|
||||||
|
// 'a' must be the shortest of the two.
|
||||||
|
func matchLen(a, b []byte) int {
|
||||||
|
var checked int
|
||||||
|
|
||||||
|
for len(a) >= 8 {
|
||||||
|
if diff := binary.LittleEndian.Uint64(a) ^ binary.LittleEndian.Uint64(b); diff != 0 {
|
||||||
|
return checked + (bits.TrailingZeros64(diff) >> 3)
|
||||||
|
}
|
||||||
|
checked += 8
|
||||||
|
a = a[8:]
|
||||||
|
b = b[8:]
|
||||||
|
}
|
||||||
|
b = b[:len(a)]
|
||||||
|
for i := range a {
|
||||||
|
if a[i] != b[i] {
|
||||||
|
return i + checked
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return len(a) + checked
|
||||||
|
}
|
1187
vendor/github.com/klauspost/compress/flate/huffman_bit_writer.go
generated
vendored
Normal file
1187
vendor/github.com/klauspost/compress/flate/huffman_bit_writer.go
generated
vendored
Normal file
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,417 @@
|
||||||
|
// Copyright 2009 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package flate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"math/bits"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
maxBitsLimit = 16
|
||||||
|
// number of valid literals
|
||||||
|
literalCount = 286
|
||||||
|
)
|
||||||
|
|
||||||
|
// hcode is a huffman code with a bit code and bit length.
|
||||||
|
type hcode uint32
|
||||||
|
|
||||||
|
func (h hcode) len() uint8 {
|
||||||
|
return uint8(h)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h hcode) code64() uint64 {
|
||||||
|
return uint64(h >> 8)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h hcode) zero() bool {
|
||||||
|
return h == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
type huffmanEncoder struct {
|
||||||
|
codes []hcode
|
||||||
|
bitCount [17]int32
|
||||||
|
|
||||||
|
// Allocate a reusable buffer with the longest possible frequency table.
|
||||||
|
// Possible lengths are codegenCodeCount, offsetCodeCount and literalCount.
|
||||||
|
// The largest of these is literalCount, so we allocate for that case.
|
||||||
|
freqcache [literalCount + 1]literalNode
|
||||||
|
}
|
||||||
|
|
||||||
|
type literalNode struct {
|
||||||
|
literal uint16
|
||||||
|
freq uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
// A levelInfo describes the state of the constructed tree for a given depth.
|
||||||
|
type levelInfo struct {
|
||||||
|
// Our level. for better printing
|
||||||
|
level int32
|
||||||
|
|
||||||
|
// The frequency of the last node at this level
|
||||||
|
lastFreq int32
|
||||||
|
|
||||||
|
// The frequency of the next character to add to this level
|
||||||
|
nextCharFreq int32
|
||||||
|
|
||||||
|
// The frequency of the next pair (from level below) to add to this level.
|
||||||
|
// Only valid if the "needed" value of the next lower level is 0.
|
||||||
|
nextPairFreq int32
|
||||||
|
|
||||||
|
// The number of chains remaining to generate for this level before moving
|
||||||
|
// up to the next level
|
||||||
|
needed int32
|
||||||
|
}
|
||||||
|
|
||||||
|
// set sets the code and length of an hcode.
|
||||||
|
func (h *hcode) set(code uint16, length uint8) {
|
||||||
|
*h = hcode(length) | (hcode(code) << 8)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newhcode(code uint16, length uint8) hcode {
|
||||||
|
return hcode(length) | (hcode(code) << 8)
|
||||||
|
}
|
||||||
|
|
||||||
|
func reverseBits(number uint16, bitLength byte) uint16 {
|
||||||
|
return bits.Reverse16(number << ((16 - bitLength) & 15))
|
||||||
|
}
|
||||||
|
|
||||||
|
func maxNode() literalNode { return literalNode{math.MaxUint16, math.MaxUint16} }
|
||||||
|
|
||||||
|
func newHuffmanEncoder(size int) *huffmanEncoder {
|
||||||
|
// Make capacity to next power of two.
|
||||||
|
c := uint(bits.Len32(uint32(size - 1)))
|
||||||
|
return &huffmanEncoder{codes: make([]hcode, size, 1<<c)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generates a HuffmanCode corresponding to the fixed literal table
|
||||||
|
func generateFixedLiteralEncoding() *huffmanEncoder {
|
||||||
|
h := newHuffmanEncoder(literalCount)
|
||||||
|
codes := h.codes
|
||||||
|
var ch uint16
|
||||||
|
for ch = 0; ch < literalCount; ch++ {
|
||||||
|
var bits uint16
|
||||||
|
var size uint8
|
||||||
|
switch {
|
||||||
|
case ch < 144:
|
||||||
|
// size 8, 000110000 .. 10111111
|
||||||
|
bits = ch + 48
|
||||||
|
size = 8
|
||||||
|
case ch < 256:
|
||||||
|
// size 9, 110010000 .. 111111111
|
||||||
|
bits = ch + 400 - 144
|
||||||
|
size = 9
|
||||||
|
case ch < 280:
|
||||||
|
// size 7, 0000000 .. 0010111
|
||||||
|
bits = ch - 256
|
||||||
|
size = 7
|
||||||
|
default:
|
||||||
|
// size 8, 11000000 .. 11000111
|
||||||
|
bits = ch + 192 - 280
|
||||||
|
size = 8
|
||||||
|
}
|
||||||
|
codes[ch] = newhcode(reverseBits(bits, size), size)
|
||||||
|
}
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateFixedOffsetEncoding() *huffmanEncoder {
|
||||||
|
h := newHuffmanEncoder(30)
|
||||||
|
codes := h.codes
|
||||||
|
for ch := range codes {
|
||||||
|
codes[ch] = newhcode(reverseBits(uint16(ch), 5), 5)
|
||||||
|
}
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
var fixedLiteralEncoding = generateFixedLiteralEncoding()
|
||||||
|
var fixedOffsetEncoding = generateFixedOffsetEncoding()
|
||||||
|
|
||||||
|
func (h *huffmanEncoder) bitLength(freq []uint16) int {
|
||||||
|
var total int
|
||||||
|
for i, f := range freq {
|
||||||
|
if f != 0 {
|
||||||
|
total += int(f) * int(h.codes[i].len())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return total
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *huffmanEncoder) bitLengthRaw(b []byte) int {
|
||||||
|
var total int
|
||||||
|
for _, f := range b {
|
||||||
|
total += int(h.codes[f].len())
|
||||||
|
}
|
||||||
|
return total
|
||||||
|
}
|
||||||
|
|
||||||
|
// canReuseBits returns the number of bits or math.MaxInt32 if the encoder cannot be reused.
|
||||||
|
func (h *huffmanEncoder) canReuseBits(freq []uint16) int {
|
||||||
|
var total int
|
||||||
|
for i, f := range freq {
|
||||||
|
if f != 0 {
|
||||||
|
code := h.codes[i]
|
||||||
|
if code.zero() {
|
||||||
|
return math.MaxInt32
|
||||||
|
}
|
||||||
|
total += int(f) * int(code.len())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return total
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the number of literals assigned to each bit size in the Huffman encoding
|
||||||
|
//
|
||||||
|
// This method is only called when list.length >= 3
|
||||||
|
// The cases of 0, 1, and 2 literals are handled by special case code.
|
||||||
|
//
|
||||||
|
// list An array of the literals with non-zero frequencies
|
||||||
|
//
|
||||||
|
// and their associated frequencies. The array is in order of increasing
|
||||||
|
// frequency, and has as its last element a special element with frequency
|
||||||
|
// MaxInt32
|
||||||
|
//
|
||||||
|
// maxBits The maximum number of bits that should be used to encode any literal.
|
||||||
|
//
|
||||||
|
// Must be less than 16.
|
||||||
|
//
|
||||||
|
// return An integer array in which array[i] indicates the number of literals
|
||||||
|
//
|
||||||
|
// that should be encoded in i bits.
|
||||||
|
func (h *huffmanEncoder) bitCounts(list []literalNode, maxBits int32) []int32 {
|
||||||
|
if maxBits >= maxBitsLimit {
|
||||||
|
panic("flate: maxBits too large")
|
||||||
|
}
|
||||||
|
n := int32(len(list))
|
||||||
|
list = list[0 : n+1]
|
||||||
|
list[n] = maxNode()
|
||||||
|
|
||||||
|
// The tree can't have greater depth than n - 1, no matter what. This
|
||||||
|
// saves a little bit of work in some small cases
|
||||||
|
if maxBits > n-1 {
|
||||||
|
maxBits = n - 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create information about each of the levels.
|
||||||
|
// A bogus "Level 0" whose sole purpose is so that
|
||||||
|
// level1.prev.needed==0. This makes level1.nextPairFreq
|
||||||
|
// be a legitimate value that never gets chosen.
|
||||||
|
var levels [maxBitsLimit]levelInfo
|
||||||
|
// leafCounts[i] counts the number of literals at the left
|
||||||
|
// of ancestors of the rightmost node at level i.
|
||||||
|
// leafCounts[i][j] is the number of literals at the left
|
||||||
|
// of the level j ancestor.
|
||||||
|
var leafCounts [maxBitsLimit][maxBitsLimit]int32
|
||||||
|
|
||||||
|
// Descending to only have 1 bounds check.
|
||||||
|
l2f := int32(list[2].freq)
|
||||||
|
l1f := int32(list[1].freq)
|
||||||
|
l0f := int32(list[0].freq) + int32(list[1].freq)
|
||||||
|
|
||||||
|
for level := int32(1); level <= maxBits; level++ {
|
||||||
|
// For every level, the first two items are the first two characters.
|
||||||
|
// We initialize the levels as if we had already figured this out.
|
||||||
|
levels[level] = levelInfo{
|
||||||
|
level: level,
|
||||||
|
lastFreq: l1f,
|
||||||
|
nextCharFreq: l2f,
|
||||||
|
nextPairFreq: l0f,
|
||||||
|
}
|
||||||
|
leafCounts[level][level] = 2
|
||||||
|
if level == 1 {
|
||||||
|
levels[level].nextPairFreq = math.MaxInt32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// We need a total of 2*n - 2 items at top level and have already generated 2.
|
||||||
|
levels[maxBits].needed = 2*n - 4
|
||||||
|
|
||||||
|
level := uint32(maxBits)
|
||||||
|
for level < 16 {
|
||||||
|
l := &levels[level]
|
||||||
|
if l.nextPairFreq == math.MaxInt32 && l.nextCharFreq == math.MaxInt32 {
|
||||||
|
// We've run out of both leafs and pairs.
|
||||||
|
// End all calculations for this level.
|
||||||
|
// To make sure we never come back to this level or any lower level,
|
||||||
|
// set nextPairFreq impossibly large.
|
||||||
|
l.needed = 0
|
||||||
|
levels[level+1].nextPairFreq = math.MaxInt32
|
||||||
|
level++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
prevFreq := l.lastFreq
|
||||||
|
if l.nextCharFreq < l.nextPairFreq {
|
||||||
|
// The next item on this row is a leaf node.
|
||||||
|
n := leafCounts[level][level] + 1
|
||||||
|
l.lastFreq = l.nextCharFreq
|
||||||
|
// Lower leafCounts are the same of the previous node.
|
||||||
|
leafCounts[level][level] = n
|
||||||
|
e := list[n]
|
||||||
|
if e.literal < math.MaxUint16 {
|
||||||
|
l.nextCharFreq = int32(e.freq)
|
||||||
|
} else {
|
||||||
|
l.nextCharFreq = math.MaxInt32
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// The next item on this row is a pair from the previous row.
|
||||||
|
// nextPairFreq isn't valid until we generate two
|
||||||
|
// more values in the level below
|
||||||
|
l.lastFreq = l.nextPairFreq
|
||||||
|
// Take leaf counts from the lower level, except counts[level] remains the same.
|
||||||
|
if true {
|
||||||
|
save := leafCounts[level][level]
|
||||||
|
leafCounts[level] = leafCounts[level-1]
|
||||||
|
leafCounts[level][level] = save
|
||||||
|
} else {
|
||||||
|
copy(leafCounts[level][:level], leafCounts[level-1][:level])
|
||||||
|
}
|
||||||
|
levels[l.level-1].needed = 2
|
||||||
|
}
|
||||||
|
|
||||||
|
if l.needed--; l.needed == 0 {
|
||||||
|
// We've done everything we need to do for this level.
|
||||||
|
// Continue calculating one level up. Fill in nextPairFreq
|
||||||
|
// of that level with the sum of the two nodes we've just calculated on
|
||||||
|
// this level.
|
||||||
|
if l.level == maxBits {
|
||||||
|
// All done!
|
||||||
|
break
|
||||||
|
}
|
||||||
|
levels[l.level+1].nextPairFreq = prevFreq + l.lastFreq
|
||||||
|
level++
|
||||||
|
} else {
|
||||||
|
// If we stole from below, move down temporarily to replenish it.
|
||||||
|
for levels[level-1].needed > 0 {
|
||||||
|
level--
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Somethings is wrong if at the end, the top level is null or hasn't used
|
||||||
|
// all of the leaves.
|
||||||
|
if leafCounts[maxBits][maxBits] != n {
|
||||||
|
panic("leafCounts[maxBits][maxBits] != n")
|
||||||
|
}
|
||||||
|
|
||||||
|
bitCount := h.bitCount[:maxBits+1]
|
||||||
|
bits := 1
|
||||||
|
counts := &leafCounts[maxBits]
|
||||||
|
for level := maxBits; level > 0; level-- {
|
||||||
|
// chain.leafCount gives the number of literals requiring at least "bits"
|
||||||
|
// bits to encode.
|
||||||
|
bitCount[bits] = counts[level] - counts[level-1]
|
||||||
|
bits++
|
||||||
|
}
|
||||||
|
return bitCount
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look at the leaves and assign them a bit count and an encoding as specified
|
||||||
|
// in RFC 1951 3.2.2
|
||||||
|
func (h *huffmanEncoder) assignEncodingAndSize(bitCount []int32, list []literalNode) {
|
||||||
|
code := uint16(0)
|
||||||
|
for n, bits := range bitCount {
|
||||||
|
code <<= 1
|
||||||
|
if n == 0 || bits == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// The literals list[len(list)-bits] .. list[len(list)-bits]
|
||||||
|
// are encoded using "bits" bits, and get the values
|
||||||
|
// code, code + 1, .... The code values are
|
||||||
|
// assigned in literal order (not frequency order).
|
||||||
|
chunk := list[len(list)-int(bits):]
|
||||||
|
|
||||||
|
sortByLiteral(chunk)
|
||||||
|
for _, node := range chunk {
|
||||||
|
h.codes[node.literal] = newhcode(reverseBits(code, uint8(n)), uint8(n))
|
||||||
|
code++
|
||||||
|
}
|
||||||
|
list = list[0 : len(list)-int(bits)]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update this Huffman Code object to be the minimum code for the specified frequency count.
|
||||||
|
//
|
||||||
|
// freq An array of frequencies, in which frequency[i] gives the frequency of literal i.
|
||||||
|
// maxBits The maximum number of bits to use for any literal.
|
||||||
|
func (h *huffmanEncoder) generate(freq []uint16, maxBits int32) {
|
||||||
|
list := h.freqcache[:len(freq)+1]
|
||||||
|
codes := h.codes[:len(freq)]
|
||||||
|
// Number of non-zero literals
|
||||||
|
count := 0
|
||||||
|
// Set list to be the set of all non-zero literals and their frequencies
|
||||||
|
for i, f := range freq {
|
||||||
|
if f != 0 {
|
||||||
|
list[count] = literalNode{uint16(i), f}
|
||||||
|
count++
|
||||||
|
} else {
|
||||||
|
codes[i] = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
list[count] = literalNode{}
|
||||||
|
|
||||||
|
list = list[:count]
|
||||||
|
if count <= 2 {
|
||||||
|
// Handle the small cases here, because they are awkward for the general case code. With
|
||||||
|
// two or fewer literals, everything has bit length 1.
|
||||||
|
for i, node := range list {
|
||||||
|
// "list" is in order of increasing literal value.
|
||||||
|
h.codes[node.literal].set(uint16(i), 1)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sortByFreq(list)
|
||||||
|
|
||||||
|
// Get the number of literals for each bit count
|
||||||
|
bitCount := h.bitCounts(list, maxBits)
|
||||||
|
// And do the assignment
|
||||||
|
h.assignEncodingAndSize(bitCount, list)
|
||||||
|
}
|
||||||
|
|
||||||
|
// atLeastOne clamps the result between 1 and 15.
|
||||||
|
func atLeastOne(v float32) float32 {
|
||||||
|
if v < 1 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
if v > 15 {
|
||||||
|
return 15
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
func histogram(b []byte, h []uint16) {
|
||||||
|
if true && len(b) >= 8<<10 {
|
||||||
|
// Split for bigger inputs
|
||||||
|
histogramSplit(b, h)
|
||||||
|
} else {
|
||||||
|
h = h[:256]
|
||||||
|
for _, t := range b {
|
||||||
|
h[t]++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func histogramSplit(b []byte, h []uint16) {
|
||||||
|
// Tested, and slightly faster than 2-way.
|
||||||
|
// Writing to separate arrays and combining is also slightly slower.
|
||||||
|
h = h[:256]
|
||||||
|
for len(b)&3 != 0 {
|
||||||
|
h[b[0]]++
|
||||||
|
b = b[1:]
|
||||||
|
}
|
||||||
|
n := len(b) / 4
|
||||||
|
x, y, z, w := b[:n], b[n:], b[n+n:], b[n+n+n:]
|
||||||
|
y, z, w = y[:len(x)], z[:len(x)], w[:len(x)]
|
||||||
|
for i, t := range x {
|
||||||
|
v0 := &h[t]
|
||||||
|
v1 := &h[y[i]]
|
||||||
|
v3 := &h[w[i]]
|
||||||
|
v2 := &h[z[i]]
|
||||||
|
*v0++
|
||||||
|
*v1++
|
||||||
|
*v2++
|
||||||
|
*v3++
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,178 @@
|
||||||
|
// Copyright 2009 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package flate
|
||||||
|
|
||||||
|
// Sort sorts data.
|
||||||
|
// It makes one call to data.Len to determine n, and O(n*log(n)) calls to
|
||||||
|
// data.Less and data.Swap. The sort is not guaranteed to be stable.
|
||||||
|
func sortByFreq(data []literalNode) {
|
||||||
|
n := len(data)
|
||||||
|
quickSortByFreq(data, 0, n, maxDepth(n))
|
||||||
|
}
|
||||||
|
|
||||||
|
func quickSortByFreq(data []literalNode, a, b, maxDepth int) {
|
||||||
|
for b-a > 12 { // Use ShellSort for slices <= 12 elements
|
||||||
|
if maxDepth == 0 {
|
||||||
|
heapSort(data, a, b)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
maxDepth--
|
||||||
|
mlo, mhi := doPivotByFreq(data, a, b)
|
||||||
|
// Avoiding recursion on the larger subproblem guarantees
|
||||||
|
// a stack depth of at most lg(b-a).
|
||||||
|
if mlo-a < b-mhi {
|
||||||
|
quickSortByFreq(data, a, mlo, maxDepth)
|
||||||
|
a = mhi // i.e., quickSortByFreq(data, mhi, b)
|
||||||
|
} else {
|
||||||
|
quickSortByFreq(data, mhi, b, maxDepth)
|
||||||
|
b = mlo // i.e., quickSortByFreq(data, a, mlo)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if b-a > 1 {
|
||||||
|
// Do ShellSort pass with gap 6
|
||||||
|
// It could be written in this simplified form cause b-a <= 12
|
||||||
|
for i := a + 6; i < b; i++ {
|
||||||
|
if data[i].freq == data[i-6].freq && data[i].literal < data[i-6].literal || data[i].freq < data[i-6].freq {
|
||||||
|
data[i], data[i-6] = data[i-6], data[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
insertionSortByFreq(data, a, b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// siftDownByFreq implements the heap property on data[lo, hi).
|
||||||
|
// first is an offset into the array where the root of the heap lies.
|
||||||
|
func siftDownByFreq(data []literalNode, lo, hi, first int) {
|
||||||
|
root := lo
|
||||||
|
for {
|
||||||
|
child := 2*root + 1
|
||||||
|
if child >= hi {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if child+1 < hi && (data[first+child].freq == data[first+child+1].freq && data[first+child].literal < data[first+child+1].literal || data[first+child].freq < data[first+child+1].freq) {
|
||||||
|
child++
|
||||||
|
}
|
||||||
|
if data[first+root].freq == data[first+child].freq && data[first+root].literal > data[first+child].literal || data[first+root].freq > data[first+child].freq {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data[first+root], data[first+child] = data[first+child], data[first+root]
|
||||||
|
root = child
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func doPivotByFreq(data []literalNode, lo, hi int) (midlo, midhi int) {
|
||||||
|
m := int(uint(lo+hi) >> 1) // Written like this to avoid integer overflow.
|
||||||
|
if hi-lo > 40 {
|
||||||
|
// Tukey's ``Ninther,'' median of three medians of three.
|
||||||
|
s := (hi - lo) / 8
|
||||||
|
medianOfThreeSortByFreq(data, lo, lo+s, lo+2*s)
|
||||||
|
medianOfThreeSortByFreq(data, m, m-s, m+s)
|
||||||
|
medianOfThreeSortByFreq(data, hi-1, hi-1-s, hi-1-2*s)
|
||||||
|
}
|
||||||
|
medianOfThreeSortByFreq(data, lo, m, hi-1)
|
||||||
|
|
||||||
|
// Invariants are:
|
||||||
|
// data[lo] = pivot (set up by ChoosePivot)
|
||||||
|
// data[lo < i < a] < pivot
|
||||||
|
// data[a <= i < b] <= pivot
|
||||||
|
// data[b <= i < c] unexamined
|
||||||
|
// data[c <= i < hi-1] > pivot
|
||||||
|
// data[hi-1] >= pivot
|
||||||
|
pivot := lo
|
||||||
|
a, c := lo+1, hi-1
|
||||||
|
|
||||||
|
for ; a < c && (data[a].freq == data[pivot].freq && data[a].literal < data[pivot].literal || data[a].freq < data[pivot].freq); a++ {
|
||||||
|
}
|
||||||
|
b := a
|
||||||
|
for {
|
||||||
|
for ; b < c && (data[pivot].freq == data[b].freq && data[pivot].literal > data[b].literal || data[pivot].freq > data[b].freq); b++ { // data[b] <= pivot
|
||||||
|
}
|
||||||
|
for ; b < c && (data[pivot].freq == data[c-1].freq && data[pivot].literal < data[c-1].literal || data[pivot].freq < data[c-1].freq); c-- { // data[c-1] > pivot
|
||||||
|
}
|
||||||
|
if b >= c {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
// data[b] > pivot; data[c-1] <= pivot
|
||||||
|
data[b], data[c-1] = data[c-1], data[b]
|
||||||
|
b++
|
||||||
|
c--
|
||||||
|
}
|
||||||
|
// If hi-c<3 then there are duplicates (by property of median of nine).
|
||||||
|
// Let's be a bit more conservative, and set border to 5.
|
||||||
|
protect := hi-c < 5
|
||||||
|
if !protect && hi-c < (hi-lo)/4 {
|
||||||
|
// Lets test some points for equality to pivot
|
||||||
|
dups := 0
|
||||||
|
if data[pivot].freq == data[hi-1].freq && data[pivot].literal > data[hi-1].literal || data[pivot].freq > data[hi-1].freq { // data[hi-1] = pivot
|
||||||
|
data[c], data[hi-1] = data[hi-1], data[c]
|
||||||
|
c++
|
||||||
|
dups++
|
||||||
|
}
|
||||||
|
if data[b-1].freq == data[pivot].freq && data[b-1].literal > data[pivot].literal || data[b-1].freq > data[pivot].freq { // data[b-1] = pivot
|
||||||
|
b--
|
||||||
|
dups++
|
||||||
|
}
|
||||||
|
// m-lo = (hi-lo)/2 > 6
|
||||||
|
// b-lo > (hi-lo)*3/4-1 > 8
|
||||||
|
// ==> m < b ==> data[m] <= pivot
|
||||||
|
if data[m].freq == data[pivot].freq && data[m].literal > data[pivot].literal || data[m].freq > data[pivot].freq { // data[m] = pivot
|
||||||
|
data[m], data[b-1] = data[b-1], data[m]
|
||||||
|
b--
|
||||||
|
dups++
|
||||||
|
}
|
||||||
|
// if at least 2 points are equal to pivot, assume skewed distribution
|
||||||
|
protect = dups > 1
|
||||||
|
}
|
||||||
|
if protect {
|
||||||
|
// Protect against a lot of duplicates
|
||||||
|
// Add invariant:
|
||||||
|
// data[a <= i < b] unexamined
|
||||||
|
// data[b <= i < c] = pivot
|
||||||
|
for {
|
||||||
|
for ; a < b && (data[b-1].freq == data[pivot].freq && data[b-1].literal > data[pivot].literal || data[b-1].freq > data[pivot].freq); b-- { // data[b] == pivot
|
||||||
|
}
|
||||||
|
for ; a < b && (data[a].freq == data[pivot].freq && data[a].literal < data[pivot].literal || data[a].freq < data[pivot].freq); a++ { // data[a] < pivot
|
||||||
|
}
|
||||||
|
if a >= b {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
// data[a] == pivot; data[b-1] < pivot
|
||||||
|
data[a], data[b-1] = data[b-1], data[a]
|
||||||
|
a++
|
||||||
|
b--
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Swap pivot into middle
|
||||||
|
data[pivot], data[b-1] = data[b-1], data[pivot]
|
||||||
|
return b - 1, c
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insertion sort
|
||||||
|
func insertionSortByFreq(data []literalNode, a, b int) {
|
||||||
|
for i := a + 1; i < b; i++ {
|
||||||
|
for j := i; j > a && (data[j].freq == data[j-1].freq && data[j].literal < data[j-1].literal || data[j].freq < data[j-1].freq); j-- {
|
||||||
|
data[j], data[j-1] = data[j-1], data[j]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// quickSortByFreq, loosely following Bentley and McIlroy,
|
||||||
|
// ``Engineering a Sort Function,'' SP&E November 1993.
|
||||||
|
|
||||||
|
// medianOfThreeSortByFreq moves the median of the three values data[m0], data[m1], data[m2] into data[m1].
|
||||||
|
func medianOfThreeSortByFreq(data []literalNode, m1, m0, m2 int) {
|
||||||
|
// sort 3 elements
|
||||||
|
if data[m1].freq == data[m0].freq && data[m1].literal < data[m0].literal || data[m1].freq < data[m0].freq {
|
||||||
|
data[m1], data[m0] = data[m0], data[m1]
|
||||||
|
}
|
||||||
|
// data[m0] <= data[m1]
|
||||||
|
if data[m2].freq == data[m1].freq && data[m2].literal < data[m1].literal || data[m2].freq < data[m1].freq {
|
||||||
|
data[m2], data[m1] = data[m1], data[m2]
|
||||||
|
// data[m0] <= data[m2] && data[m1] < data[m2]
|
||||||
|
if data[m1].freq == data[m0].freq && data[m1].literal < data[m0].literal || data[m1].freq < data[m0].freq {
|
||||||
|
data[m1], data[m0] = data[m0], data[m1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// now data[m0] <= data[m1] <= data[m2]
|
||||||
|
}
|
201
vendor/github.com/klauspost/compress/flate/huffman_sortByLiteral.go
generated
vendored
Normal file
201
vendor/github.com/klauspost/compress/flate/huffman_sortByLiteral.go
generated
vendored
Normal file
|
@ -0,0 +1,201 @@
|
||||||
|
// Copyright 2009 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package flate
|
||||||
|
|
||||||
|
// Sort sorts data.
|
||||||
|
// It makes one call to data.Len to determine n, and O(n*log(n)) calls to
|
||||||
|
// data.Less and data.Swap. The sort is not guaranteed to be stable.
|
||||||
|
func sortByLiteral(data []literalNode) {
|
||||||
|
n := len(data)
|
||||||
|
quickSort(data, 0, n, maxDepth(n))
|
||||||
|
}
|
||||||
|
|
||||||
|
func quickSort(data []literalNode, a, b, maxDepth int) {
|
||||||
|
for b-a > 12 { // Use ShellSort for slices <= 12 elements
|
||||||
|
if maxDepth == 0 {
|
||||||
|
heapSort(data, a, b)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
maxDepth--
|
||||||
|
mlo, mhi := doPivot(data, a, b)
|
||||||
|
// Avoiding recursion on the larger subproblem guarantees
|
||||||
|
// a stack depth of at most lg(b-a).
|
||||||
|
if mlo-a < b-mhi {
|
||||||
|
quickSort(data, a, mlo, maxDepth)
|
||||||
|
a = mhi // i.e., quickSort(data, mhi, b)
|
||||||
|
} else {
|
||||||
|
quickSort(data, mhi, b, maxDepth)
|
||||||
|
b = mlo // i.e., quickSort(data, a, mlo)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if b-a > 1 {
|
||||||
|
// Do ShellSort pass with gap 6
|
||||||
|
// It could be written in this simplified form cause b-a <= 12
|
||||||
|
for i := a + 6; i < b; i++ {
|
||||||
|
if data[i].literal < data[i-6].literal {
|
||||||
|
data[i], data[i-6] = data[i-6], data[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
insertionSort(data, a, b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func heapSort(data []literalNode, a, b int) {
|
||||||
|
first := a
|
||||||
|
lo := 0
|
||||||
|
hi := b - a
|
||||||
|
|
||||||
|
// Build heap with greatest element at top.
|
||||||
|
for i := (hi - 1) / 2; i >= 0; i-- {
|
||||||
|
siftDown(data, i, hi, first)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pop elements, largest first, into end of data.
|
||||||
|
for i := hi - 1; i >= 0; i-- {
|
||||||
|
data[first], data[first+i] = data[first+i], data[first]
|
||||||
|
siftDown(data, lo, i, first)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// siftDown implements the heap property on data[lo, hi).
|
||||||
|
// first is an offset into the array where the root of the heap lies.
|
||||||
|
func siftDown(data []literalNode, lo, hi, first int) {
|
||||||
|
root := lo
|
||||||
|
for {
|
||||||
|
child := 2*root + 1
|
||||||
|
if child >= hi {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if child+1 < hi && data[first+child].literal < data[first+child+1].literal {
|
||||||
|
child++
|
||||||
|
}
|
||||||
|
if data[first+root].literal > data[first+child].literal {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data[first+root], data[first+child] = data[first+child], data[first+root]
|
||||||
|
root = child
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func doPivot(data []literalNode, lo, hi int) (midlo, midhi int) {
|
||||||
|
m := int(uint(lo+hi) >> 1) // Written like this to avoid integer overflow.
|
||||||
|
if hi-lo > 40 {
|
||||||
|
// Tukey's ``Ninther,'' median of three medians of three.
|
||||||
|
s := (hi - lo) / 8
|
||||||
|
medianOfThree(data, lo, lo+s, lo+2*s)
|
||||||
|
medianOfThree(data, m, m-s, m+s)
|
||||||
|
medianOfThree(data, hi-1, hi-1-s, hi-1-2*s)
|
||||||
|
}
|
||||||
|
medianOfThree(data, lo, m, hi-1)
|
||||||
|
|
||||||
|
// Invariants are:
|
||||||
|
// data[lo] = pivot (set up by ChoosePivot)
|
||||||
|
// data[lo < i < a] < pivot
|
||||||
|
// data[a <= i < b] <= pivot
|
||||||
|
// data[b <= i < c] unexamined
|
||||||
|
// data[c <= i < hi-1] > pivot
|
||||||
|
// data[hi-1] >= pivot
|
||||||
|
pivot := lo
|
||||||
|
a, c := lo+1, hi-1
|
||||||
|
|
||||||
|
for ; a < c && data[a].literal < data[pivot].literal; a++ {
|
||||||
|
}
|
||||||
|
b := a
|
||||||
|
for {
|
||||||
|
for ; b < c && data[pivot].literal > data[b].literal; b++ { // data[b] <= pivot
|
||||||
|
}
|
||||||
|
for ; b < c && data[pivot].literal < data[c-1].literal; c-- { // data[c-1] > pivot
|
||||||
|
}
|
||||||
|
if b >= c {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
// data[b] > pivot; data[c-1] <= pivot
|
||||||
|
data[b], data[c-1] = data[c-1], data[b]
|
||||||
|
b++
|
||||||
|
c--
|
||||||
|
}
|
||||||
|
// If hi-c<3 then there are duplicates (by property of median of nine).
|
||||||
|
// Let's be a bit more conservative, and set border to 5.
|
||||||
|
protect := hi-c < 5
|
||||||
|
if !protect && hi-c < (hi-lo)/4 {
|
||||||
|
// Lets test some points for equality to pivot
|
||||||
|
dups := 0
|
||||||
|
if data[pivot].literal > data[hi-1].literal { // data[hi-1] = pivot
|
||||||
|
data[c], data[hi-1] = data[hi-1], data[c]
|
||||||
|
c++
|
||||||
|
dups++
|
||||||
|
}
|
||||||
|
if data[b-1].literal > data[pivot].literal { // data[b-1] = pivot
|
||||||
|
b--
|
||||||
|
dups++
|
||||||
|
}
|
||||||
|
// m-lo = (hi-lo)/2 > 6
|
||||||
|
// b-lo > (hi-lo)*3/4-1 > 8
|
||||||
|
// ==> m < b ==> data[m] <= pivot
|
||||||
|
if data[m].literal > data[pivot].literal { // data[m] = pivot
|
||||||
|
data[m], data[b-1] = data[b-1], data[m]
|
||||||
|
b--
|
||||||
|
dups++
|
||||||
|
}
|
||||||
|
// if at least 2 points are equal to pivot, assume skewed distribution
|
||||||
|
protect = dups > 1
|
||||||
|
}
|
||||||
|
if protect {
|
||||||
|
// Protect against a lot of duplicates
|
||||||
|
// Add invariant:
|
||||||
|
// data[a <= i < b] unexamined
|
||||||
|
// data[b <= i < c] = pivot
|
||||||
|
for {
|
||||||
|
for ; a < b && data[b-1].literal > data[pivot].literal; b-- { // data[b] == pivot
|
||||||
|
}
|
||||||
|
for ; a < b && data[a].literal < data[pivot].literal; a++ { // data[a] < pivot
|
||||||
|
}
|
||||||
|
if a >= b {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
// data[a] == pivot; data[b-1] < pivot
|
||||||
|
data[a], data[b-1] = data[b-1], data[a]
|
||||||
|
a++
|
||||||
|
b--
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Swap pivot into middle
|
||||||
|
data[pivot], data[b-1] = data[b-1], data[pivot]
|
||||||
|
return b - 1, c
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insertion sort
|
||||||
|
func insertionSort(data []literalNode, a, b int) {
|
||||||
|
for i := a + 1; i < b; i++ {
|
||||||
|
for j := i; j > a && data[j].literal < data[j-1].literal; j-- {
|
||||||
|
data[j], data[j-1] = data[j-1], data[j]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// maxDepth returns a threshold at which quicksort should switch
|
||||||
|
// to heapsort. It returns 2*ceil(lg(n+1)).
|
||||||
|
func maxDepth(n int) int {
|
||||||
|
var depth int
|
||||||
|
for i := n; i > 0; i >>= 1 {
|
||||||
|
depth++
|
||||||
|
}
|
||||||
|
return depth * 2
|
||||||
|
}
|
||||||
|
|
||||||
|
// medianOfThree moves the median of the three values data[m0], data[m1], data[m2] into data[m1].
|
||||||
|
func medianOfThree(data []literalNode, m1, m0, m2 int) {
|
||||||
|
// sort 3 elements
|
||||||
|
if data[m1].literal < data[m0].literal {
|
||||||
|
data[m1], data[m0] = data[m0], data[m1]
|
||||||
|
}
|
||||||
|
// data[m0] <= data[m1]
|
||||||
|
if data[m2].literal < data[m1].literal {
|
||||||
|
data[m2], data[m1] = data[m1], data[m2]
|
||||||
|
// data[m0] <= data[m2] && data[m1] < data[m2]
|
||||||
|
if data[m1].literal < data[m0].literal {
|
||||||
|
data[m1], data[m0] = data[m0], data[m1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// now data[m0] <= data[m1] <= data[m2]
|
||||||
|
}
|
|
@ -0,0 +1,793 @@
|
||||||
|
// Copyright 2009 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// Package flate implements the DEFLATE compressed data format, described in
|
||||||
|
// RFC 1951. The gzip and zlib packages implement access to DEFLATE-based file
|
||||||
|
// formats.
|
||||||
|
package flate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"compress/flate"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"math/bits"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
maxCodeLen = 16 // max length of Huffman code
|
||||||
|
maxCodeLenMask = 15 // mask for max length of Huffman code
|
||||||
|
// The next three numbers come from the RFC section 3.2.7, with the
|
||||||
|
// additional proviso in section 3.2.5 which implies that distance codes
|
||||||
|
// 30 and 31 should never occur in compressed data.
|
||||||
|
maxNumLit = 286
|
||||||
|
maxNumDist = 30
|
||||||
|
numCodes = 19 // number of codes in Huffman meta-code
|
||||||
|
|
||||||
|
debugDecode = false
|
||||||
|
)
|
||||||
|
|
||||||
|
// Value of length - 3 and extra bits.
|
||||||
|
type lengthExtra struct {
|
||||||
|
length, extra uint8
|
||||||
|
}
|
||||||
|
|
||||||
|
var decCodeToLen = [32]lengthExtra{{length: 0x0, extra: 0x0}, {length: 0x1, extra: 0x0}, {length: 0x2, extra: 0x0}, {length: 0x3, extra: 0x0}, {length: 0x4, extra: 0x0}, {length: 0x5, extra: 0x0}, {length: 0x6, extra: 0x0}, {length: 0x7, extra: 0x0}, {length: 0x8, extra: 0x1}, {length: 0xa, extra: 0x1}, {length: 0xc, extra: 0x1}, {length: 0xe, extra: 0x1}, {length: 0x10, extra: 0x2}, {length: 0x14, extra: 0x2}, {length: 0x18, extra: 0x2}, {length: 0x1c, extra: 0x2}, {length: 0x20, extra: 0x3}, {length: 0x28, extra: 0x3}, {length: 0x30, extra: 0x3}, {length: 0x38, extra: 0x3}, {length: 0x40, extra: 0x4}, {length: 0x50, extra: 0x4}, {length: 0x60, extra: 0x4}, {length: 0x70, extra: 0x4}, {length: 0x80, extra: 0x5}, {length: 0xa0, extra: 0x5}, {length: 0xc0, extra: 0x5}, {length: 0xe0, extra: 0x5}, {length: 0xff, extra: 0x0}, {length: 0x0, extra: 0x0}, {length: 0x0, extra: 0x0}, {length: 0x0, extra: 0x0}}
|
||||||
|
|
||||||
|
var bitMask32 = [32]uint32{
|
||||||
|
0, 1, 3, 7, 0xF, 0x1F, 0x3F, 0x7F, 0xFF,
|
||||||
|
0x1FF, 0x3FF, 0x7FF, 0xFFF, 0x1FFF, 0x3FFF, 0x7FFF, 0xFFFF,
|
||||||
|
0x1ffff, 0x3ffff, 0x7FFFF, 0xfFFFF, 0x1fFFFF, 0x3fFFFF, 0x7fFFFF, 0xffFFFF,
|
||||||
|
0x1ffFFFF, 0x3ffFFFF, 0x7ffFFFF, 0xfffFFFF, 0x1fffFFFF, 0x3fffFFFF, 0x7fffFFFF,
|
||||||
|
} // up to 32 bits
|
||||||
|
|
||||||
|
// Initialize the fixedHuffmanDecoder only once upon first use.
|
||||||
|
var fixedOnce sync.Once
|
||||||
|
var fixedHuffmanDecoder huffmanDecoder
|
||||||
|
|
||||||
|
// A CorruptInputError reports the presence of corrupt input at a given offset.
|
||||||
|
type CorruptInputError = flate.CorruptInputError
|
||||||
|
|
||||||
|
// An InternalError reports an error in the flate code itself.
|
||||||
|
type InternalError string
|
||||||
|
|
||||||
|
func (e InternalError) Error() string { return "flate: internal error: " + string(e) }
|
||||||
|
|
||||||
|
// A ReadError reports an error encountered while reading input.
|
||||||
|
//
|
||||||
|
// Deprecated: No longer returned.
|
||||||
|
type ReadError = flate.ReadError
|
||||||
|
|
||||||
|
// A WriteError reports an error encountered while writing output.
|
||||||
|
//
|
||||||
|
// Deprecated: No longer returned.
|
||||||
|
type WriteError = flate.WriteError
|
||||||
|
|
||||||
|
// Resetter resets a ReadCloser returned by NewReader or NewReaderDict to
|
||||||
|
// to switch to a new underlying Reader. This permits reusing a ReadCloser
|
||||||
|
// instead of allocating a new one.
|
||||||
|
type Resetter interface {
|
||||||
|
// Reset discards any buffered data and resets the Resetter as if it was
|
||||||
|
// newly initialized with the given reader.
|
||||||
|
Reset(r io.Reader, dict []byte) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// The data structure for decoding Huffman tables is based on that of
|
||||||
|
// zlib. There is a lookup table of a fixed bit width (huffmanChunkBits),
|
||||||
|
// For codes smaller than the table width, there are multiple entries
|
||||||
|
// (each combination of trailing bits has the same value). For codes
|
||||||
|
// larger than the table width, the table contains a link to an overflow
|
||||||
|
// table. The width of each entry in the link table is the maximum code
|
||||||
|
// size minus the chunk width.
|
||||||
|
//
|
||||||
|
// Note that you can do a lookup in the table even without all bits
|
||||||
|
// filled. Since the extra bits are zero, and the DEFLATE Huffman codes
|
||||||
|
// have the property that shorter codes come before longer ones, the
|
||||||
|
// bit length estimate in the result is a lower bound on the actual
|
||||||
|
// number of bits.
|
||||||
|
//
|
||||||
|
// See the following:
|
||||||
|
// http://www.gzip.org/algorithm.txt
|
||||||
|
|
||||||
|
// chunk & 15 is number of bits
|
||||||
|
// chunk >> 4 is value, including table link
|
||||||
|
|
||||||
|
const (
|
||||||
|
huffmanChunkBits = 9
|
||||||
|
huffmanNumChunks = 1 << huffmanChunkBits
|
||||||
|
huffmanCountMask = 15
|
||||||
|
huffmanValueShift = 4
|
||||||
|
)
|
||||||
|
|
||||||
|
type huffmanDecoder struct {
|
||||||
|
maxRead int // the maximum number of bits we can read and not overread
|
||||||
|
chunks *[huffmanNumChunks]uint16 // chunks as described above
|
||||||
|
links [][]uint16 // overflow links
|
||||||
|
linkMask uint32 // mask the width of the link table
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize Huffman decoding tables from array of code lengths.
|
||||||
|
// Following this function, h is guaranteed to be initialized into a complete
|
||||||
|
// tree (i.e., neither over-subscribed nor under-subscribed). The exception is a
|
||||||
|
// degenerate case where the tree has only a single symbol with length 1. Empty
|
||||||
|
// trees are permitted.
|
||||||
|
func (h *huffmanDecoder) init(lengths []int) bool {
|
||||||
|
// Sanity enables additional runtime tests during Huffman
|
||||||
|
// table construction. It's intended to be used during
|
||||||
|
// development to supplement the currently ad-hoc unit tests.
|
||||||
|
const sanity = false
|
||||||
|
|
||||||
|
if h.chunks == nil {
|
||||||
|
h.chunks = &[huffmanNumChunks]uint16{}
|
||||||
|
}
|
||||||
|
if h.maxRead != 0 {
|
||||||
|
*h = huffmanDecoder{chunks: h.chunks, links: h.links}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count number of codes of each length,
|
||||||
|
// compute maxRead and max length.
|
||||||
|
var count [maxCodeLen]int
|
||||||
|
var min, max int
|
||||||
|
for _, n := range lengths {
|
||||||
|
if n == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if min == 0 || n < min {
|
||||||
|
min = n
|
||||||
|
}
|
||||||
|
if n > max {
|
||||||
|
max = n
|
||||||
|
}
|
||||||
|
count[n&maxCodeLenMask]++
|
||||||
|
}
|
||||||
|
|
||||||
|
// Empty tree. The decompressor.huffSym function will fail later if the tree
|
||||||
|
// is used. Technically, an empty tree is only valid for the HDIST tree and
|
||||||
|
// not the HCLEN and HLIT tree. However, a stream with an empty HCLEN tree
|
||||||
|
// is guaranteed to fail since it will attempt to use the tree to decode the
|
||||||
|
// codes for the HLIT and HDIST trees. Similarly, an empty HLIT tree is
|
||||||
|
// guaranteed to fail later since the compressed data section must be
|
||||||
|
// composed of at least one symbol (the end-of-block marker).
|
||||||
|
if max == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
code := 0
|
||||||
|
var nextcode [maxCodeLen]int
|
||||||
|
for i := min; i <= max; i++ {
|
||||||
|
code <<= 1
|
||||||
|
nextcode[i&maxCodeLenMask] = code
|
||||||
|
code += count[i&maxCodeLenMask]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that the coding is complete (i.e., that we've
|
||||||
|
// assigned all 2-to-the-max possible bit sequences).
|
||||||
|
// Exception: To be compatible with zlib, we also need to
|
||||||
|
// accept degenerate single-code codings. See also
|
||||||
|
// TestDegenerateHuffmanCoding.
|
||||||
|
if code != 1<<uint(max) && !(code == 1 && max == 1) {
|
||||||
|
if debugDecode {
|
||||||
|
fmt.Println("coding failed, code, max:", code, max, code == 1<<uint(max), code == 1 && max == 1, "(one should be true)")
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
h.maxRead = min
|
||||||
|
chunks := h.chunks[:]
|
||||||
|
for i := range chunks {
|
||||||
|
chunks[i] = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if max > huffmanChunkBits {
|
||||||
|
numLinks := 1 << (uint(max) - huffmanChunkBits)
|
||||||
|
h.linkMask = uint32(numLinks - 1)
|
||||||
|
|
||||||
|
// create link tables
|
||||||
|
link := nextcode[huffmanChunkBits+1] >> 1
|
||||||
|
if cap(h.links) < huffmanNumChunks-link {
|
||||||
|
h.links = make([][]uint16, huffmanNumChunks-link)
|
||||||
|
} else {
|
||||||
|
h.links = h.links[:huffmanNumChunks-link]
|
||||||
|
}
|
||||||
|
for j := uint(link); j < huffmanNumChunks; j++ {
|
||||||
|
reverse := int(bits.Reverse16(uint16(j)))
|
||||||
|
reverse >>= uint(16 - huffmanChunkBits)
|
||||||
|
off := j - uint(link)
|
||||||
|
if sanity && h.chunks[reverse] != 0 {
|
||||||
|
panic("impossible: overwriting existing chunk")
|
||||||
|
}
|
||||||
|
h.chunks[reverse] = uint16(off<<huffmanValueShift | (huffmanChunkBits + 1))
|
||||||
|
if cap(h.links[off]) < numLinks {
|
||||||
|
h.links[off] = make([]uint16, numLinks)
|
||||||
|
} else {
|
||||||
|
links := h.links[off][:0]
|
||||||
|
h.links[off] = links[:numLinks]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
h.links = h.links[:0]
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, n := range lengths {
|
||||||
|
if n == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
code := nextcode[n]
|
||||||
|
nextcode[n]++
|
||||||
|
chunk := uint16(i<<huffmanValueShift | n)
|
||||||
|
reverse := int(bits.Reverse16(uint16(code)))
|
||||||
|
reverse >>= uint(16 - n)
|
||||||
|
if n <= huffmanChunkBits {
|
||||||
|
for off := reverse; off < len(h.chunks); off += 1 << uint(n) {
|
||||||
|
// We should never need to overwrite
|
||||||
|
// an existing chunk. Also, 0 is
|
||||||
|
// never a valid chunk, because the
|
||||||
|
// lower 4 "count" bits should be
|
||||||
|
// between 1 and 15.
|
||||||
|
if sanity && h.chunks[off] != 0 {
|
||||||
|
panic("impossible: overwriting existing chunk")
|
||||||
|
}
|
||||||
|
h.chunks[off] = chunk
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
j := reverse & (huffmanNumChunks - 1)
|
||||||
|
if sanity && h.chunks[j]&huffmanCountMask != huffmanChunkBits+1 {
|
||||||
|
// Longer codes should have been
|
||||||
|
// associated with a link table above.
|
||||||
|
panic("impossible: not an indirect chunk")
|
||||||
|
}
|
||||||
|
value := h.chunks[j] >> huffmanValueShift
|
||||||
|
linktab := h.links[value]
|
||||||
|
reverse >>= huffmanChunkBits
|
||||||
|
for off := reverse; off < len(linktab); off += 1 << uint(n-huffmanChunkBits) {
|
||||||
|
if sanity && linktab[off] != 0 {
|
||||||
|
panic("impossible: overwriting existing chunk")
|
||||||
|
}
|
||||||
|
linktab[off] = chunk
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if sanity {
|
||||||
|
// Above we've sanity checked that we never overwrote
|
||||||
|
// an existing entry. Here we additionally check that
|
||||||
|
// we filled the tables completely.
|
||||||
|
for i, chunk := range h.chunks {
|
||||||
|
if chunk == 0 {
|
||||||
|
// As an exception, in the degenerate
|
||||||
|
// single-code case, we allow odd
|
||||||
|
// chunks to be missing.
|
||||||
|
if code == 1 && i%2 == 1 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
panic("impossible: missing chunk")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, linktab := range h.links {
|
||||||
|
for _, chunk := range linktab {
|
||||||
|
if chunk == 0 {
|
||||||
|
panic("impossible: missing chunk")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// The actual read interface needed by NewReader.
|
||||||
|
// If the passed in io.Reader does not also have ReadByte,
|
||||||
|
// the NewReader will introduce its own buffering.
|
||||||
|
type Reader interface {
|
||||||
|
io.Reader
|
||||||
|
io.ByteReader
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decompress state.
|
||||||
|
type decompressor struct {
|
||||||
|
// Input source.
|
||||||
|
r Reader
|
||||||
|
roffset int64
|
||||||
|
|
||||||
|
// Huffman decoders for literal/length, distance.
|
||||||
|
h1, h2 huffmanDecoder
|
||||||
|
|
||||||
|
// Length arrays used to define Huffman codes.
|
||||||
|
bits *[maxNumLit + maxNumDist]int
|
||||||
|
codebits *[numCodes]int
|
||||||
|
|
||||||
|
// Output history, buffer.
|
||||||
|
dict dictDecoder
|
||||||
|
|
||||||
|
// Next step in the decompression,
|
||||||
|
// and decompression state.
|
||||||
|
step func(*decompressor)
|
||||||
|
stepState int
|
||||||
|
err error
|
||||||
|
toRead []byte
|
||||||
|
hl, hd *huffmanDecoder
|
||||||
|
copyLen int
|
||||||
|
copyDist int
|
||||||
|
|
||||||
|
// Temporary buffer (avoids repeated allocation).
|
||||||
|
buf [4]byte
|
||||||
|
|
||||||
|
// Input bits, in top of b.
|
||||||
|
b uint32
|
||||||
|
|
||||||
|
nb uint
|
||||||
|
final bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *decompressor) nextBlock() {
|
||||||
|
for f.nb < 1+2 {
|
||||||
|
if f.err = f.moreBits(); f.err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
f.final = f.b&1 == 1
|
||||||
|
f.b >>= 1
|
||||||
|
typ := f.b & 3
|
||||||
|
f.b >>= 2
|
||||||
|
f.nb -= 1 + 2
|
||||||
|
switch typ {
|
||||||
|
case 0:
|
||||||
|
f.dataBlock()
|
||||||
|
if debugDecode {
|
||||||
|
fmt.Println("stored block")
|
||||||
|
}
|
||||||
|
case 1:
|
||||||
|
// compressed, fixed Huffman tables
|
||||||
|
f.hl = &fixedHuffmanDecoder
|
||||||
|
f.hd = nil
|
||||||
|
f.huffmanBlockDecoder()()
|
||||||
|
if debugDecode {
|
||||||
|
fmt.Println("predefinied huffman block")
|
||||||
|
}
|
||||||
|
case 2:
|
||||||
|
// compressed, dynamic Huffman tables
|
||||||
|
if f.err = f.readHuffman(); f.err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
f.hl = &f.h1
|
||||||
|
f.hd = &f.h2
|
||||||
|
f.huffmanBlockDecoder()()
|
||||||
|
if debugDecode {
|
||||||
|
fmt.Println("dynamic huffman block")
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// 3 is reserved.
|
||||||
|
if debugDecode {
|
||||||
|
fmt.Println("reserved data block encountered")
|
||||||
|
}
|
||||||
|
f.err = CorruptInputError(f.roffset)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *decompressor) Read(b []byte) (int, error) {
|
||||||
|
for {
|
||||||
|
if len(f.toRead) > 0 {
|
||||||
|
n := copy(b, f.toRead)
|
||||||
|
f.toRead = f.toRead[n:]
|
||||||
|
if len(f.toRead) == 0 {
|
||||||
|
return n, f.err
|
||||||
|
}
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
if f.err != nil {
|
||||||
|
return 0, f.err
|
||||||
|
}
|
||||||
|
f.step(f)
|
||||||
|
if f.err != nil && len(f.toRead) == 0 {
|
||||||
|
f.toRead = f.dict.readFlush() // Flush what's left in case of error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Support the io.WriteTo interface for io.Copy and friends.
|
||||||
|
func (f *decompressor) WriteTo(w io.Writer) (int64, error) {
|
||||||
|
total := int64(0)
|
||||||
|
flushed := false
|
||||||
|
for {
|
||||||
|
if len(f.toRead) > 0 {
|
||||||
|
n, err := w.Write(f.toRead)
|
||||||
|
total += int64(n)
|
||||||
|
if err != nil {
|
||||||
|
f.err = err
|
||||||
|
return total, err
|
||||||
|
}
|
||||||
|
if n != len(f.toRead) {
|
||||||
|
return total, io.ErrShortWrite
|
||||||
|
}
|
||||||
|
f.toRead = f.toRead[:0]
|
||||||
|
}
|
||||||
|
if f.err != nil && flushed {
|
||||||
|
if f.err == io.EOF {
|
||||||
|
return total, nil
|
||||||
|
}
|
||||||
|
return total, f.err
|
||||||
|
}
|
||||||
|
if f.err == nil {
|
||||||
|
f.step(f)
|
||||||
|
}
|
||||||
|
if len(f.toRead) == 0 && f.err != nil && !flushed {
|
||||||
|
f.toRead = f.dict.readFlush() // Flush what's left in case of error
|
||||||
|
flushed = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *decompressor) Close() error {
|
||||||
|
if f.err == io.EOF {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return f.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// RFC 1951 section 3.2.7.
|
||||||
|
// Compression with dynamic Huffman codes
|
||||||
|
|
||||||
|
var codeOrder = [...]int{16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15}
|
||||||
|
|
||||||
|
func (f *decompressor) readHuffman() error {
|
||||||
|
// HLIT[5], HDIST[5], HCLEN[4].
|
||||||
|
for f.nb < 5+5+4 {
|
||||||
|
if err := f.moreBits(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nlit := int(f.b&0x1F) + 257
|
||||||
|
if nlit > maxNumLit {
|
||||||
|
if debugDecode {
|
||||||
|
fmt.Println("nlit > maxNumLit", nlit)
|
||||||
|
}
|
||||||
|
return CorruptInputError(f.roffset)
|
||||||
|
}
|
||||||
|
f.b >>= 5
|
||||||
|
ndist := int(f.b&0x1F) + 1
|
||||||
|
if ndist > maxNumDist {
|
||||||
|
if debugDecode {
|
||||||
|
fmt.Println("ndist > maxNumDist", ndist)
|
||||||
|
}
|
||||||
|
return CorruptInputError(f.roffset)
|
||||||
|
}
|
||||||
|
f.b >>= 5
|
||||||
|
nclen := int(f.b&0xF) + 4
|
||||||
|
// numCodes is 19, so nclen is always valid.
|
||||||
|
f.b >>= 4
|
||||||
|
f.nb -= 5 + 5 + 4
|
||||||
|
|
||||||
|
// (HCLEN+4)*3 bits: code lengths in the magic codeOrder order.
|
||||||
|
for i := 0; i < nclen; i++ {
|
||||||
|
for f.nb < 3 {
|
||||||
|
if err := f.moreBits(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
f.codebits[codeOrder[i]] = int(f.b & 0x7)
|
||||||
|
f.b >>= 3
|
||||||
|
f.nb -= 3
|
||||||
|
}
|
||||||
|
for i := nclen; i < len(codeOrder); i++ {
|
||||||
|
f.codebits[codeOrder[i]] = 0
|
||||||
|
}
|
||||||
|
if !f.h1.init(f.codebits[0:]) {
|
||||||
|
if debugDecode {
|
||||||
|
fmt.Println("init codebits failed")
|
||||||
|
}
|
||||||
|
return CorruptInputError(f.roffset)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HLIT + 257 code lengths, HDIST + 1 code lengths,
|
||||||
|
// using the code length Huffman code.
|
||||||
|
for i, n := 0, nlit+ndist; i < n; {
|
||||||
|
x, err := f.huffSym(&f.h1)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if x < 16 {
|
||||||
|
// Actual length.
|
||||||
|
f.bits[i] = x
|
||||||
|
i++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Repeat previous length or zero.
|
||||||
|
var rep int
|
||||||
|
var nb uint
|
||||||
|
var b int
|
||||||
|
switch x {
|
||||||
|
default:
|
||||||
|
return InternalError("unexpected length code")
|
||||||
|
case 16:
|
||||||
|
rep = 3
|
||||||
|
nb = 2
|
||||||
|
if i == 0 {
|
||||||
|
if debugDecode {
|
||||||
|
fmt.Println("i==0")
|
||||||
|
}
|
||||||
|
return CorruptInputError(f.roffset)
|
||||||
|
}
|
||||||
|
b = f.bits[i-1]
|
||||||
|
case 17:
|
||||||
|
rep = 3
|
||||||
|
nb = 3
|
||||||
|
b = 0
|
||||||
|
case 18:
|
||||||
|
rep = 11
|
||||||
|
nb = 7
|
||||||
|
b = 0
|
||||||
|
}
|
||||||
|
for f.nb < nb {
|
||||||
|
if err := f.moreBits(); err != nil {
|
||||||
|
if debugDecode {
|
||||||
|
fmt.Println("morebits:", err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rep += int(f.b & uint32(1<<(nb®SizeMaskUint32)-1))
|
||||||
|
f.b >>= nb & regSizeMaskUint32
|
||||||
|
f.nb -= nb
|
||||||
|
if i+rep > n {
|
||||||
|
if debugDecode {
|
||||||
|
fmt.Println("i+rep > n", i, rep, n)
|
||||||
|
}
|
||||||
|
return CorruptInputError(f.roffset)
|
||||||
|
}
|
||||||
|
for j := 0; j < rep; j++ {
|
||||||
|
f.bits[i] = b
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !f.h1.init(f.bits[0:nlit]) || !f.h2.init(f.bits[nlit:nlit+ndist]) {
|
||||||
|
if debugDecode {
|
||||||
|
fmt.Println("init2 failed")
|
||||||
|
}
|
||||||
|
return CorruptInputError(f.roffset)
|
||||||
|
}
|
||||||
|
|
||||||
|
// As an optimization, we can initialize the maxRead bits to read at a time
|
||||||
|
// for the HLIT tree to the length of the EOB marker since we know that
|
||||||
|
// every block must terminate with one. This preserves the property that
|
||||||
|
// we never read any extra bytes after the end of the DEFLATE stream.
|
||||||
|
if f.h1.maxRead < f.bits[endBlockMarker] {
|
||||||
|
f.h1.maxRead = f.bits[endBlockMarker]
|
||||||
|
}
|
||||||
|
if !f.final {
|
||||||
|
// If not the final block, the smallest block possible is
|
||||||
|
// a predefined table, BTYPE=01, with a single EOB marker.
|
||||||
|
// This will take up 3 + 7 bits.
|
||||||
|
f.h1.maxRead += 10
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy a single uncompressed data block from input to output.
|
||||||
|
func (f *decompressor) dataBlock() {
|
||||||
|
// Uncompressed.
|
||||||
|
// Discard current half-byte.
|
||||||
|
left := (f.nb) & 7
|
||||||
|
f.nb -= left
|
||||||
|
f.b >>= left
|
||||||
|
|
||||||
|
offBytes := f.nb >> 3
|
||||||
|
// Unfilled values will be overwritten.
|
||||||
|
f.buf[0] = uint8(f.b)
|
||||||
|
f.buf[1] = uint8(f.b >> 8)
|
||||||
|
f.buf[2] = uint8(f.b >> 16)
|
||||||
|
f.buf[3] = uint8(f.b >> 24)
|
||||||
|
|
||||||
|
f.roffset += int64(offBytes)
|
||||||
|
f.nb, f.b = 0, 0
|
||||||
|
|
||||||
|
// Length then ones-complement of length.
|
||||||
|
nr, err := io.ReadFull(f.r, f.buf[offBytes:4])
|
||||||
|
f.roffset += int64(nr)
|
||||||
|
if err != nil {
|
||||||
|
f.err = noEOF(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
n := uint16(f.buf[0]) | uint16(f.buf[1])<<8
|
||||||
|
nn := uint16(f.buf[2]) | uint16(f.buf[3])<<8
|
||||||
|
if nn != ^n {
|
||||||
|
if debugDecode {
|
||||||
|
ncomp := ^n
|
||||||
|
fmt.Println("uint16(nn) != uint16(^n)", nn, ncomp)
|
||||||
|
}
|
||||||
|
f.err = CorruptInputError(f.roffset)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if n == 0 {
|
||||||
|
f.toRead = f.dict.readFlush()
|
||||||
|
f.finishBlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
f.copyLen = int(n)
|
||||||
|
f.copyData()
|
||||||
|
}
|
||||||
|
|
||||||
|
// copyData copies f.copyLen bytes from the underlying reader into f.hist.
|
||||||
|
// It pauses for reads when f.hist is full.
|
||||||
|
func (f *decompressor) copyData() {
|
||||||
|
buf := f.dict.writeSlice()
|
||||||
|
if len(buf) > f.copyLen {
|
||||||
|
buf = buf[:f.copyLen]
|
||||||
|
}
|
||||||
|
|
||||||
|
cnt, err := io.ReadFull(f.r, buf)
|
||||||
|
f.roffset += int64(cnt)
|
||||||
|
f.copyLen -= cnt
|
||||||
|
f.dict.writeMark(cnt)
|
||||||
|
if err != nil {
|
||||||
|
f.err = noEOF(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if f.dict.availWrite() == 0 || f.copyLen > 0 {
|
||||||
|
f.toRead = f.dict.readFlush()
|
||||||
|
f.step = (*decompressor).copyData
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f.finishBlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *decompressor) finishBlock() {
|
||||||
|
if f.final {
|
||||||
|
if f.dict.availRead() > 0 {
|
||||||
|
f.toRead = f.dict.readFlush()
|
||||||
|
}
|
||||||
|
f.err = io.EOF
|
||||||
|
}
|
||||||
|
f.step = (*decompressor).nextBlock
|
||||||
|
}
|
||||||
|
|
||||||
|
// noEOF returns err, unless err == io.EOF, in which case it returns io.ErrUnexpectedEOF.
|
||||||
|
func noEOF(e error) error {
|
||||||
|
if e == io.EOF {
|
||||||
|
return io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *decompressor) moreBits() error {
|
||||||
|
c, err := f.r.ReadByte()
|
||||||
|
if err != nil {
|
||||||
|
return noEOF(err)
|
||||||
|
}
|
||||||
|
f.roffset++
|
||||||
|
f.b |= uint32(c) << (f.nb & regSizeMaskUint32)
|
||||||
|
f.nb += 8
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the next Huffman-encoded symbol from f according to h.
|
||||||
|
func (f *decompressor) huffSym(h *huffmanDecoder) (int, error) {
|
||||||
|
// Since a huffmanDecoder can be empty or be composed of a degenerate tree
|
||||||
|
// with single element, huffSym must error on these two edge cases. In both
|
||||||
|
// cases, the chunks slice will be 0 for the invalid sequence, leading it
|
||||||
|
// satisfy the n == 0 check below.
|
||||||
|
n := uint(h.maxRead)
|
||||||
|
// Optimization. Compiler isn't smart enough to keep f.b,f.nb in registers,
|
||||||
|
// but is smart enough to keep local variables in registers, so use nb and b,
|
||||||
|
// inline call to moreBits and reassign b,nb back to f on return.
|
||||||
|
nb, b := f.nb, f.b
|
||||||
|
for {
|
||||||
|
for nb < n {
|
||||||
|
c, err := f.r.ReadByte()
|
||||||
|
if err != nil {
|
||||||
|
f.b = b
|
||||||
|
f.nb = nb
|
||||||
|
return 0, noEOF(err)
|
||||||
|
}
|
||||||
|
f.roffset++
|
||||||
|
b |= uint32(c) << (nb & regSizeMaskUint32)
|
||||||
|
nb += 8
|
||||||
|
}
|
||||||
|
chunk := h.chunks[b&(huffmanNumChunks-1)]
|
||||||
|
n = uint(chunk & huffmanCountMask)
|
||||||
|
if n > huffmanChunkBits {
|
||||||
|
chunk = h.links[chunk>>huffmanValueShift][(b>>huffmanChunkBits)&h.linkMask]
|
||||||
|
n = uint(chunk & huffmanCountMask)
|
||||||
|
}
|
||||||
|
if n <= nb {
|
||||||
|
if n == 0 {
|
||||||
|
f.b = b
|
||||||
|
f.nb = nb
|
||||||
|
if debugDecode {
|
||||||
|
fmt.Println("huffsym: n==0")
|
||||||
|
}
|
||||||
|
f.err = CorruptInputError(f.roffset)
|
||||||
|
return 0, f.err
|
||||||
|
}
|
||||||
|
f.b = b >> (n & regSizeMaskUint32)
|
||||||
|
f.nb = nb - n
|
||||||
|
return int(chunk >> huffmanValueShift), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeReader(r io.Reader) Reader {
|
||||||
|
if rr, ok := r.(Reader); ok {
|
||||||
|
return rr
|
||||||
|
}
|
||||||
|
return bufio.NewReader(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func fixedHuffmanDecoderInit() {
|
||||||
|
fixedOnce.Do(func() {
|
||||||
|
// These come from the RFC section 3.2.6.
|
||||||
|
var bits [288]int
|
||||||
|
for i := 0; i < 144; i++ {
|
||||||
|
bits[i] = 8
|
||||||
|
}
|
||||||
|
for i := 144; i < 256; i++ {
|
||||||
|
bits[i] = 9
|
||||||
|
}
|
||||||
|
for i := 256; i < 280; i++ {
|
||||||
|
bits[i] = 7
|
||||||
|
}
|
||||||
|
for i := 280; i < 288; i++ {
|
||||||
|
bits[i] = 8
|
||||||
|
}
|
||||||
|
fixedHuffmanDecoder.init(bits[:])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *decompressor) Reset(r io.Reader, dict []byte) error {
|
||||||
|
*f = decompressor{
|
||||||
|
r: makeReader(r),
|
||||||
|
bits: f.bits,
|
||||||
|
codebits: f.codebits,
|
||||||
|
h1: f.h1,
|
||||||
|
h2: f.h2,
|
||||||
|
dict: f.dict,
|
||||||
|
step: (*decompressor).nextBlock,
|
||||||
|
}
|
||||||
|
f.dict.init(maxMatchOffset, dict)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewReader returns a new ReadCloser that can be used
|
||||||
|
// to read the uncompressed version of r.
|
||||||
|
// If r does not also implement io.ByteReader,
|
||||||
|
// the decompressor may read more data than necessary from r.
|
||||||
|
// It is the caller's responsibility to call Close on the ReadCloser
|
||||||
|
// when finished reading.
|
||||||
|
//
|
||||||
|
// The ReadCloser returned by NewReader also implements Resetter.
|
||||||
|
func NewReader(r io.Reader) io.ReadCloser {
|
||||||
|
fixedHuffmanDecoderInit()
|
||||||
|
|
||||||
|
var f decompressor
|
||||||
|
f.r = makeReader(r)
|
||||||
|
f.bits = new([maxNumLit + maxNumDist]int)
|
||||||
|
f.codebits = new([numCodes]int)
|
||||||
|
f.step = (*decompressor).nextBlock
|
||||||
|
f.dict.init(maxMatchOffset, nil)
|
||||||
|
return &f
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewReaderDict is like NewReader but initializes the reader
|
||||||
|
// with a preset dictionary. The returned Reader behaves as if
|
||||||
|
// the uncompressed data stream started with the given dictionary,
|
||||||
|
// which has already been read. NewReaderDict is typically used
|
||||||
|
// to read data compressed by NewWriterDict.
|
||||||
|
//
|
||||||
|
// The ReadCloser returned by NewReader also implements Resetter.
|
||||||
|
func NewReaderDict(r io.Reader, dict []byte) io.ReadCloser {
|
||||||
|
fixedHuffmanDecoderInit()
|
||||||
|
|
||||||
|
var f decompressor
|
||||||
|
f.r = makeReader(r)
|
||||||
|
f.bits = new([maxNumLit + maxNumDist]int)
|
||||||
|
f.codebits = new([numCodes]int)
|
||||||
|
f.step = (*decompressor).nextBlock
|
||||||
|
f.dict.init(maxMatchOffset, dict)
|
||||||
|
return &f
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,241 @@
|
||||||
|
package flate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"math/bits"
|
||||||
|
)
|
||||||
|
|
||||||
|
// fastGen maintains the table for matches,
|
||||||
|
// and the previous byte block for level 2.
|
||||||
|
// This is the generic implementation.
|
||||||
|
type fastEncL1 struct {
|
||||||
|
fastGen
|
||||||
|
table [tableSize]tableEntry
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeL1 uses a similar algorithm to level 1
|
||||||
|
func (e *fastEncL1) Encode(dst *tokens, src []byte) {
|
||||||
|
const (
|
||||||
|
inputMargin = 12 - 1
|
||||||
|
minNonLiteralBlockSize = 1 + 1 + inputMargin
|
||||||
|
hashBytes = 5
|
||||||
|
)
|
||||||
|
if debugDeflate && e.cur < 0 {
|
||||||
|
panic(fmt.Sprint("e.cur < 0: ", e.cur))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Protect against e.cur wraparound.
|
||||||
|
for e.cur >= bufferReset {
|
||||||
|
if len(e.hist) == 0 {
|
||||||
|
for i := range e.table[:] {
|
||||||
|
e.table[i] = tableEntry{}
|
||||||
|
}
|
||||||
|
e.cur = maxMatchOffset
|
||||||
|
break
|
||||||
|
}
|
||||||
|
// Shift down everything in the table that isn't already too far away.
|
||||||
|
minOff := e.cur + int32(len(e.hist)) - maxMatchOffset
|
||||||
|
for i := range e.table[:] {
|
||||||
|
v := e.table[i].offset
|
||||||
|
if v <= minOff {
|
||||||
|
v = 0
|
||||||
|
} else {
|
||||||
|
v = v - e.cur + maxMatchOffset
|
||||||
|
}
|
||||||
|
e.table[i].offset = v
|
||||||
|
}
|
||||||
|
e.cur = maxMatchOffset
|
||||||
|
}
|
||||||
|
|
||||||
|
s := e.addBlock(src)
|
||||||
|
|
||||||
|
// This check isn't in the Snappy implementation, but there, the caller
|
||||||
|
// instead of the callee handles this case.
|
||||||
|
if len(src) < minNonLiteralBlockSize {
|
||||||
|
// We do not fill the token table.
|
||||||
|
// This will be picked up by caller.
|
||||||
|
dst.n = uint16(len(src))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override src
|
||||||
|
src = e.hist
|
||||||
|
nextEmit := s
|
||||||
|
|
||||||
|
// sLimit is when to stop looking for offset/length copies. The inputMargin
|
||||||
|
// lets us use a fast path for emitLiteral in the main loop, while we are
|
||||||
|
// looking for copies.
|
||||||
|
sLimit := int32(len(src) - inputMargin)
|
||||||
|
|
||||||
|
// nextEmit is where in src the next emitLiteral should start from.
|
||||||
|
cv := load6432(src, s)
|
||||||
|
|
||||||
|
for {
|
||||||
|
const skipLog = 5
|
||||||
|
const doEvery = 2
|
||||||
|
|
||||||
|
nextS := s
|
||||||
|
var candidate tableEntry
|
||||||
|
for {
|
||||||
|
nextHash := hashLen(cv, tableBits, hashBytes)
|
||||||
|
candidate = e.table[nextHash]
|
||||||
|
nextS = s + doEvery + (s-nextEmit)>>skipLog
|
||||||
|
if nextS > sLimit {
|
||||||
|
goto emitRemainder
|
||||||
|
}
|
||||||
|
|
||||||
|
now := load6432(src, nextS)
|
||||||
|
e.table[nextHash] = tableEntry{offset: s + e.cur}
|
||||||
|
nextHash = hashLen(now, tableBits, hashBytes)
|
||||||
|
|
||||||
|
offset := s - (candidate.offset - e.cur)
|
||||||
|
if offset < maxMatchOffset && uint32(cv) == load3232(src, candidate.offset-e.cur) {
|
||||||
|
e.table[nextHash] = tableEntry{offset: nextS + e.cur}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do one right away...
|
||||||
|
cv = now
|
||||||
|
s = nextS
|
||||||
|
nextS++
|
||||||
|
candidate = e.table[nextHash]
|
||||||
|
now >>= 8
|
||||||
|
e.table[nextHash] = tableEntry{offset: s + e.cur}
|
||||||
|
|
||||||
|
offset = s - (candidate.offset - e.cur)
|
||||||
|
if offset < maxMatchOffset && uint32(cv) == load3232(src, candidate.offset-e.cur) {
|
||||||
|
e.table[nextHash] = tableEntry{offset: nextS + e.cur}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
cv = now
|
||||||
|
s = nextS
|
||||||
|
}
|
||||||
|
|
||||||
|
// A 4-byte match has been found. We'll later see if more than 4 bytes
|
||||||
|
// match. But, prior to the match, src[nextEmit:s] are unmatched. Emit
|
||||||
|
// them as literal bytes.
|
||||||
|
for {
|
||||||
|
// Invariant: we have a 4-byte match at s, and no need to emit any
|
||||||
|
// literal bytes prior to s.
|
||||||
|
|
||||||
|
// Extend the 4-byte match as long as possible.
|
||||||
|
t := candidate.offset - e.cur
|
||||||
|
var l = int32(4)
|
||||||
|
if false {
|
||||||
|
l = e.matchlenLong(s+4, t+4, src) + 4
|
||||||
|
} else {
|
||||||
|
// inlined:
|
||||||
|
a := src[s+4:]
|
||||||
|
b := src[t+4:]
|
||||||
|
for len(a) >= 8 {
|
||||||
|
if diff := binary.LittleEndian.Uint64(a) ^ binary.LittleEndian.Uint64(b); diff != 0 {
|
||||||
|
l += int32(bits.TrailingZeros64(diff) >> 3)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
l += 8
|
||||||
|
a = a[8:]
|
||||||
|
b = b[8:]
|
||||||
|
}
|
||||||
|
if len(a) < 8 {
|
||||||
|
b = b[:len(a)]
|
||||||
|
for i := range a {
|
||||||
|
if a[i] != b[i] {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
l++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extend backwards
|
||||||
|
for t > 0 && s > nextEmit && src[t-1] == src[s-1] {
|
||||||
|
s--
|
||||||
|
t--
|
||||||
|
l++
|
||||||
|
}
|
||||||
|
if nextEmit < s {
|
||||||
|
if false {
|
||||||
|
emitLiteral(dst, src[nextEmit:s])
|
||||||
|
} else {
|
||||||
|
for _, v := range src[nextEmit:s] {
|
||||||
|
dst.tokens[dst.n] = token(v)
|
||||||
|
dst.litHist[v]++
|
||||||
|
dst.n++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save the match found
|
||||||
|
if false {
|
||||||
|
dst.AddMatchLong(l, uint32(s-t-baseMatchOffset))
|
||||||
|
} else {
|
||||||
|
// Inlined...
|
||||||
|
xoffset := uint32(s - t - baseMatchOffset)
|
||||||
|
xlength := l
|
||||||
|
oc := offsetCode(xoffset)
|
||||||
|
xoffset |= oc << 16
|
||||||
|
for xlength > 0 {
|
||||||
|
xl := xlength
|
||||||
|
if xl > 258 {
|
||||||
|
if xl > 258+baseMatchLength {
|
||||||
|
xl = 258
|
||||||
|
} else {
|
||||||
|
xl = 258 - baseMatchLength
|
||||||
|
}
|
||||||
|
}
|
||||||
|
xlength -= xl
|
||||||
|
xl -= baseMatchLength
|
||||||
|
dst.extraHist[lengthCodes1[uint8(xl)]]++
|
||||||
|
dst.offHist[oc]++
|
||||||
|
dst.tokens[dst.n] = token(matchType | uint32(xl)<<lengthShift | xoffset)
|
||||||
|
dst.n++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s += l
|
||||||
|
nextEmit = s
|
||||||
|
if nextS >= s {
|
||||||
|
s = nextS + 1
|
||||||
|
}
|
||||||
|
if s >= sLimit {
|
||||||
|
// Index first pair after match end.
|
||||||
|
if int(s+l+8) < len(src) {
|
||||||
|
cv := load6432(src, s)
|
||||||
|
e.table[hashLen(cv, tableBits, hashBytes)] = tableEntry{offset: s + e.cur}
|
||||||
|
}
|
||||||
|
goto emitRemainder
|
||||||
|
}
|
||||||
|
|
||||||
|
// We could immediately start working at s now, but to improve
|
||||||
|
// compression we first update the hash table at s-2 and at s. If
|
||||||
|
// another emitCopy is not our next move, also calculate nextHash
|
||||||
|
// at s+1. At least on GOARCH=amd64, these three hash calculations
|
||||||
|
// are faster as one load64 call (with some shifts) instead of
|
||||||
|
// three load32 calls.
|
||||||
|
x := load6432(src, s-2)
|
||||||
|
o := e.cur + s - 2
|
||||||
|
prevHash := hashLen(x, tableBits, hashBytes)
|
||||||
|
e.table[prevHash] = tableEntry{offset: o}
|
||||||
|
x >>= 16
|
||||||
|
currHash := hashLen(x, tableBits, hashBytes)
|
||||||
|
candidate = e.table[currHash]
|
||||||
|
e.table[currHash] = tableEntry{offset: o + 2}
|
||||||
|
|
||||||
|
offset := s - (candidate.offset - e.cur)
|
||||||
|
if offset > maxMatchOffset || uint32(x) != load3232(src, candidate.offset-e.cur) {
|
||||||
|
cv = x >> 8
|
||||||
|
s++
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
emitRemainder:
|
||||||
|
if int(nextEmit) < len(src) {
|
||||||
|
// If nothing was added, don't encode literals.
|
||||||
|
if dst.n == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
emitLiteral(dst, src[nextEmit:])
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,214 @@
|
||||||
|
package flate
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
// fastGen maintains the table for matches,
|
||||||
|
// and the previous byte block for level 2.
|
||||||
|
// This is the generic implementation.
|
||||||
|
type fastEncL2 struct {
|
||||||
|
fastGen
|
||||||
|
table [bTableSize]tableEntry
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeL2 uses a similar algorithm to level 1, but is capable
|
||||||
|
// of matching across blocks giving better compression at a small slowdown.
|
||||||
|
func (e *fastEncL2) Encode(dst *tokens, src []byte) {
|
||||||
|
const (
|
||||||
|
inputMargin = 12 - 1
|
||||||
|
minNonLiteralBlockSize = 1 + 1 + inputMargin
|
||||||
|
hashBytes = 5
|
||||||
|
)
|
||||||
|
|
||||||
|
if debugDeflate && e.cur < 0 {
|
||||||
|
panic(fmt.Sprint("e.cur < 0: ", e.cur))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Protect against e.cur wraparound.
|
||||||
|
for e.cur >= bufferReset {
|
||||||
|
if len(e.hist) == 0 {
|
||||||
|
for i := range e.table[:] {
|
||||||
|
e.table[i] = tableEntry{}
|
||||||
|
}
|
||||||
|
e.cur = maxMatchOffset
|
||||||
|
break
|
||||||
|
}
|
||||||
|
// Shift down everything in the table that isn't already too far away.
|
||||||
|
minOff := e.cur + int32(len(e.hist)) - maxMatchOffset
|
||||||
|
for i := range e.table[:] {
|
||||||
|
v := e.table[i].offset
|
||||||
|
if v <= minOff {
|
||||||
|
v = 0
|
||||||
|
} else {
|
||||||
|
v = v - e.cur + maxMatchOffset
|
||||||
|
}
|
||||||
|
e.table[i].offset = v
|
||||||
|
}
|
||||||
|
e.cur = maxMatchOffset
|
||||||
|
}
|
||||||
|
|
||||||
|
s := e.addBlock(src)
|
||||||
|
|
||||||
|
// This check isn't in the Snappy implementation, but there, the caller
|
||||||
|
// instead of the callee handles this case.
|
||||||
|
if len(src) < minNonLiteralBlockSize {
|
||||||
|
// We do not fill the token table.
|
||||||
|
// This will be picked up by caller.
|
||||||
|
dst.n = uint16(len(src))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override src
|
||||||
|
src = e.hist
|
||||||
|
nextEmit := s
|
||||||
|
|
||||||
|
// sLimit is when to stop looking for offset/length copies. The inputMargin
|
||||||
|
// lets us use a fast path for emitLiteral in the main loop, while we are
|
||||||
|
// looking for copies.
|
||||||
|
sLimit := int32(len(src) - inputMargin)
|
||||||
|
|
||||||
|
// nextEmit is where in src the next emitLiteral should start from.
|
||||||
|
cv := load6432(src, s)
|
||||||
|
for {
|
||||||
|
// When should we start skipping if we haven't found matches in a long while.
|
||||||
|
const skipLog = 5
|
||||||
|
const doEvery = 2
|
||||||
|
|
||||||
|
nextS := s
|
||||||
|
var candidate tableEntry
|
||||||
|
for {
|
||||||
|
nextHash := hashLen(cv, bTableBits, hashBytes)
|
||||||
|
s = nextS
|
||||||
|
nextS = s + doEvery + (s-nextEmit)>>skipLog
|
||||||
|
if nextS > sLimit {
|
||||||
|
goto emitRemainder
|
||||||
|
}
|
||||||
|
candidate = e.table[nextHash]
|
||||||
|
now := load6432(src, nextS)
|
||||||
|
e.table[nextHash] = tableEntry{offset: s + e.cur}
|
||||||
|
nextHash = hashLen(now, bTableBits, hashBytes)
|
||||||
|
|
||||||
|
offset := s - (candidate.offset - e.cur)
|
||||||
|
if offset < maxMatchOffset && uint32(cv) == load3232(src, candidate.offset-e.cur) {
|
||||||
|
e.table[nextHash] = tableEntry{offset: nextS + e.cur}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do one right away...
|
||||||
|
cv = now
|
||||||
|
s = nextS
|
||||||
|
nextS++
|
||||||
|
candidate = e.table[nextHash]
|
||||||
|
now >>= 8
|
||||||
|
e.table[nextHash] = tableEntry{offset: s + e.cur}
|
||||||
|
|
||||||
|
offset = s - (candidate.offset - e.cur)
|
||||||
|
if offset < maxMatchOffset && uint32(cv) == load3232(src, candidate.offset-e.cur) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
cv = now
|
||||||
|
}
|
||||||
|
|
||||||
|
// A 4-byte match has been found. We'll later see if more than 4 bytes
|
||||||
|
// match. But, prior to the match, src[nextEmit:s] are unmatched. Emit
|
||||||
|
// them as literal bytes.
|
||||||
|
|
||||||
|
// Call emitCopy, and then see if another emitCopy could be our next
|
||||||
|
// move. Repeat until we find no match for the input immediately after
|
||||||
|
// what was consumed by the last emitCopy call.
|
||||||
|
//
|
||||||
|
// If we exit this loop normally then we need to call emitLiteral next,
|
||||||
|
// though we don't yet know how big the literal will be. We handle that
|
||||||
|
// by proceeding to the next iteration of the main loop. We also can
|
||||||
|
// exit this loop via goto if we get close to exhausting the input.
|
||||||
|
for {
|
||||||
|
// Invariant: we have a 4-byte match at s, and no need to emit any
|
||||||
|
// literal bytes prior to s.
|
||||||
|
|
||||||
|
// Extend the 4-byte match as long as possible.
|
||||||
|
t := candidate.offset - e.cur
|
||||||
|
l := e.matchlenLong(s+4, t+4, src) + 4
|
||||||
|
|
||||||
|
// Extend backwards
|
||||||
|
for t > 0 && s > nextEmit && src[t-1] == src[s-1] {
|
||||||
|
s--
|
||||||
|
t--
|
||||||
|
l++
|
||||||
|
}
|
||||||
|
if nextEmit < s {
|
||||||
|
if false {
|
||||||
|
emitLiteral(dst, src[nextEmit:s])
|
||||||
|
} else {
|
||||||
|
for _, v := range src[nextEmit:s] {
|
||||||
|
dst.tokens[dst.n] = token(v)
|
||||||
|
dst.litHist[v]++
|
||||||
|
dst.n++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.AddMatchLong(l, uint32(s-t-baseMatchOffset))
|
||||||
|
s += l
|
||||||
|
nextEmit = s
|
||||||
|
if nextS >= s {
|
||||||
|
s = nextS + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if s >= sLimit {
|
||||||
|
// Index first pair after match end.
|
||||||
|
if int(s+l+8) < len(src) {
|
||||||
|
cv := load6432(src, s)
|
||||||
|
e.table[hashLen(cv, bTableBits, hashBytes)] = tableEntry{offset: s + e.cur}
|
||||||
|
}
|
||||||
|
goto emitRemainder
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store every second hash in-between, but offset by 1.
|
||||||
|
for i := s - l + 2; i < s-5; i += 7 {
|
||||||
|
x := load6432(src, i)
|
||||||
|
nextHash := hashLen(x, bTableBits, hashBytes)
|
||||||
|
e.table[nextHash] = tableEntry{offset: e.cur + i}
|
||||||
|
// Skip one
|
||||||
|
x >>= 16
|
||||||
|
nextHash = hashLen(x, bTableBits, hashBytes)
|
||||||
|
e.table[nextHash] = tableEntry{offset: e.cur + i + 2}
|
||||||
|
// Skip one
|
||||||
|
x >>= 16
|
||||||
|
nextHash = hashLen(x, bTableBits, hashBytes)
|
||||||
|
e.table[nextHash] = tableEntry{offset: e.cur + i + 4}
|
||||||
|
}
|
||||||
|
|
||||||
|
// We could immediately start working at s now, but to improve
|
||||||
|
// compression we first update the hash table at s-2 to s. If
|
||||||
|
// another emitCopy is not our next move, also calculate nextHash
|
||||||
|
// at s+1. At least on GOARCH=amd64, these three hash calculations
|
||||||
|
// are faster as one load64 call (with some shifts) instead of
|
||||||
|
// three load32 calls.
|
||||||
|
x := load6432(src, s-2)
|
||||||
|
o := e.cur + s - 2
|
||||||
|
prevHash := hashLen(x, bTableBits, hashBytes)
|
||||||
|
prevHash2 := hashLen(x>>8, bTableBits, hashBytes)
|
||||||
|
e.table[prevHash] = tableEntry{offset: o}
|
||||||
|
e.table[prevHash2] = tableEntry{offset: o + 1}
|
||||||
|
currHash := hashLen(x>>16, bTableBits, hashBytes)
|
||||||
|
candidate = e.table[currHash]
|
||||||
|
e.table[currHash] = tableEntry{offset: o + 2}
|
||||||
|
|
||||||
|
offset := s - (candidate.offset - e.cur)
|
||||||
|
if offset > maxMatchOffset || uint32(x>>16) != load3232(src, candidate.offset-e.cur) {
|
||||||
|
cv = x >> 24
|
||||||
|
s++
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
emitRemainder:
|
||||||
|
if int(nextEmit) < len(src) {
|
||||||
|
// If nothing was added, don't encode literals.
|
||||||
|
if dst.n == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
emitLiteral(dst, src[nextEmit:])
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,241 @@
|
||||||
|
package flate
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
// fastEncL3
|
||||||
|
type fastEncL3 struct {
|
||||||
|
fastGen
|
||||||
|
table [1 << 16]tableEntryPrev
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode uses a similar algorithm to level 2, will check up to two candidates.
|
||||||
|
func (e *fastEncL3) Encode(dst *tokens, src []byte) {
|
||||||
|
const (
|
||||||
|
inputMargin = 12 - 1
|
||||||
|
minNonLiteralBlockSize = 1 + 1 + inputMargin
|
||||||
|
tableBits = 16
|
||||||
|
tableSize = 1 << tableBits
|
||||||
|
hashBytes = 5
|
||||||
|
)
|
||||||
|
|
||||||
|
if debugDeflate && e.cur < 0 {
|
||||||
|
panic(fmt.Sprint("e.cur < 0: ", e.cur))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Protect against e.cur wraparound.
|
||||||
|
for e.cur >= bufferReset {
|
||||||
|
if len(e.hist) == 0 {
|
||||||
|
for i := range e.table[:] {
|
||||||
|
e.table[i] = tableEntryPrev{}
|
||||||
|
}
|
||||||
|
e.cur = maxMatchOffset
|
||||||
|
break
|
||||||
|
}
|
||||||
|
// Shift down everything in the table that isn't already too far away.
|
||||||
|
minOff := e.cur + int32(len(e.hist)) - maxMatchOffset
|
||||||
|
for i := range e.table[:] {
|
||||||
|
v := e.table[i]
|
||||||
|
if v.Cur.offset <= minOff {
|
||||||
|
v.Cur.offset = 0
|
||||||
|
} else {
|
||||||
|
v.Cur.offset = v.Cur.offset - e.cur + maxMatchOffset
|
||||||
|
}
|
||||||
|
if v.Prev.offset <= minOff {
|
||||||
|
v.Prev.offset = 0
|
||||||
|
} else {
|
||||||
|
v.Prev.offset = v.Prev.offset - e.cur + maxMatchOffset
|
||||||
|
}
|
||||||
|
e.table[i] = v
|
||||||
|
}
|
||||||
|
e.cur = maxMatchOffset
|
||||||
|
}
|
||||||
|
|
||||||
|
s := e.addBlock(src)
|
||||||
|
|
||||||
|
// Skip if too small.
|
||||||
|
if len(src) < minNonLiteralBlockSize {
|
||||||
|
// We do not fill the token table.
|
||||||
|
// This will be picked up by caller.
|
||||||
|
dst.n = uint16(len(src))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override src
|
||||||
|
src = e.hist
|
||||||
|
nextEmit := s
|
||||||
|
|
||||||
|
// sLimit is when to stop looking for offset/length copies. The inputMargin
|
||||||
|
// lets us use a fast path for emitLiteral in the main loop, while we are
|
||||||
|
// looking for copies.
|
||||||
|
sLimit := int32(len(src) - inputMargin)
|
||||||
|
|
||||||
|
// nextEmit is where in src the next emitLiteral should start from.
|
||||||
|
cv := load6432(src, s)
|
||||||
|
for {
|
||||||
|
const skipLog = 7
|
||||||
|
nextS := s
|
||||||
|
var candidate tableEntry
|
||||||
|
for {
|
||||||
|
nextHash := hashLen(cv, tableBits, hashBytes)
|
||||||
|
s = nextS
|
||||||
|
nextS = s + 1 + (s-nextEmit)>>skipLog
|
||||||
|
if nextS > sLimit {
|
||||||
|
goto emitRemainder
|
||||||
|
}
|
||||||
|
candidates := e.table[nextHash]
|
||||||
|
now := load6432(src, nextS)
|
||||||
|
|
||||||
|
// Safe offset distance until s + 4...
|
||||||
|
minOffset := e.cur + s - (maxMatchOffset - 4)
|
||||||
|
e.table[nextHash] = tableEntryPrev{Prev: candidates.Cur, Cur: tableEntry{offset: s + e.cur}}
|
||||||
|
|
||||||
|
// Check both candidates
|
||||||
|
candidate = candidates.Cur
|
||||||
|
if candidate.offset < minOffset {
|
||||||
|
cv = now
|
||||||
|
// Previous will also be invalid, we have nothing.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if uint32(cv) == load3232(src, candidate.offset-e.cur) {
|
||||||
|
if candidates.Prev.offset < minOffset || uint32(cv) != load3232(src, candidates.Prev.offset-e.cur) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
// Both match and are valid, pick longest.
|
||||||
|
offset := s - (candidate.offset - e.cur)
|
||||||
|
o2 := s - (candidates.Prev.offset - e.cur)
|
||||||
|
l1, l2 := matchLen(src[s+4:], src[s-offset+4:]), matchLen(src[s+4:], src[s-o2+4:])
|
||||||
|
if l2 > l1 {
|
||||||
|
candidate = candidates.Prev
|
||||||
|
}
|
||||||
|
break
|
||||||
|
} else {
|
||||||
|
// We only check if value mismatches.
|
||||||
|
// Offset will always be invalid in other cases.
|
||||||
|
candidate = candidates.Prev
|
||||||
|
if candidate.offset > minOffset && uint32(cv) == load3232(src, candidate.offset-e.cur) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cv = now
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call emitCopy, and then see if another emitCopy could be our next
|
||||||
|
// move. Repeat until we find no match for the input immediately after
|
||||||
|
// what was consumed by the last emitCopy call.
|
||||||
|
//
|
||||||
|
// If we exit this loop normally then we need to call emitLiteral next,
|
||||||
|
// though we don't yet know how big the literal will be. We handle that
|
||||||
|
// by proceeding to the next iteration of the main loop. We also can
|
||||||
|
// exit this loop via goto if we get close to exhausting the input.
|
||||||
|
for {
|
||||||
|
// Invariant: we have a 4-byte match at s, and no need to emit any
|
||||||
|
// literal bytes prior to s.
|
||||||
|
|
||||||
|
// Extend the 4-byte match as long as possible.
|
||||||
|
//
|
||||||
|
t := candidate.offset - e.cur
|
||||||
|
l := e.matchlenLong(s+4, t+4, src) + 4
|
||||||
|
|
||||||
|
// Extend backwards
|
||||||
|
for t > 0 && s > nextEmit && src[t-1] == src[s-1] {
|
||||||
|
s--
|
||||||
|
t--
|
||||||
|
l++
|
||||||
|
}
|
||||||
|
if nextEmit < s {
|
||||||
|
if false {
|
||||||
|
emitLiteral(dst, src[nextEmit:s])
|
||||||
|
} else {
|
||||||
|
for _, v := range src[nextEmit:s] {
|
||||||
|
dst.tokens[dst.n] = token(v)
|
||||||
|
dst.litHist[v]++
|
||||||
|
dst.n++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.AddMatchLong(l, uint32(s-t-baseMatchOffset))
|
||||||
|
s += l
|
||||||
|
nextEmit = s
|
||||||
|
if nextS >= s {
|
||||||
|
s = nextS + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if s >= sLimit {
|
||||||
|
t += l
|
||||||
|
// Index first pair after match end.
|
||||||
|
if int(t+8) < len(src) && t > 0 {
|
||||||
|
cv = load6432(src, t)
|
||||||
|
nextHash := hashLen(cv, tableBits, hashBytes)
|
||||||
|
e.table[nextHash] = tableEntryPrev{
|
||||||
|
Prev: e.table[nextHash].Cur,
|
||||||
|
Cur: tableEntry{offset: e.cur + t},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
goto emitRemainder
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store every 5th hash in-between.
|
||||||
|
for i := s - l + 2; i < s-5; i += 6 {
|
||||||
|
nextHash := hashLen(load6432(src, i), tableBits, hashBytes)
|
||||||
|
e.table[nextHash] = tableEntryPrev{
|
||||||
|
Prev: e.table[nextHash].Cur,
|
||||||
|
Cur: tableEntry{offset: e.cur + i}}
|
||||||
|
}
|
||||||
|
// We could immediately start working at s now, but to improve
|
||||||
|
// compression we first update the hash table at s-2 to s.
|
||||||
|
x := load6432(src, s-2)
|
||||||
|
prevHash := hashLen(x, tableBits, hashBytes)
|
||||||
|
|
||||||
|
e.table[prevHash] = tableEntryPrev{
|
||||||
|
Prev: e.table[prevHash].Cur,
|
||||||
|
Cur: tableEntry{offset: e.cur + s - 2},
|
||||||
|
}
|
||||||
|
x >>= 8
|
||||||
|
prevHash = hashLen(x, tableBits, hashBytes)
|
||||||
|
|
||||||
|
e.table[prevHash] = tableEntryPrev{
|
||||||
|
Prev: e.table[prevHash].Cur,
|
||||||
|
Cur: tableEntry{offset: e.cur + s - 1},
|
||||||
|
}
|
||||||
|
x >>= 8
|
||||||
|
currHash := hashLen(x, tableBits, hashBytes)
|
||||||
|
candidates := e.table[currHash]
|
||||||
|
cv = x
|
||||||
|
e.table[currHash] = tableEntryPrev{
|
||||||
|
Prev: candidates.Cur,
|
||||||
|
Cur: tableEntry{offset: s + e.cur},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check both candidates
|
||||||
|
candidate = candidates.Cur
|
||||||
|
minOffset := e.cur + s - (maxMatchOffset - 4)
|
||||||
|
|
||||||
|
if candidate.offset > minOffset {
|
||||||
|
if uint32(cv) == load3232(src, candidate.offset-e.cur) {
|
||||||
|
// Found a match...
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
candidate = candidates.Prev
|
||||||
|
if candidate.offset > minOffset && uint32(cv) == load3232(src, candidate.offset-e.cur) {
|
||||||
|
// Match at prev...
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cv = x >> 8
|
||||||
|
s++
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
emitRemainder:
|
||||||
|
if int(nextEmit) < len(src) {
|
||||||
|
// If nothing was added, don't encode literals.
|
||||||
|
if dst.n == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
emitLiteral(dst, src[nextEmit:])
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,221 @@
|
||||||
|
package flate
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
type fastEncL4 struct {
|
||||||
|
fastGen
|
||||||
|
table [tableSize]tableEntry
|
||||||
|
bTable [tableSize]tableEntry
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *fastEncL4) Encode(dst *tokens, src []byte) {
|
||||||
|
const (
|
||||||
|
inputMargin = 12 - 1
|
||||||
|
minNonLiteralBlockSize = 1 + 1 + inputMargin
|
||||||
|
hashShortBytes = 4
|
||||||
|
)
|
||||||
|
if debugDeflate && e.cur < 0 {
|
||||||
|
panic(fmt.Sprint("e.cur < 0: ", e.cur))
|
||||||
|
}
|
||||||
|
// Protect against e.cur wraparound.
|
||||||
|
for e.cur >= bufferReset {
|
||||||
|
if len(e.hist) == 0 {
|
||||||
|
for i := range e.table[:] {
|
||||||
|
e.table[i] = tableEntry{}
|
||||||
|
}
|
||||||
|
for i := range e.bTable[:] {
|
||||||
|
e.bTable[i] = tableEntry{}
|
||||||
|
}
|
||||||
|
e.cur = maxMatchOffset
|
||||||
|
break
|
||||||
|
}
|
||||||
|
// Shift down everything in the table that isn't already too far away.
|
||||||
|
minOff := e.cur + int32(len(e.hist)) - maxMatchOffset
|
||||||
|
for i := range e.table[:] {
|
||||||
|
v := e.table[i].offset
|
||||||
|
if v <= minOff {
|
||||||
|
v = 0
|
||||||
|
} else {
|
||||||
|
v = v - e.cur + maxMatchOffset
|
||||||
|
}
|
||||||
|
e.table[i].offset = v
|
||||||
|
}
|
||||||
|
for i := range e.bTable[:] {
|
||||||
|
v := e.bTable[i].offset
|
||||||
|
if v <= minOff {
|
||||||
|
v = 0
|
||||||
|
} else {
|
||||||
|
v = v - e.cur + maxMatchOffset
|
||||||
|
}
|
||||||
|
e.bTable[i].offset = v
|
||||||
|
}
|
||||||
|
e.cur = maxMatchOffset
|
||||||
|
}
|
||||||
|
|
||||||
|
s := e.addBlock(src)
|
||||||
|
|
||||||
|
// This check isn't in the Snappy implementation, but there, the caller
|
||||||
|
// instead of the callee handles this case.
|
||||||
|
if len(src) < minNonLiteralBlockSize {
|
||||||
|
// We do not fill the token table.
|
||||||
|
// This will be picked up by caller.
|
||||||
|
dst.n = uint16(len(src))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override src
|
||||||
|
src = e.hist
|
||||||
|
nextEmit := s
|
||||||
|
|
||||||
|
// sLimit is when to stop looking for offset/length copies. The inputMargin
|
||||||
|
// lets us use a fast path for emitLiteral in the main loop, while we are
|
||||||
|
// looking for copies.
|
||||||
|
sLimit := int32(len(src) - inputMargin)
|
||||||
|
|
||||||
|
// nextEmit is where in src the next emitLiteral should start from.
|
||||||
|
cv := load6432(src, s)
|
||||||
|
for {
|
||||||
|
const skipLog = 6
|
||||||
|
const doEvery = 1
|
||||||
|
|
||||||
|
nextS := s
|
||||||
|
var t int32
|
||||||
|
for {
|
||||||
|
nextHashS := hashLen(cv, tableBits, hashShortBytes)
|
||||||
|
nextHashL := hash7(cv, tableBits)
|
||||||
|
|
||||||
|
s = nextS
|
||||||
|
nextS = s + doEvery + (s-nextEmit)>>skipLog
|
||||||
|
if nextS > sLimit {
|
||||||
|
goto emitRemainder
|
||||||
|
}
|
||||||
|
// Fetch a short+long candidate
|
||||||
|
sCandidate := e.table[nextHashS]
|
||||||
|
lCandidate := e.bTable[nextHashL]
|
||||||
|
next := load6432(src, nextS)
|
||||||
|
entry := tableEntry{offset: s + e.cur}
|
||||||
|
e.table[nextHashS] = entry
|
||||||
|
e.bTable[nextHashL] = entry
|
||||||
|
|
||||||
|
t = lCandidate.offset - e.cur
|
||||||
|
if s-t < maxMatchOffset && uint32(cv) == load3232(src, lCandidate.offset-e.cur) {
|
||||||
|
// We got a long match. Use that.
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
t = sCandidate.offset - e.cur
|
||||||
|
if s-t < maxMatchOffset && uint32(cv) == load3232(src, sCandidate.offset-e.cur) {
|
||||||
|
// Found a 4 match...
|
||||||
|
lCandidate = e.bTable[hash7(next, tableBits)]
|
||||||
|
|
||||||
|
// If the next long is a candidate, check if we should use that instead...
|
||||||
|
lOff := nextS - (lCandidate.offset - e.cur)
|
||||||
|
if lOff < maxMatchOffset && load3232(src, lCandidate.offset-e.cur) == uint32(next) {
|
||||||
|
l1, l2 := matchLen(src[s+4:], src[t+4:]), matchLen(src[nextS+4:], src[nextS-lOff+4:])
|
||||||
|
if l2 > l1 {
|
||||||
|
s = nextS
|
||||||
|
t = lCandidate.offset - e.cur
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
cv = next
|
||||||
|
}
|
||||||
|
|
||||||
|
// A 4-byte match has been found. We'll later see if more than 4 bytes
|
||||||
|
// match. But, prior to the match, src[nextEmit:s] are unmatched. Emit
|
||||||
|
// them as literal bytes.
|
||||||
|
|
||||||
|
// Extend the 4-byte match as long as possible.
|
||||||
|
l := e.matchlenLong(s+4, t+4, src) + 4
|
||||||
|
|
||||||
|
// Extend backwards
|
||||||
|
for t > 0 && s > nextEmit && src[t-1] == src[s-1] {
|
||||||
|
s--
|
||||||
|
t--
|
||||||
|
l++
|
||||||
|
}
|
||||||
|
if nextEmit < s {
|
||||||
|
if false {
|
||||||
|
emitLiteral(dst, src[nextEmit:s])
|
||||||
|
} else {
|
||||||
|
for _, v := range src[nextEmit:s] {
|
||||||
|
dst.tokens[dst.n] = token(v)
|
||||||
|
dst.litHist[v]++
|
||||||
|
dst.n++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if debugDeflate {
|
||||||
|
if t >= s {
|
||||||
|
panic("s-t")
|
||||||
|
}
|
||||||
|
if (s - t) > maxMatchOffset {
|
||||||
|
panic(fmt.Sprintln("mmo", t))
|
||||||
|
}
|
||||||
|
if l < baseMatchLength {
|
||||||
|
panic("bml")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.AddMatchLong(l, uint32(s-t-baseMatchOffset))
|
||||||
|
s += l
|
||||||
|
nextEmit = s
|
||||||
|
if nextS >= s {
|
||||||
|
s = nextS + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if s >= sLimit {
|
||||||
|
// Index first pair after match end.
|
||||||
|
if int(s+8) < len(src) {
|
||||||
|
cv := load6432(src, s)
|
||||||
|
e.table[hashLen(cv, tableBits, hashShortBytes)] = tableEntry{offset: s + e.cur}
|
||||||
|
e.bTable[hash7(cv, tableBits)] = tableEntry{offset: s + e.cur}
|
||||||
|
}
|
||||||
|
goto emitRemainder
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store every 3rd hash in-between
|
||||||
|
if true {
|
||||||
|
i := nextS
|
||||||
|
if i < s-1 {
|
||||||
|
cv := load6432(src, i)
|
||||||
|
t := tableEntry{offset: i + e.cur}
|
||||||
|
t2 := tableEntry{offset: t.offset + 1}
|
||||||
|
e.bTable[hash7(cv, tableBits)] = t
|
||||||
|
e.bTable[hash7(cv>>8, tableBits)] = t2
|
||||||
|
e.table[hashLen(cv>>8, tableBits, hashShortBytes)] = t2
|
||||||
|
|
||||||
|
i += 3
|
||||||
|
for ; i < s-1; i += 3 {
|
||||||
|
cv := load6432(src, i)
|
||||||
|
t := tableEntry{offset: i + e.cur}
|
||||||
|
t2 := tableEntry{offset: t.offset + 1}
|
||||||
|
e.bTable[hash7(cv, tableBits)] = t
|
||||||
|
e.bTable[hash7(cv>>8, tableBits)] = t2
|
||||||
|
e.table[hashLen(cv>>8, tableBits, hashShortBytes)] = t2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// We could immediately start working at s now, but to improve
|
||||||
|
// compression we first update the hash table at s-1 and at s.
|
||||||
|
x := load6432(src, s-1)
|
||||||
|
o := e.cur + s - 1
|
||||||
|
prevHashS := hashLen(x, tableBits, hashShortBytes)
|
||||||
|
prevHashL := hash7(x, tableBits)
|
||||||
|
e.table[prevHashS] = tableEntry{offset: o}
|
||||||
|
e.bTable[prevHashL] = tableEntry{offset: o}
|
||||||
|
cv = x >> 8
|
||||||
|
}
|
||||||
|
|
||||||
|
emitRemainder:
|
||||||
|
if int(nextEmit) < len(src) {
|
||||||
|
// If nothing was added, don't encode literals.
|
||||||
|
if dst.n == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
emitLiteral(dst, src[nextEmit:])
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,310 @@
|
||||||
|
package flate
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
type fastEncL5 struct {
|
||||||
|
fastGen
|
||||||
|
table [tableSize]tableEntry
|
||||||
|
bTable [tableSize]tableEntryPrev
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *fastEncL5) Encode(dst *tokens, src []byte) {
|
||||||
|
const (
|
||||||
|
inputMargin = 12 - 1
|
||||||
|
minNonLiteralBlockSize = 1 + 1 + inputMargin
|
||||||
|
hashShortBytes = 4
|
||||||
|
)
|
||||||
|
if debugDeflate && e.cur < 0 {
|
||||||
|
panic(fmt.Sprint("e.cur < 0: ", e.cur))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Protect against e.cur wraparound.
|
||||||
|
for e.cur >= bufferReset {
|
||||||
|
if len(e.hist) == 0 {
|
||||||
|
for i := range e.table[:] {
|
||||||
|
e.table[i] = tableEntry{}
|
||||||
|
}
|
||||||
|
for i := range e.bTable[:] {
|
||||||
|
e.bTable[i] = tableEntryPrev{}
|
||||||
|
}
|
||||||
|
e.cur = maxMatchOffset
|
||||||
|
break
|
||||||
|
}
|
||||||
|
// Shift down everything in the table that isn't already too far away.
|
||||||
|
minOff := e.cur + int32(len(e.hist)) - maxMatchOffset
|
||||||
|
for i := range e.table[:] {
|
||||||
|
v := e.table[i].offset
|
||||||
|
if v <= minOff {
|
||||||
|
v = 0
|
||||||
|
} else {
|
||||||
|
v = v - e.cur + maxMatchOffset
|
||||||
|
}
|
||||||
|
e.table[i].offset = v
|
||||||
|
}
|
||||||
|
for i := range e.bTable[:] {
|
||||||
|
v := e.bTable[i]
|
||||||
|
if v.Cur.offset <= minOff {
|
||||||
|
v.Cur.offset = 0
|
||||||
|
v.Prev.offset = 0
|
||||||
|
} else {
|
||||||
|
v.Cur.offset = v.Cur.offset - e.cur + maxMatchOffset
|
||||||
|
if v.Prev.offset <= minOff {
|
||||||
|
v.Prev.offset = 0
|
||||||
|
} else {
|
||||||
|
v.Prev.offset = v.Prev.offset - e.cur + maxMatchOffset
|
||||||
|
}
|
||||||
|
}
|
||||||
|
e.bTable[i] = v
|
||||||
|
}
|
||||||
|
e.cur = maxMatchOffset
|
||||||
|
}
|
||||||
|
|
||||||
|
s := e.addBlock(src)
|
||||||
|
|
||||||
|
// This check isn't in the Snappy implementation, but there, the caller
|
||||||
|
// instead of the callee handles this case.
|
||||||
|
if len(src) < minNonLiteralBlockSize {
|
||||||
|
// We do not fill the token table.
|
||||||
|
// This will be picked up by caller.
|
||||||
|
dst.n = uint16(len(src))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override src
|
||||||
|
src = e.hist
|
||||||
|
nextEmit := s
|
||||||
|
|
||||||
|
// sLimit is when to stop looking for offset/length copies. The inputMargin
|
||||||
|
// lets us use a fast path for emitLiteral in the main loop, while we are
|
||||||
|
// looking for copies.
|
||||||
|
sLimit := int32(len(src) - inputMargin)
|
||||||
|
|
||||||
|
// nextEmit is where in src the next emitLiteral should start from.
|
||||||
|
cv := load6432(src, s)
|
||||||
|
for {
|
||||||
|
const skipLog = 6
|
||||||
|
const doEvery = 1
|
||||||
|
|
||||||
|
nextS := s
|
||||||
|
var l int32
|
||||||
|
var t int32
|
||||||
|
for {
|
||||||
|
nextHashS := hashLen(cv, tableBits, hashShortBytes)
|
||||||
|
nextHashL := hash7(cv, tableBits)
|
||||||
|
|
||||||
|
s = nextS
|
||||||
|
nextS = s + doEvery + (s-nextEmit)>>skipLog
|
||||||
|
if nextS > sLimit {
|
||||||
|
goto emitRemainder
|
||||||
|
}
|
||||||
|
// Fetch a short+long candidate
|
||||||
|
sCandidate := e.table[nextHashS]
|
||||||
|
lCandidate := e.bTable[nextHashL]
|
||||||
|
next := load6432(src, nextS)
|
||||||
|
entry := tableEntry{offset: s + e.cur}
|
||||||
|
e.table[nextHashS] = entry
|
||||||
|
eLong := &e.bTable[nextHashL]
|
||||||
|
eLong.Cur, eLong.Prev = entry, eLong.Cur
|
||||||
|
|
||||||
|
nextHashS = hashLen(next, tableBits, hashShortBytes)
|
||||||
|
nextHashL = hash7(next, tableBits)
|
||||||
|
|
||||||
|
t = lCandidate.Cur.offset - e.cur
|
||||||
|
if s-t < maxMatchOffset {
|
||||||
|
if uint32(cv) == load3232(src, lCandidate.Cur.offset-e.cur) {
|
||||||
|
// Store the next match
|
||||||
|
e.table[nextHashS] = tableEntry{offset: nextS + e.cur}
|
||||||
|
eLong := &e.bTable[nextHashL]
|
||||||
|
eLong.Cur, eLong.Prev = tableEntry{offset: nextS + e.cur}, eLong.Cur
|
||||||
|
|
||||||
|
t2 := lCandidate.Prev.offset - e.cur
|
||||||
|
if s-t2 < maxMatchOffset && uint32(cv) == load3232(src, lCandidate.Prev.offset-e.cur) {
|
||||||
|
l = e.matchlen(s+4, t+4, src) + 4
|
||||||
|
ml1 := e.matchlen(s+4, t2+4, src) + 4
|
||||||
|
if ml1 > l {
|
||||||
|
t = t2
|
||||||
|
l = ml1
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
t = lCandidate.Prev.offset - e.cur
|
||||||
|
if s-t < maxMatchOffset && uint32(cv) == load3232(src, lCandidate.Prev.offset-e.cur) {
|
||||||
|
// Store the next match
|
||||||
|
e.table[nextHashS] = tableEntry{offset: nextS + e.cur}
|
||||||
|
eLong := &e.bTable[nextHashL]
|
||||||
|
eLong.Cur, eLong.Prev = tableEntry{offset: nextS + e.cur}, eLong.Cur
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t = sCandidate.offset - e.cur
|
||||||
|
if s-t < maxMatchOffset && uint32(cv) == load3232(src, sCandidate.offset-e.cur) {
|
||||||
|
// Found a 4 match...
|
||||||
|
l = e.matchlen(s+4, t+4, src) + 4
|
||||||
|
lCandidate = e.bTable[nextHashL]
|
||||||
|
// Store the next match
|
||||||
|
|
||||||
|
e.table[nextHashS] = tableEntry{offset: nextS + e.cur}
|
||||||
|
eLong := &e.bTable[nextHashL]
|
||||||
|
eLong.Cur, eLong.Prev = tableEntry{offset: nextS + e.cur}, eLong.Cur
|
||||||
|
|
||||||
|
// If the next long is a candidate, use that...
|
||||||
|
t2 := lCandidate.Cur.offset - e.cur
|
||||||
|
if nextS-t2 < maxMatchOffset {
|
||||||
|
if load3232(src, lCandidate.Cur.offset-e.cur) == uint32(next) {
|
||||||
|
ml := e.matchlen(nextS+4, t2+4, src) + 4
|
||||||
|
if ml > l {
|
||||||
|
t = t2
|
||||||
|
s = nextS
|
||||||
|
l = ml
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// If the previous long is a candidate, use that...
|
||||||
|
t2 = lCandidate.Prev.offset - e.cur
|
||||||
|
if nextS-t2 < maxMatchOffset && load3232(src, lCandidate.Prev.offset-e.cur) == uint32(next) {
|
||||||
|
ml := e.matchlen(nextS+4, t2+4, src) + 4
|
||||||
|
if ml > l {
|
||||||
|
t = t2
|
||||||
|
s = nextS
|
||||||
|
l = ml
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
cv = next
|
||||||
|
}
|
||||||
|
|
||||||
|
// A 4-byte match has been found. We'll later see if more than 4 bytes
|
||||||
|
// match. But, prior to the match, src[nextEmit:s] are unmatched. Emit
|
||||||
|
// them as literal bytes.
|
||||||
|
|
||||||
|
if l == 0 {
|
||||||
|
// Extend the 4-byte match as long as possible.
|
||||||
|
l = e.matchlenLong(s+4, t+4, src) + 4
|
||||||
|
} else if l == maxMatchLength {
|
||||||
|
l += e.matchlenLong(s+l, t+l, src)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to locate a better match by checking the end of best match...
|
||||||
|
if sAt := s + l; l < 30 && sAt < sLimit {
|
||||||
|
// Allow some bytes at the beginning to mismatch.
|
||||||
|
// Sweet spot is 2/3 bytes depending on input.
|
||||||
|
// 3 is only a little better when it is but sometimes a lot worse.
|
||||||
|
// The skipped bytes are tested in Extend backwards,
|
||||||
|
// and still picked up as part of the match if they do.
|
||||||
|
const skipBeginning = 2
|
||||||
|
eLong := e.bTable[hash7(load6432(src, sAt), tableBits)].Cur.offset
|
||||||
|
t2 := eLong - e.cur - l + skipBeginning
|
||||||
|
s2 := s + skipBeginning
|
||||||
|
off := s2 - t2
|
||||||
|
if t2 >= 0 && off < maxMatchOffset && off > 0 {
|
||||||
|
if l2 := e.matchlenLong(s2, t2, src); l2 > l {
|
||||||
|
t = t2
|
||||||
|
l = l2
|
||||||
|
s = s2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extend backwards
|
||||||
|
for t > 0 && s > nextEmit && src[t-1] == src[s-1] {
|
||||||
|
s--
|
||||||
|
t--
|
||||||
|
l++
|
||||||
|
}
|
||||||
|
if nextEmit < s {
|
||||||
|
if false {
|
||||||
|
emitLiteral(dst, src[nextEmit:s])
|
||||||
|
} else {
|
||||||
|
for _, v := range src[nextEmit:s] {
|
||||||
|
dst.tokens[dst.n] = token(v)
|
||||||
|
dst.litHist[v]++
|
||||||
|
dst.n++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if debugDeflate {
|
||||||
|
if t >= s {
|
||||||
|
panic(fmt.Sprintln("s-t", s, t))
|
||||||
|
}
|
||||||
|
if (s - t) > maxMatchOffset {
|
||||||
|
panic(fmt.Sprintln("mmo", s-t))
|
||||||
|
}
|
||||||
|
if l < baseMatchLength {
|
||||||
|
panic("bml")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.AddMatchLong(l, uint32(s-t-baseMatchOffset))
|
||||||
|
s += l
|
||||||
|
nextEmit = s
|
||||||
|
if nextS >= s {
|
||||||
|
s = nextS + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if s >= sLimit {
|
||||||
|
goto emitRemainder
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store every 3rd hash in-between.
|
||||||
|
if true {
|
||||||
|
const hashEvery = 3
|
||||||
|
i := s - l + 1
|
||||||
|
if i < s-1 {
|
||||||
|
cv := load6432(src, i)
|
||||||
|
t := tableEntry{offset: i + e.cur}
|
||||||
|
e.table[hashLen(cv, tableBits, hashShortBytes)] = t
|
||||||
|
eLong := &e.bTable[hash7(cv, tableBits)]
|
||||||
|
eLong.Cur, eLong.Prev = t, eLong.Cur
|
||||||
|
|
||||||
|
// Do an long at i+1
|
||||||
|
cv >>= 8
|
||||||
|
t = tableEntry{offset: t.offset + 1}
|
||||||
|
eLong = &e.bTable[hash7(cv, tableBits)]
|
||||||
|
eLong.Cur, eLong.Prev = t, eLong.Cur
|
||||||
|
|
||||||
|
// We only have enough bits for a short entry at i+2
|
||||||
|
cv >>= 8
|
||||||
|
t = tableEntry{offset: t.offset + 1}
|
||||||
|
e.table[hashLen(cv, tableBits, hashShortBytes)] = t
|
||||||
|
|
||||||
|
// Skip one - otherwise we risk hitting 's'
|
||||||
|
i += 4
|
||||||
|
for ; i < s-1; i += hashEvery {
|
||||||
|
cv := load6432(src, i)
|
||||||
|
t := tableEntry{offset: i + e.cur}
|
||||||
|
t2 := tableEntry{offset: t.offset + 1}
|
||||||
|
eLong := &e.bTable[hash7(cv, tableBits)]
|
||||||
|
eLong.Cur, eLong.Prev = t, eLong.Cur
|
||||||
|
e.table[hashLen(cv>>8, tableBits, hashShortBytes)] = t2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// We could immediately start working at s now, but to improve
|
||||||
|
// compression we first update the hash table at s-1 and at s.
|
||||||
|
x := load6432(src, s-1)
|
||||||
|
o := e.cur + s - 1
|
||||||
|
prevHashS := hashLen(x, tableBits, hashShortBytes)
|
||||||
|
prevHashL := hash7(x, tableBits)
|
||||||
|
e.table[prevHashS] = tableEntry{offset: o}
|
||||||
|
eLong := &e.bTable[prevHashL]
|
||||||
|
eLong.Cur, eLong.Prev = tableEntry{offset: o}, eLong.Cur
|
||||||
|
cv = x >> 8
|
||||||
|
}
|
||||||
|
|
||||||
|
emitRemainder:
|
||||||
|
if int(nextEmit) < len(src) {
|
||||||
|
// If nothing was added, don't encode literals.
|
||||||
|
if dst.n == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
emitLiteral(dst, src[nextEmit:])
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,325 @@
|
||||||
|
package flate
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
type fastEncL6 struct {
|
||||||
|
fastGen
|
||||||
|
table [tableSize]tableEntry
|
||||||
|
bTable [tableSize]tableEntryPrev
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *fastEncL6) Encode(dst *tokens, src []byte) {
|
||||||
|
const (
|
||||||
|
inputMargin = 12 - 1
|
||||||
|
minNonLiteralBlockSize = 1 + 1 + inputMargin
|
||||||
|
hashShortBytes = 4
|
||||||
|
)
|
||||||
|
if debugDeflate && e.cur < 0 {
|
||||||
|
panic(fmt.Sprint("e.cur < 0: ", e.cur))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Protect against e.cur wraparound.
|
||||||
|
for e.cur >= bufferReset {
|
||||||
|
if len(e.hist) == 0 {
|
||||||
|
for i := range e.table[:] {
|
||||||
|
e.table[i] = tableEntry{}
|
||||||
|
}
|
||||||
|
for i := range e.bTable[:] {
|
||||||
|
e.bTable[i] = tableEntryPrev{}
|
||||||
|
}
|
||||||
|
e.cur = maxMatchOffset
|
||||||
|
break
|
||||||
|
}
|
||||||
|
// Shift down everything in the table that isn't already too far away.
|
||||||
|
minOff := e.cur + int32(len(e.hist)) - maxMatchOffset
|
||||||
|
for i := range e.table[:] {
|
||||||
|
v := e.table[i].offset
|
||||||
|
if v <= minOff {
|
||||||
|
v = 0
|
||||||
|
} else {
|
||||||
|
v = v - e.cur + maxMatchOffset
|
||||||
|
}
|
||||||
|
e.table[i].offset = v
|
||||||
|
}
|
||||||
|
for i := range e.bTable[:] {
|
||||||
|
v := e.bTable[i]
|
||||||
|
if v.Cur.offset <= minOff {
|
||||||
|
v.Cur.offset = 0
|
||||||
|
v.Prev.offset = 0
|
||||||
|
} else {
|
||||||
|
v.Cur.offset = v.Cur.offset - e.cur + maxMatchOffset
|
||||||
|
if v.Prev.offset <= minOff {
|
||||||
|
v.Prev.offset = 0
|
||||||
|
} else {
|
||||||
|
v.Prev.offset = v.Prev.offset - e.cur + maxMatchOffset
|
||||||
|
}
|
||||||
|
}
|
||||||
|
e.bTable[i] = v
|
||||||
|
}
|
||||||
|
e.cur = maxMatchOffset
|
||||||
|
}
|
||||||
|
|
||||||
|
s := e.addBlock(src)
|
||||||
|
|
||||||
|
// This check isn't in the Snappy implementation, but there, the caller
|
||||||
|
// instead of the callee handles this case.
|
||||||
|
if len(src) < minNonLiteralBlockSize {
|
||||||
|
// We do not fill the token table.
|
||||||
|
// This will be picked up by caller.
|
||||||
|
dst.n = uint16(len(src))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override src
|
||||||
|
src = e.hist
|
||||||
|
nextEmit := s
|
||||||
|
|
||||||
|
// sLimit is when to stop looking for offset/length copies. The inputMargin
|
||||||
|
// lets us use a fast path for emitLiteral in the main loop, while we are
|
||||||
|
// looking for copies.
|
||||||
|
sLimit := int32(len(src) - inputMargin)
|
||||||
|
|
||||||
|
// nextEmit is where in src the next emitLiteral should start from.
|
||||||
|
cv := load6432(src, s)
|
||||||
|
// Repeat MUST be > 1 and within range
|
||||||
|
repeat := int32(1)
|
||||||
|
for {
|
||||||
|
const skipLog = 7
|
||||||
|
const doEvery = 1
|
||||||
|
|
||||||
|
nextS := s
|
||||||
|
var l int32
|
||||||
|
var t int32
|
||||||
|
for {
|
||||||
|
nextHashS := hashLen(cv, tableBits, hashShortBytes)
|
||||||
|
nextHashL := hash7(cv, tableBits)
|
||||||
|
s = nextS
|
||||||
|
nextS = s + doEvery + (s-nextEmit)>>skipLog
|
||||||
|
if nextS > sLimit {
|
||||||
|
goto emitRemainder
|
||||||
|
}
|
||||||
|
// Fetch a short+long candidate
|
||||||
|
sCandidate := e.table[nextHashS]
|
||||||
|
lCandidate := e.bTable[nextHashL]
|
||||||
|
next := load6432(src, nextS)
|
||||||
|
entry := tableEntry{offset: s + e.cur}
|
||||||
|
e.table[nextHashS] = entry
|
||||||
|
eLong := &e.bTable[nextHashL]
|
||||||
|
eLong.Cur, eLong.Prev = entry, eLong.Cur
|
||||||
|
|
||||||
|
// Calculate hashes of 'next'
|
||||||
|
nextHashS = hashLen(next, tableBits, hashShortBytes)
|
||||||
|
nextHashL = hash7(next, tableBits)
|
||||||
|
|
||||||
|
t = lCandidate.Cur.offset - e.cur
|
||||||
|
if s-t < maxMatchOffset {
|
||||||
|
if uint32(cv) == load3232(src, lCandidate.Cur.offset-e.cur) {
|
||||||
|
// Long candidate matches at least 4 bytes.
|
||||||
|
|
||||||
|
// Store the next match
|
||||||
|
e.table[nextHashS] = tableEntry{offset: nextS + e.cur}
|
||||||
|
eLong := &e.bTable[nextHashL]
|
||||||
|
eLong.Cur, eLong.Prev = tableEntry{offset: nextS + e.cur}, eLong.Cur
|
||||||
|
|
||||||
|
// Check the previous long candidate as well.
|
||||||
|
t2 := lCandidate.Prev.offset - e.cur
|
||||||
|
if s-t2 < maxMatchOffset && uint32(cv) == load3232(src, lCandidate.Prev.offset-e.cur) {
|
||||||
|
l = e.matchlen(s+4, t+4, src) + 4
|
||||||
|
ml1 := e.matchlen(s+4, t2+4, src) + 4
|
||||||
|
if ml1 > l {
|
||||||
|
t = t2
|
||||||
|
l = ml1
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
// Current value did not match, but check if previous long value does.
|
||||||
|
t = lCandidate.Prev.offset - e.cur
|
||||||
|
if s-t < maxMatchOffset && uint32(cv) == load3232(src, lCandidate.Prev.offset-e.cur) {
|
||||||
|
// Store the next match
|
||||||
|
e.table[nextHashS] = tableEntry{offset: nextS + e.cur}
|
||||||
|
eLong := &e.bTable[nextHashL]
|
||||||
|
eLong.Cur, eLong.Prev = tableEntry{offset: nextS + e.cur}, eLong.Cur
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t = sCandidate.offset - e.cur
|
||||||
|
if s-t < maxMatchOffset && uint32(cv) == load3232(src, sCandidate.offset-e.cur) {
|
||||||
|
// Found a 4 match...
|
||||||
|
l = e.matchlen(s+4, t+4, src) + 4
|
||||||
|
|
||||||
|
// Look up next long candidate (at nextS)
|
||||||
|
lCandidate = e.bTable[nextHashL]
|
||||||
|
|
||||||
|
// Store the next match
|
||||||
|
e.table[nextHashS] = tableEntry{offset: nextS + e.cur}
|
||||||
|
eLong := &e.bTable[nextHashL]
|
||||||
|
eLong.Cur, eLong.Prev = tableEntry{offset: nextS + e.cur}, eLong.Cur
|
||||||
|
|
||||||
|
// Check repeat at s + repOff
|
||||||
|
const repOff = 1
|
||||||
|
t2 := s - repeat + repOff
|
||||||
|
if load3232(src, t2) == uint32(cv>>(8*repOff)) {
|
||||||
|
ml := e.matchlen(s+4+repOff, t2+4, src) + 4
|
||||||
|
if ml > l {
|
||||||
|
t = t2
|
||||||
|
l = ml
|
||||||
|
s += repOff
|
||||||
|
// Not worth checking more.
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the next long is a candidate, use that...
|
||||||
|
t2 = lCandidate.Cur.offset - e.cur
|
||||||
|
if nextS-t2 < maxMatchOffset {
|
||||||
|
if load3232(src, lCandidate.Cur.offset-e.cur) == uint32(next) {
|
||||||
|
ml := e.matchlen(nextS+4, t2+4, src) + 4
|
||||||
|
if ml > l {
|
||||||
|
t = t2
|
||||||
|
s = nextS
|
||||||
|
l = ml
|
||||||
|
// This is ok, but check previous as well.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// If the previous long is a candidate, use that...
|
||||||
|
t2 = lCandidate.Prev.offset - e.cur
|
||||||
|
if nextS-t2 < maxMatchOffset && load3232(src, lCandidate.Prev.offset-e.cur) == uint32(next) {
|
||||||
|
ml := e.matchlen(nextS+4, t2+4, src) + 4
|
||||||
|
if ml > l {
|
||||||
|
t = t2
|
||||||
|
s = nextS
|
||||||
|
l = ml
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
cv = next
|
||||||
|
}
|
||||||
|
|
||||||
|
// A 4-byte match has been found. We'll later see if more than 4 bytes
|
||||||
|
// match. But, prior to the match, src[nextEmit:s] are unmatched. Emit
|
||||||
|
// them as literal bytes.
|
||||||
|
|
||||||
|
// Extend the 4-byte match as long as possible.
|
||||||
|
if l == 0 {
|
||||||
|
l = e.matchlenLong(s+4, t+4, src) + 4
|
||||||
|
} else if l == maxMatchLength {
|
||||||
|
l += e.matchlenLong(s+l, t+l, src)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to locate a better match by checking the end-of-match...
|
||||||
|
if sAt := s + l; sAt < sLimit {
|
||||||
|
// Allow some bytes at the beginning to mismatch.
|
||||||
|
// Sweet spot is 2/3 bytes depending on input.
|
||||||
|
// 3 is only a little better when it is but sometimes a lot worse.
|
||||||
|
// The skipped bytes are tested in Extend backwards,
|
||||||
|
// and still picked up as part of the match if they do.
|
||||||
|
const skipBeginning = 2
|
||||||
|
eLong := &e.bTable[hash7(load6432(src, sAt), tableBits)]
|
||||||
|
// Test current
|
||||||
|
t2 := eLong.Cur.offset - e.cur - l + skipBeginning
|
||||||
|
s2 := s + skipBeginning
|
||||||
|
off := s2 - t2
|
||||||
|
if off < maxMatchOffset {
|
||||||
|
if off > 0 && t2 >= 0 {
|
||||||
|
if l2 := e.matchlenLong(s2, t2, src); l2 > l {
|
||||||
|
t = t2
|
||||||
|
l = l2
|
||||||
|
s = s2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Test next:
|
||||||
|
t2 = eLong.Prev.offset - e.cur - l + skipBeginning
|
||||||
|
off := s2 - t2
|
||||||
|
if off > 0 && off < maxMatchOffset && t2 >= 0 {
|
||||||
|
if l2 := e.matchlenLong(s2, t2, src); l2 > l {
|
||||||
|
t = t2
|
||||||
|
l = l2
|
||||||
|
s = s2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extend backwards
|
||||||
|
for t > 0 && s > nextEmit && src[t-1] == src[s-1] {
|
||||||
|
s--
|
||||||
|
t--
|
||||||
|
l++
|
||||||
|
}
|
||||||
|
if nextEmit < s {
|
||||||
|
if false {
|
||||||
|
emitLiteral(dst, src[nextEmit:s])
|
||||||
|
} else {
|
||||||
|
for _, v := range src[nextEmit:s] {
|
||||||
|
dst.tokens[dst.n] = token(v)
|
||||||
|
dst.litHist[v]++
|
||||||
|
dst.n++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if false {
|
||||||
|
if t >= s {
|
||||||
|
panic(fmt.Sprintln("s-t", s, t))
|
||||||
|
}
|
||||||
|
if (s - t) > maxMatchOffset {
|
||||||
|
panic(fmt.Sprintln("mmo", s-t))
|
||||||
|
}
|
||||||
|
if l < baseMatchLength {
|
||||||
|
panic("bml")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.AddMatchLong(l, uint32(s-t-baseMatchOffset))
|
||||||
|
repeat = s - t
|
||||||
|
s += l
|
||||||
|
nextEmit = s
|
||||||
|
if nextS >= s {
|
||||||
|
s = nextS + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if s >= sLimit {
|
||||||
|
// Index after match end.
|
||||||
|
for i := nextS + 1; i < int32(len(src))-8; i += 2 {
|
||||||
|
cv := load6432(src, i)
|
||||||
|
e.table[hashLen(cv, tableBits, hashShortBytes)] = tableEntry{offset: i + e.cur}
|
||||||
|
eLong := &e.bTable[hash7(cv, tableBits)]
|
||||||
|
eLong.Cur, eLong.Prev = tableEntry{offset: i + e.cur}, eLong.Cur
|
||||||
|
}
|
||||||
|
goto emitRemainder
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store every long hash in-between and every second short.
|
||||||
|
if true {
|
||||||
|
for i := nextS + 1; i < s-1; i += 2 {
|
||||||
|
cv := load6432(src, i)
|
||||||
|
t := tableEntry{offset: i + e.cur}
|
||||||
|
t2 := tableEntry{offset: t.offset + 1}
|
||||||
|
eLong := &e.bTable[hash7(cv, tableBits)]
|
||||||
|
eLong2 := &e.bTable[hash7(cv>>8, tableBits)]
|
||||||
|
e.table[hashLen(cv, tableBits, hashShortBytes)] = t
|
||||||
|
eLong.Cur, eLong.Prev = t, eLong.Cur
|
||||||
|
eLong2.Cur, eLong2.Prev = t2, eLong2.Cur
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// We could immediately start working at s now, but to improve
|
||||||
|
// compression we first update the hash table at s-1 and at s.
|
||||||
|
cv = load6432(src, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
emitRemainder:
|
||||||
|
if int(nextEmit) < len(src) {
|
||||||
|
// If nothing was added, don't encode literals.
|
||||||
|
if dst.n == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
emitLiteral(dst, src[nextEmit:])
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,37 @@
|
||||||
|
package flate
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Masks for shifts with register sizes of the shift value.
|
||||||
|
// This can be used to work around the x86 design of shifting by mod register size.
|
||||||
|
// It can be used when a variable shift is always smaller than the register size.
|
||||||
|
|
||||||
|
// reg8SizeMaskX - shift value is 8 bits, shifted is X
|
||||||
|
reg8SizeMask8 = 7
|
||||||
|
reg8SizeMask16 = 15
|
||||||
|
reg8SizeMask32 = 31
|
||||||
|
reg8SizeMask64 = 63
|
||||||
|
|
||||||
|
// reg16SizeMaskX - shift value is 16 bits, shifted is X
|
||||||
|
reg16SizeMask8 = reg8SizeMask8
|
||||||
|
reg16SizeMask16 = reg8SizeMask16
|
||||||
|
reg16SizeMask32 = reg8SizeMask32
|
||||||
|
reg16SizeMask64 = reg8SizeMask64
|
||||||
|
|
||||||
|
// reg32SizeMaskX - shift value is 32 bits, shifted is X
|
||||||
|
reg32SizeMask8 = reg8SizeMask8
|
||||||
|
reg32SizeMask16 = reg8SizeMask16
|
||||||
|
reg32SizeMask32 = reg8SizeMask32
|
||||||
|
reg32SizeMask64 = reg8SizeMask64
|
||||||
|
|
||||||
|
// reg64SizeMaskX - shift value is 64 bits, shifted is X
|
||||||
|
reg64SizeMask8 = reg8SizeMask8
|
||||||
|
reg64SizeMask16 = reg8SizeMask16
|
||||||
|
reg64SizeMask32 = reg8SizeMask32
|
||||||
|
reg64SizeMask64 = reg8SizeMask64
|
||||||
|
|
||||||
|
// regSizeMaskUintX - shift value is uint, shifted is X
|
||||||
|
regSizeMaskUint8 = reg8SizeMask8
|
||||||
|
regSizeMaskUint16 = reg8SizeMask16
|
||||||
|
regSizeMaskUint32 = reg8SizeMask32
|
||||||
|
regSizeMaskUint64 = reg8SizeMask64
|
||||||
|
)
|
|
@ -0,0 +1,40 @@
|
||||||
|
//go:build !amd64
|
||||||
|
// +build !amd64
|
||||||
|
|
||||||
|
package flate
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Masks for shifts with register sizes of the shift value.
|
||||||
|
// This can be used to work around the x86 design of shifting by mod register size.
|
||||||
|
// It can be used when a variable shift is always smaller than the register size.
|
||||||
|
|
||||||
|
// reg8SizeMaskX - shift value is 8 bits, shifted is X
|
||||||
|
reg8SizeMask8 = 0xff
|
||||||
|
reg8SizeMask16 = 0xff
|
||||||
|
reg8SizeMask32 = 0xff
|
||||||
|
reg8SizeMask64 = 0xff
|
||||||
|
|
||||||
|
// reg16SizeMaskX - shift value is 16 bits, shifted is X
|
||||||
|
reg16SizeMask8 = 0xffff
|
||||||
|
reg16SizeMask16 = 0xffff
|
||||||
|
reg16SizeMask32 = 0xffff
|
||||||
|
reg16SizeMask64 = 0xffff
|
||||||
|
|
||||||
|
// reg32SizeMaskX - shift value is 32 bits, shifted is X
|
||||||
|
reg32SizeMask8 = 0xffffffff
|
||||||
|
reg32SizeMask16 = 0xffffffff
|
||||||
|
reg32SizeMask32 = 0xffffffff
|
||||||
|
reg32SizeMask64 = 0xffffffff
|
||||||
|
|
||||||
|
// reg64SizeMaskX - shift value is 64 bits, shifted is X
|
||||||
|
reg64SizeMask8 = 0xffffffffffffffff
|
||||||
|
reg64SizeMask16 = 0xffffffffffffffff
|
||||||
|
reg64SizeMask32 = 0xffffffffffffffff
|
||||||
|
reg64SizeMask64 = 0xffffffffffffffff
|
||||||
|
|
||||||
|
// regSizeMaskUintX - shift value is uint, shifted is X
|
||||||
|
regSizeMaskUint8 = ^uint(0)
|
||||||
|
regSizeMaskUint16 = ^uint(0)
|
||||||
|
regSizeMaskUint32 = ^uint(0)
|
||||||
|
regSizeMaskUint64 = ^uint(0)
|
||||||
|
)
|
|
@ -0,0 +1,305 @@
|
||||||
|
package flate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"math"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
maxStatelessBlock = math.MaxInt16
|
||||||
|
// dictionary will be taken from maxStatelessBlock, so limit it.
|
||||||
|
maxStatelessDict = 8 << 10
|
||||||
|
|
||||||
|
slTableBits = 13
|
||||||
|
slTableSize = 1 << slTableBits
|
||||||
|
slTableShift = 32 - slTableBits
|
||||||
|
)
|
||||||
|
|
||||||
|
type statelessWriter struct {
|
||||||
|
dst io.Writer
|
||||||
|
closed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *statelessWriter) Close() error {
|
||||||
|
if s.closed {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
s.closed = true
|
||||||
|
// Emit EOF block
|
||||||
|
return StatelessDeflate(s.dst, nil, true, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *statelessWriter) Write(p []byte) (n int, err error) {
|
||||||
|
err = StatelessDeflate(s.dst, p, false, nil)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *statelessWriter) Reset(w io.Writer) {
|
||||||
|
s.dst = w
|
||||||
|
s.closed = false
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewStatelessWriter will do compression but without maintaining any state
|
||||||
|
// between Write calls.
|
||||||
|
// There will be no memory kept between Write calls,
|
||||||
|
// but compression and speed will be suboptimal.
|
||||||
|
// Because of this, the size of actual Write calls will affect output size.
|
||||||
|
func NewStatelessWriter(dst io.Writer) io.WriteCloser {
|
||||||
|
return &statelessWriter{dst: dst}
|
||||||
|
}
|
||||||
|
|
||||||
|
// bitWriterPool contains bit writers that can be reused.
|
||||||
|
var bitWriterPool = sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
return newHuffmanBitWriter(nil)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// StatelessDeflate allows compressing directly to a Writer without retaining state.
|
||||||
|
// When returning everything will be flushed.
|
||||||
|
// Up to 8KB of an optional dictionary can be given which is presumed to precede the block.
|
||||||
|
// Longer dictionaries will be truncated and will still produce valid output.
|
||||||
|
// Sending nil dictionary is perfectly fine.
|
||||||
|
func StatelessDeflate(out io.Writer, in []byte, eof bool, dict []byte) error {
|
||||||
|
var dst tokens
|
||||||
|
bw := bitWriterPool.Get().(*huffmanBitWriter)
|
||||||
|
bw.reset(out)
|
||||||
|
defer func() {
|
||||||
|
// don't keep a reference to our output
|
||||||
|
bw.reset(nil)
|
||||||
|
bitWriterPool.Put(bw)
|
||||||
|
}()
|
||||||
|
if eof && len(in) == 0 {
|
||||||
|
// Just write an EOF block.
|
||||||
|
// Could be faster...
|
||||||
|
bw.writeStoredHeader(0, true)
|
||||||
|
bw.flush()
|
||||||
|
return bw.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Truncate dict
|
||||||
|
if len(dict) > maxStatelessDict {
|
||||||
|
dict = dict[len(dict)-maxStatelessDict:]
|
||||||
|
}
|
||||||
|
|
||||||
|
for len(in) > 0 {
|
||||||
|
todo := in
|
||||||
|
if len(todo) > maxStatelessBlock-len(dict) {
|
||||||
|
todo = todo[:maxStatelessBlock-len(dict)]
|
||||||
|
}
|
||||||
|
in = in[len(todo):]
|
||||||
|
uncompressed := todo
|
||||||
|
if len(dict) > 0 {
|
||||||
|
// combine dict and source
|
||||||
|
bufLen := len(todo) + len(dict)
|
||||||
|
combined := make([]byte, bufLen)
|
||||||
|
copy(combined, dict)
|
||||||
|
copy(combined[len(dict):], todo)
|
||||||
|
todo = combined
|
||||||
|
}
|
||||||
|
// Compress
|
||||||
|
statelessEnc(&dst, todo, int16(len(dict)))
|
||||||
|
isEof := eof && len(in) == 0
|
||||||
|
|
||||||
|
if dst.n == 0 {
|
||||||
|
bw.writeStoredHeader(len(uncompressed), isEof)
|
||||||
|
if bw.err != nil {
|
||||||
|
return bw.err
|
||||||
|
}
|
||||||
|
bw.writeBytes(uncompressed)
|
||||||
|
} else if int(dst.n) > len(uncompressed)-len(uncompressed)>>4 {
|
||||||
|
// If we removed less than 1/16th, huffman compress the block.
|
||||||
|
bw.writeBlockHuff(isEof, uncompressed, len(in) == 0)
|
||||||
|
} else {
|
||||||
|
bw.writeBlockDynamic(&dst, isEof, uncompressed, len(in) == 0)
|
||||||
|
}
|
||||||
|
if len(in) > 0 {
|
||||||
|
// Retain a dict if we have more
|
||||||
|
dict = todo[len(todo)-maxStatelessDict:]
|
||||||
|
dst.Reset()
|
||||||
|
}
|
||||||
|
if bw.err != nil {
|
||||||
|
return bw.err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !eof {
|
||||||
|
// Align, only a stored block can do that.
|
||||||
|
bw.writeStoredHeader(0, false)
|
||||||
|
}
|
||||||
|
bw.flush()
|
||||||
|
return bw.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func hashSL(u uint32) uint32 {
|
||||||
|
return (u * 0x1e35a7bd) >> slTableShift
|
||||||
|
}
|
||||||
|
|
||||||
|
func load3216(b []byte, i int16) uint32 {
|
||||||
|
// Help the compiler eliminate bounds checks on the read so it can be done in a single read.
|
||||||
|
b = b[i:]
|
||||||
|
b = b[:4]
|
||||||
|
return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
|
||||||
|
}
|
||||||
|
|
||||||
|
func load6416(b []byte, i int16) uint64 {
|
||||||
|
// Help the compiler eliminate bounds checks on the read so it can be done in a single read.
|
||||||
|
b = b[i:]
|
||||||
|
b = b[:8]
|
||||||
|
return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 |
|
||||||
|
uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56
|
||||||
|
}
|
||||||
|
|
||||||
|
func statelessEnc(dst *tokens, src []byte, startAt int16) {
|
||||||
|
const (
|
||||||
|
inputMargin = 12 - 1
|
||||||
|
minNonLiteralBlockSize = 1 + 1 + inputMargin
|
||||||
|
)
|
||||||
|
|
||||||
|
type tableEntry struct {
|
||||||
|
offset int16
|
||||||
|
}
|
||||||
|
|
||||||
|
var table [slTableSize]tableEntry
|
||||||
|
|
||||||
|
// This check isn't in the Snappy implementation, but there, the caller
|
||||||
|
// instead of the callee handles this case.
|
||||||
|
if len(src)-int(startAt) < minNonLiteralBlockSize {
|
||||||
|
// We do not fill the token table.
|
||||||
|
// This will be picked up by caller.
|
||||||
|
dst.n = 0
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Index until startAt
|
||||||
|
if startAt > 0 {
|
||||||
|
cv := load3232(src, 0)
|
||||||
|
for i := int16(0); i < startAt; i++ {
|
||||||
|
table[hashSL(cv)] = tableEntry{offset: i}
|
||||||
|
cv = (cv >> 8) | (uint32(src[i+4]) << 24)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s := startAt + 1
|
||||||
|
nextEmit := startAt
|
||||||
|
// sLimit is when to stop looking for offset/length copies. The inputMargin
|
||||||
|
// lets us use a fast path for emitLiteral in the main loop, while we are
|
||||||
|
// looking for copies.
|
||||||
|
sLimit := int16(len(src) - inputMargin)
|
||||||
|
|
||||||
|
// nextEmit is where in src the next emitLiteral should start from.
|
||||||
|
cv := load3216(src, s)
|
||||||
|
|
||||||
|
for {
|
||||||
|
const skipLog = 5
|
||||||
|
const doEvery = 2
|
||||||
|
|
||||||
|
nextS := s
|
||||||
|
var candidate tableEntry
|
||||||
|
for {
|
||||||
|
nextHash := hashSL(cv)
|
||||||
|
candidate = table[nextHash]
|
||||||
|
nextS = s + doEvery + (s-nextEmit)>>skipLog
|
||||||
|
if nextS > sLimit || nextS <= 0 {
|
||||||
|
goto emitRemainder
|
||||||
|
}
|
||||||
|
|
||||||
|
now := load6416(src, nextS)
|
||||||
|
table[nextHash] = tableEntry{offset: s}
|
||||||
|
nextHash = hashSL(uint32(now))
|
||||||
|
|
||||||
|
if cv == load3216(src, candidate.offset) {
|
||||||
|
table[nextHash] = tableEntry{offset: nextS}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do one right away...
|
||||||
|
cv = uint32(now)
|
||||||
|
s = nextS
|
||||||
|
nextS++
|
||||||
|
candidate = table[nextHash]
|
||||||
|
now >>= 8
|
||||||
|
table[nextHash] = tableEntry{offset: s}
|
||||||
|
|
||||||
|
if cv == load3216(src, candidate.offset) {
|
||||||
|
table[nextHash] = tableEntry{offset: nextS}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
cv = uint32(now)
|
||||||
|
s = nextS
|
||||||
|
}
|
||||||
|
|
||||||
|
// A 4-byte match has been found. We'll later see if more than 4 bytes
|
||||||
|
// match. But, prior to the match, src[nextEmit:s] are unmatched. Emit
|
||||||
|
// them as literal bytes.
|
||||||
|
for {
|
||||||
|
// Invariant: we have a 4-byte match at s, and no need to emit any
|
||||||
|
// literal bytes prior to s.
|
||||||
|
|
||||||
|
// Extend the 4-byte match as long as possible.
|
||||||
|
t := candidate.offset
|
||||||
|
l := int16(matchLen(src[s+4:], src[t+4:]) + 4)
|
||||||
|
|
||||||
|
// Extend backwards
|
||||||
|
for t > 0 && s > nextEmit && src[t-1] == src[s-1] {
|
||||||
|
s--
|
||||||
|
t--
|
||||||
|
l++
|
||||||
|
}
|
||||||
|
if nextEmit < s {
|
||||||
|
if false {
|
||||||
|
emitLiteral(dst, src[nextEmit:s])
|
||||||
|
} else {
|
||||||
|
for _, v := range src[nextEmit:s] {
|
||||||
|
dst.tokens[dst.n] = token(v)
|
||||||
|
dst.litHist[v]++
|
||||||
|
dst.n++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save the match found
|
||||||
|
dst.AddMatchLong(int32(l), uint32(s-t-baseMatchOffset))
|
||||||
|
s += l
|
||||||
|
nextEmit = s
|
||||||
|
if nextS >= s {
|
||||||
|
s = nextS + 1
|
||||||
|
}
|
||||||
|
if s >= sLimit {
|
||||||
|
goto emitRemainder
|
||||||
|
}
|
||||||
|
|
||||||
|
// We could immediately start working at s now, but to improve
|
||||||
|
// compression we first update the hash table at s-2 and at s. If
|
||||||
|
// another emitCopy is not our next move, also calculate nextHash
|
||||||
|
// at s+1. At least on GOARCH=amd64, these three hash calculations
|
||||||
|
// are faster as one load64 call (with some shifts) instead of
|
||||||
|
// three load32 calls.
|
||||||
|
x := load6416(src, s-2)
|
||||||
|
o := s - 2
|
||||||
|
prevHash := hashSL(uint32(x))
|
||||||
|
table[prevHash] = tableEntry{offset: o}
|
||||||
|
x >>= 16
|
||||||
|
currHash := hashSL(uint32(x))
|
||||||
|
candidate = table[currHash]
|
||||||
|
table[currHash] = tableEntry{offset: o + 2}
|
||||||
|
|
||||||
|
if uint32(x) != load3216(src, candidate.offset) {
|
||||||
|
cv = uint32(x >> 8)
|
||||||
|
s++
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
emitRemainder:
|
||||||
|
if int(nextEmit) < len(src) {
|
||||||
|
// If nothing was added, don't encode literals.
|
||||||
|
if dst.n == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
emitLiteral(dst, src[nextEmit:])
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,379 @@
|
||||||
|
// Copyright 2009 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package flate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"math"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// bits 0-16 xoffset = offset - MIN_OFFSET_SIZE, or literal - 16 bits
|
||||||
|
// bits 16-22 offsetcode - 5 bits
|
||||||
|
// bits 22-30 xlength = length - MIN_MATCH_LENGTH - 8 bits
|
||||||
|
// bits 30-32 type 0 = literal 1=EOF 2=Match 3=Unused - 2 bits
|
||||||
|
lengthShift = 22
|
||||||
|
offsetMask = 1<<lengthShift - 1
|
||||||
|
typeMask = 3 << 30
|
||||||
|
literalType = 0 << 30
|
||||||
|
matchType = 1 << 30
|
||||||
|
matchOffsetOnlyMask = 0xffff
|
||||||
|
)
|
||||||
|
|
||||||
|
// The length code for length X (MIN_MATCH_LENGTH <= X <= MAX_MATCH_LENGTH)
|
||||||
|
// is lengthCodes[length - MIN_MATCH_LENGTH]
|
||||||
|
var lengthCodes = [256]uint8{
|
||||||
|
0, 1, 2, 3, 4, 5, 6, 7, 8, 8,
|
||||||
|
9, 9, 10, 10, 11, 11, 12, 12, 12, 12,
|
||||||
|
13, 13, 13, 13, 14, 14, 14, 14, 15, 15,
|
||||||
|
15, 15, 16, 16, 16, 16, 16, 16, 16, 16,
|
||||||
|
17, 17, 17, 17, 17, 17, 17, 17, 18, 18,
|
||||||
|
18, 18, 18, 18, 18, 18, 19, 19, 19, 19,
|
||||||
|
19, 19, 19, 19, 20, 20, 20, 20, 20, 20,
|
||||||
|
20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
|
||||||
|
21, 21, 21, 21, 21, 21, 21, 21, 21, 21,
|
||||||
|
21, 21, 21, 21, 21, 21, 22, 22, 22, 22,
|
||||||
|
22, 22, 22, 22, 22, 22, 22, 22, 22, 22,
|
||||||
|
22, 22, 23, 23, 23, 23, 23, 23, 23, 23,
|
||||||
|
23, 23, 23, 23, 23, 23, 23, 23, 24, 24,
|
||||||
|
24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
|
||||||
|
24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
|
||||||
|
24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
|
||||||
|
25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
|
||||||
|
25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
|
||||||
|
25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
|
||||||
|
25, 25, 26, 26, 26, 26, 26, 26, 26, 26,
|
||||||
|
26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
|
||||||
|
26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
|
||||||
|
26, 26, 26, 26, 27, 27, 27, 27, 27, 27,
|
||||||
|
27, 27, 27, 27, 27, 27, 27, 27, 27, 27,
|
||||||
|
27, 27, 27, 27, 27, 27, 27, 27, 27, 27,
|
||||||
|
27, 27, 27, 27, 27, 28,
|
||||||
|
}
|
||||||
|
|
||||||
|
// lengthCodes1 is length codes, but starting at 1.
|
||||||
|
var lengthCodes1 = [256]uint8{
|
||||||
|
1, 2, 3, 4, 5, 6, 7, 8, 9, 9,
|
||||||
|
10, 10, 11, 11, 12, 12, 13, 13, 13, 13,
|
||||||
|
14, 14, 14, 14, 15, 15, 15, 15, 16, 16,
|
||||||
|
16, 16, 17, 17, 17, 17, 17, 17, 17, 17,
|
||||||
|
18, 18, 18, 18, 18, 18, 18, 18, 19, 19,
|
||||||
|
19, 19, 19, 19, 19, 19, 20, 20, 20, 20,
|
||||||
|
20, 20, 20, 20, 21, 21, 21, 21, 21, 21,
|
||||||
|
21, 21, 21, 21, 21, 21, 21, 21, 21, 21,
|
||||||
|
22, 22, 22, 22, 22, 22, 22, 22, 22, 22,
|
||||||
|
22, 22, 22, 22, 22, 22, 23, 23, 23, 23,
|
||||||
|
23, 23, 23, 23, 23, 23, 23, 23, 23, 23,
|
||||||
|
23, 23, 24, 24, 24, 24, 24, 24, 24, 24,
|
||||||
|
24, 24, 24, 24, 24, 24, 24, 24, 25, 25,
|
||||||
|
25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
|
||||||
|
25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
|
||||||
|
25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
|
||||||
|
26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
|
||||||
|
26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
|
||||||
|
26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
|
||||||
|
26, 26, 27, 27, 27, 27, 27, 27, 27, 27,
|
||||||
|
27, 27, 27, 27, 27, 27, 27, 27, 27, 27,
|
||||||
|
27, 27, 27, 27, 27, 27, 27, 27, 27, 27,
|
||||||
|
27, 27, 27, 27, 28, 28, 28, 28, 28, 28,
|
||||||
|
28, 28, 28, 28, 28, 28, 28, 28, 28, 28,
|
||||||
|
28, 28, 28, 28, 28, 28, 28, 28, 28, 28,
|
||||||
|
28, 28, 28, 28, 28, 29,
|
||||||
|
}
|
||||||
|
|
||||||
|
var offsetCodes = [256]uint32{
|
||||||
|
0, 1, 2, 3, 4, 4, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7,
|
||||||
|
8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9,
|
||||||
|
10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
|
||||||
|
11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11,
|
||||||
|
12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
|
||||||
|
12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
|
||||||
|
13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13,
|
||||||
|
13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13,
|
||||||
|
14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
|
||||||
|
14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
|
||||||
|
14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
|
||||||
|
14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
|
||||||
|
15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
|
||||||
|
15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
|
||||||
|
15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
|
||||||
|
15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
|
||||||
|
}
|
||||||
|
|
||||||
|
// offsetCodes14 are offsetCodes, but with 14 added.
|
||||||
|
var offsetCodes14 = [256]uint32{
|
||||||
|
14, 15, 16, 17, 18, 18, 19, 19, 20, 20, 20, 20, 21, 21, 21, 21,
|
||||||
|
22, 22, 22, 22, 22, 22, 22, 22, 23, 23, 23, 23, 23, 23, 23, 23,
|
||||||
|
24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
|
||||||
|
25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
|
||||||
|
26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
|
||||||
|
26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
|
||||||
|
27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27,
|
||||||
|
27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27,
|
||||||
|
28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28,
|
||||||
|
28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28,
|
||||||
|
28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28,
|
||||||
|
28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28,
|
||||||
|
29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
|
||||||
|
29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
|
||||||
|
29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
|
||||||
|
29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
|
||||||
|
}
|
||||||
|
|
||||||
|
type token uint32
|
||||||
|
|
||||||
|
type tokens struct {
|
||||||
|
extraHist [32]uint16 // codes 256->maxnumlit
|
||||||
|
offHist [32]uint16 // offset codes
|
||||||
|
litHist [256]uint16 // codes 0->255
|
||||||
|
nFilled int
|
||||||
|
n uint16 // Must be able to contain maxStoreBlockSize
|
||||||
|
tokens [maxStoreBlockSize + 1]token
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tokens) Reset() {
|
||||||
|
if t.n == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.n = 0
|
||||||
|
t.nFilled = 0
|
||||||
|
for i := range t.litHist[:] {
|
||||||
|
t.litHist[i] = 0
|
||||||
|
}
|
||||||
|
for i := range t.extraHist[:] {
|
||||||
|
t.extraHist[i] = 0
|
||||||
|
}
|
||||||
|
for i := range t.offHist[:] {
|
||||||
|
t.offHist[i] = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tokens) Fill() {
|
||||||
|
if t.n == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for i, v := range t.litHist[:] {
|
||||||
|
if v == 0 {
|
||||||
|
t.litHist[i] = 1
|
||||||
|
t.nFilled++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for i, v := range t.extraHist[:literalCount-256] {
|
||||||
|
if v == 0 {
|
||||||
|
t.nFilled++
|
||||||
|
t.extraHist[i] = 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for i, v := range t.offHist[:offsetCodeCount] {
|
||||||
|
if v == 0 {
|
||||||
|
t.offHist[i] = 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func indexTokens(in []token) tokens {
|
||||||
|
var t tokens
|
||||||
|
t.indexTokens(in)
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tokens) indexTokens(in []token) {
|
||||||
|
t.Reset()
|
||||||
|
for _, tok := range in {
|
||||||
|
if tok < matchType {
|
||||||
|
t.AddLiteral(tok.literal())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
t.AddMatch(uint32(tok.length()), tok.offset()&matchOffsetOnlyMask)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// emitLiteral writes a literal chunk and returns the number of bytes written.
|
||||||
|
func emitLiteral(dst *tokens, lit []byte) {
|
||||||
|
for _, v := range lit {
|
||||||
|
dst.tokens[dst.n] = token(v)
|
||||||
|
dst.litHist[v]++
|
||||||
|
dst.n++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tokens) AddLiteral(lit byte) {
|
||||||
|
t.tokens[t.n] = token(lit)
|
||||||
|
t.litHist[lit]++
|
||||||
|
t.n++
|
||||||
|
}
|
||||||
|
|
||||||
|
// from https://stackoverflow.com/a/28730362
|
||||||
|
func mFastLog2(val float32) float32 {
|
||||||
|
ux := int32(math.Float32bits(val))
|
||||||
|
log2 := (float32)(((ux >> 23) & 255) - 128)
|
||||||
|
ux &= -0x7f800001
|
||||||
|
ux += 127 << 23
|
||||||
|
uval := math.Float32frombits(uint32(ux))
|
||||||
|
log2 += ((-0.34484843)*uval+2.02466578)*uval - 0.67487759
|
||||||
|
return log2
|
||||||
|
}
|
||||||
|
|
||||||
|
// EstimatedBits will return an minimum size estimated by an *optimal*
|
||||||
|
// compression of the block.
|
||||||
|
// The size of the block
|
||||||
|
func (t *tokens) EstimatedBits() int {
|
||||||
|
shannon := float32(0)
|
||||||
|
bits := int(0)
|
||||||
|
nMatches := 0
|
||||||
|
total := int(t.n) + t.nFilled
|
||||||
|
if total > 0 {
|
||||||
|
invTotal := 1.0 / float32(total)
|
||||||
|
for _, v := range t.litHist[:] {
|
||||||
|
if v > 0 {
|
||||||
|
n := float32(v)
|
||||||
|
shannon += atLeastOne(-mFastLog2(n*invTotal)) * n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Just add 15 for EOB
|
||||||
|
shannon += 15
|
||||||
|
for i, v := range t.extraHist[1 : literalCount-256] {
|
||||||
|
if v > 0 {
|
||||||
|
n := float32(v)
|
||||||
|
shannon += atLeastOne(-mFastLog2(n*invTotal)) * n
|
||||||
|
bits += int(lengthExtraBits[i&31]) * int(v)
|
||||||
|
nMatches += int(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if nMatches > 0 {
|
||||||
|
invTotal := 1.0 / float32(nMatches)
|
||||||
|
for i, v := range t.offHist[:offsetCodeCount] {
|
||||||
|
if v > 0 {
|
||||||
|
n := float32(v)
|
||||||
|
shannon += atLeastOne(-mFastLog2(n*invTotal)) * n
|
||||||
|
bits += int(offsetExtraBits[i&31]) * int(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return int(shannon) + bits
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddMatch adds a match to the tokens.
|
||||||
|
// This function is very sensitive to inlining and right on the border.
|
||||||
|
func (t *tokens) AddMatch(xlength uint32, xoffset uint32) {
|
||||||
|
if debugDeflate {
|
||||||
|
if xlength >= maxMatchLength+baseMatchLength {
|
||||||
|
panic(fmt.Errorf("invalid length: %v", xlength))
|
||||||
|
}
|
||||||
|
if xoffset >= maxMatchOffset+baseMatchOffset {
|
||||||
|
panic(fmt.Errorf("invalid offset: %v", xoffset))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
oCode := offsetCode(xoffset)
|
||||||
|
xoffset |= oCode << 16
|
||||||
|
|
||||||
|
t.extraHist[lengthCodes1[uint8(xlength)]]++
|
||||||
|
t.offHist[oCode&31]++
|
||||||
|
t.tokens[t.n] = token(matchType | xlength<<lengthShift | xoffset)
|
||||||
|
t.n++
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddMatchLong adds a match to the tokens, potentially longer than max match length.
|
||||||
|
// Length should NOT have the base subtracted, only offset should.
|
||||||
|
func (t *tokens) AddMatchLong(xlength int32, xoffset uint32) {
|
||||||
|
if debugDeflate {
|
||||||
|
if xoffset >= maxMatchOffset+baseMatchOffset {
|
||||||
|
panic(fmt.Errorf("invalid offset: %v", xoffset))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
oc := offsetCode(xoffset)
|
||||||
|
xoffset |= oc << 16
|
||||||
|
for xlength > 0 {
|
||||||
|
xl := xlength
|
||||||
|
if xl > 258 {
|
||||||
|
// We need to have at least baseMatchLength left over for next loop.
|
||||||
|
if xl > 258+baseMatchLength {
|
||||||
|
xl = 258
|
||||||
|
} else {
|
||||||
|
xl = 258 - baseMatchLength
|
||||||
|
}
|
||||||
|
}
|
||||||
|
xlength -= xl
|
||||||
|
xl -= baseMatchLength
|
||||||
|
t.extraHist[lengthCodes1[uint8(xl)]]++
|
||||||
|
t.offHist[oc&31]++
|
||||||
|
t.tokens[t.n] = token(matchType | uint32(xl)<<lengthShift | xoffset)
|
||||||
|
t.n++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tokens) AddEOB() {
|
||||||
|
t.tokens[t.n] = token(endBlockMarker)
|
||||||
|
t.extraHist[0]++
|
||||||
|
t.n++
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tokens) Slice() []token {
|
||||||
|
return t.tokens[:t.n]
|
||||||
|
}
|
||||||
|
|
||||||
|
// VarInt returns the tokens as varint encoded bytes.
|
||||||
|
func (t *tokens) VarInt() []byte {
|
||||||
|
var b = make([]byte, binary.MaxVarintLen32*int(t.n))
|
||||||
|
var off int
|
||||||
|
for _, v := range t.tokens[:t.n] {
|
||||||
|
off += binary.PutUvarint(b[off:], uint64(v))
|
||||||
|
}
|
||||||
|
return b[:off]
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromVarInt restores t to the varint encoded tokens provided.
|
||||||
|
// Any data in t is removed.
|
||||||
|
func (t *tokens) FromVarInt(b []byte) error {
|
||||||
|
var buf = bytes.NewReader(b)
|
||||||
|
var toks []token
|
||||||
|
for {
|
||||||
|
r, err := binary.ReadUvarint(buf)
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
toks = append(toks, token(r))
|
||||||
|
}
|
||||||
|
t.indexTokens(toks)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the type of a token
|
||||||
|
func (t token) typ() uint32 { return uint32(t) & typeMask }
|
||||||
|
|
||||||
|
// Returns the literal of a literal token
|
||||||
|
func (t token) literal() uint8 { return uint8(t) }
|
||||||
|
|
||||||
|
// Returns the extra offset of a match token
|
||||||
|
func (t token) offset() uint32 { return uint32(t) & offsetMask }
|
||||||
|
|
||||||
|
func (t token) length() uint8 { return uint8(t >> lengthShift) }
|
||||||
|
|
||||||
|
// Convert length to code.
|
||||||
|
func lengthCode(len uint8) uint8 { return lengthCodes[len] }
|
||||||
|
|
||||||
|
// Returns the offset code corresponding to a specific offset
|
||||||
|
func offsetCode(off uint32) uint32 {
|
||||||
|
if false {
|
||||||
|
if off < uint32(len(offsetCodes)) {
|
||||||
|
return offsetCodes[off&255]
|
||||||
|
} else if off>>7 < uint32(len(offsetCodes)) {
|
||||||
|
return offsetCodes[(off>>7)&255] + 14
|
||||||
|
} else {
|
||||||
|
return offsetCodes[(off>>14)&255] + 28
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if off < uint32(len(offsetCodes)) {
|
||||||
|
return offsetCodes[uint8(off)]
|
||||||
|
}
|
||||||
|
return offsetCodes14[uint8(off>>7)]
|
||||||
|
}
|
|
@ -183,6 +183,9 @@ github.com/grpc-ecosystem/grpc-opentracing/go/otgrpc
|
||||||
# github.com/json-iterator/go v1.1.12
|
# github.com/json-iterator/go v1.1.12
|
||||||
## explicit; go 1.12
|
## explicit; go 1.12
|
||||||
github.com/json-iterator/go
|
github.com/json-iterator/go
|
||||||
|
# github.com/klauspost/compress v1.15.11
|
||||||
|
## explicit; go 1.17
|
||||||
|
github.com/klauspost/compress/flate
|
||||||
# github.com/kr/text v0.2.0
|
# github.com/kr/text v0.2.0
|
||||||
## explicit
|
## explicit
|
||||||
# github.com/kylelemons/godebug v1.1.0
|
# github.com/kylelemons/godebug v1.1.0
|
||||||
|
@ -582,6 +585,13 @@ gopkg.in/yaml.v2
|
||||||
# gopkg.in/yaml.v3 v3.0.1 => gopkg.in/yaml.v3 v3.0.1
|
# gopkg.in/yaml.v3 v3.0.1 => gopkg.in/yaml.v3 v3.0.1
|
||||||
## explicit
|
## explicit
|
||||||
gopkg.in/yaml.v3
|
gopkg.in/yaml.v3
|
||||||
|
# nhooyr.io/websocket v1.8.7
|
||||||
|
## explicit; go 1.13
|
||||||
|
nhooyr.io/websocket
|
||||||
|
nhooyr.io/websocket/internal/bpool
|
||||||
|
nhooyr.io/websocket/internal/errd
|
||||||
|
nhooyr.io/websocket/internal/wsjs
|
||||||
|
nhooyr.io/websocket/internal/xsync
|
||||||
# zombiezen.com/go/capnproto2 v2.18.0+incompatible
|
# zombiezen.com/go/capnproto2 v2.18.0+incompatible
|
||||||
## explicit
|
## explicit
|
||||||
zombiezen.com/go/capnproto2
|
zombiezen.com/go/capnproto2
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
websocket.test
|
|
@ -0,0 +1,21 @@
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2018 Anmol Sethi
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
|
@ -0,0 +1,132 @@
|
||||||
|
# websocket
|
||||||
|
|
||||||
|
[![godoc](https://godoc.org/nhooyr.io/websocket?status.svg)](https://pkg.go.dev/nhooyr.io/websocket)
|
||||||
|
[![coverage](https://img.shields.io/badge/coverage-88%25-success)](https://nhooyrio-websocket-coverage.netlify.app)
|
||||||
|
|
||||||
|
websocket is a minimal and idiomatic WebSocket library for Go.
|
||||||
|
|
||||||
|
## Install
|
||||||
|
|
||||||
|
```bash
|
||||||
|
go get nhooyr.io/websocket
|
||||||
|
```
|
||||||
|
|
||||||
|
## Highlights
|
||||||
|
|
||||||
|
- Minimal and idiomatic API
|
||||||
|
- First class [context.Context](https://blog.golang.org/context) support
|
||||||
|
- Fully passes the WebSocket [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite)
|
||||||
|
- [Single dependency](https://pkg.go.dev/nhooyr.io/websocket?tab=imports)
|
||||||
|
- JSON and protobuf helpers in the [wsjson](https://pkg.go.dev/nhooyr.io/websocket/wsjson) and [wspb](https://pkg.go.dev/nhooyr.io/websocket/wspb) subpackages
|
||||||
|
- Zero alloc reads and writes
|
||||||
|
- Concurrent writes
|
||||||
|
- [Close handshake](https://pkg.go.dev/nhooyr.io/websocket#Conn.Close)
|
||||||
|
- [net.Conn](https://pkg.go.dev/nhooyr.io/websocket#NetConn) wrapper
|
||||||
|
- [Ping pong](https://pkg.go.dev/nhooyr.io/websocket#Conn.Ping) API
|
||||||
|
- [RFC 7692](https://tools.ietf.org/html/rfc7692) permessage-deflate compression
|
||||||
|
- Compile to [Wasm](https://pkg.go.dev/nhooyr.io/websocket#hdr-Wasm)
|
||||||
|
|
||||||
|
## Roadmap
|
||||||
|
|
||||||
|
- [ ] HTTP/2 [#4](https://github.com/nhooyr/websocket/issues/4)
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
For a production quality example that demonstrates the complete API, see the
|
||||||
|
[echo example](./examples/echo).
|
||||||
|
|
||||||
|
For a full stack example, see the [chat example](./examples/chat).
|
||||||
|
|
||||||
|
### Server
|
||||||
|
|
||||||
|
```go
|
||||||
|
http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) {
|
||||||
|
c, err := websocket.Accept(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
// ...
|
||||||
|
}
|
||||||
|
defer c.Close(websocket.StatusInternalError, "the sky is falling")
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), time.Second*10)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
var v interface{}
|
||||||
|
err = wsjson.Read(ctx, c, &v)
|
||||||
|
if err != nil {
|
||||||
|
// ...
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("received: %v", v)
|
||||||
|
|
||||||
|
c.Close(websocket.StatusNormalClosure, "")
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Client
|
||||||
|
|
||||||
|
```go
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
c, _, err := websocket.Dial(ctx, "ws://localhost:8080", nil)
|
||||||
|
if err != nil {
|
||||||
|
// ...
|
||||||
|
}
|
||||||
|
defer c.Close(websocket.StatusInternalError, "the sky is falling")
|
||||||
|
|
||||||
|
err = wsjson.Write(ctx, c, "hi")
|
||||||
|
if err != nil {
|
||||||
|
// ...
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Close(websocket.StatusNormalClosure, "")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Comparison
|
||||||
|
|
||||||
|
### gorilla/websocket
|
||||||
|
|
||||||
|
Advantages of [gorilla/websocket](https://github.com/gorilla/websocket):
|
||||||
|
|
||||||
|
- Mature and widely used
|
||||||
|
- [Prepared writes](https://pkg.go.dev/github.com/gorilla/websocket#PreparedMessage)
|
||||||
|
- Configurable [buffer sizes](https://pkg.go.dev/github.com/gorilla/websocket#hdr-Buffers)
|
||||||
|
|
||||||
|
Advantages of nhooyr.io/websocket:
|
||||||
|
|
||||||
|
- Minimal and idiomatic API
|
||||||
|
- Compare godoc of [nhooyr.io/websocket](https://pkg.go.dev/nhooyr.io/websocket) with [gorilla/websocket](https://pkg.go.dev/github.com/gorilla/websocket) side by side.
|
||||||
|
- [net.Conn](https://pkg.go.dev/nhooyr.io/websocket#NetConn) wrapper
|
||||||
|
- Zero alloc reads and writes ([gorilla/websocket#535](https://github.com/gorilla/websocket/issues/535))
|
||||||
|
- Full [context.Context](https://blog.golang.org/context) support
|
||||||
|
- Dial uses [net/http.Client](https://golang.org/pkg/net/http/#Client)
|
||||||
|
- Will enable easy HTTP/2 support in the future
|
||||||
|
- Gorilla writes directly to a net.Conn and so duplicates features of net/http.Client.
|
||||||
|
- Concurrent writes
|
||||||
|
- Close handshake ([gorilla/websocket#448](https://github.com/gorilla/websocket/issues/448))
|
||||||
|
- Idiomatic [ping pong](https://pkg.go.dev/nhooyr.io/websocket#Conn.Ping) API
|
||||||
|
- Gorilla requires registering a pong callback before sending a Ping
|
||||||
|
- Can target Wasm ([gorilla/websocket#432](https://github.com/gorilla/websocket/issues/432))
|
||||||
|
- Transparent message buffer reuse with [wsjson](https://pkg.go.dev/nhooyr.io/websocket/wsjson) and [wspb](https://pkg.go.dev/nhooyr.io/websocket/wspb) subpackages
|
||||||
|
- [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) faster WebSocket masking implementation in pure Go
|
||||||
|
- Gorilla's implementation is slower and uses [unsafe](https://golang.org/pkg/unsafe/).
|
||||||
|
- Full [permessage-deflate](https://tools.ietf.org/html/rfc7692) compression extension support
|
||||||
|
- Gorilla only supports no context takeover mode
|
||||||
|
- We use [klauspost/compress](https://github.com/klauspost/compress) for much lower memory usage ([gorilla/websocket#203](https://github.com/gorilla/websocket/issues/203))
|
||||||
|
- [CloseRead](https://pkg.go.dev/nhooyr.io/websocket#Conn.CloseRead) helper ([gorilla/websocket#492](https://github.com/gorilla/websocket/issues/492))
|
||||||
|
- Actively maintained ([gorilla/websocket#370](https://github.com/gorilla/websocket/issues/370))
|
||||||
|
|
||||||
|
#### golang.org/x/net/websocket
|
||||||
|
|
||||||
|
[golang.org/x/net/websocket](https://pkg.go.dev/golang.org/x/net/websocket) is deprecated.
|
||||||
|
See [golang/go/issues/18152](https://github.com/golang/go/issues/18152).
|
||||||
|
|
||||||
|
The [net.Conn](https://pkg.go.dev/nhooyr.io/websocket#NetConn) can help in transitioning
|
||||||
|
to nhooyr.io/websocket.
|
||||||
|
|
||||||
|
#### gobwas/ws
|
||||||
|
|
||||||
|
[gobwas/ws](https://github.com/gobwas/ws) has an extremely flexible API that allows it to be used
|
||||||
|
in an event driven style for performance. See the author's [blog post](https://medium.freecodecamp.org/million-websockets-and-go-cc58418460bb).
|
||||||
|
|
||||||
|
However when writing idiomatic Go, nhooyr.io/websocket will be faster and easier to use.
|
|
@ -0,0 +1,370 @@
|
||||||
|
// +build !js
|
||||||
|
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/sha1"
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"net/textproto"
|
||||||
|
"net/url"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"nhooyr.io/websocket/internal/errd"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AcceptOptions represents Accept's options.
|
||||||
|
type AcceptOptions struct {
|
||||||
|
// Subprotocols lists the WebSocket subprotocols that Accept will negotiate with the client.
|
||||||
|
// The empty subprotocol will always be negotiated as per RFC 6455. If you would like to
|
||||||
|
// reject it, close the connection when c.Subprotocol() == "".
|
||||||
|
Subprotocols []string
|
||||||
|
|
||||||
|
// InsecureSkipVerify is used to disable Accept's origin verification behaviour.
|
||||||
|
//
|
||||||
|
// You probably want to use OriginPatterns instead.
|
||||||
|
InsecureSkipVerify bool
|
||||||
|
|
||||||
|
// OriginPatterns lists the host patterns for authorized origins.
|
||||||
|
// The request host is always authorized.
|
||||||
|
// Use this to enable cross origin WebSockets.
|
||||||
|
//
|
||||||
|
// i.e javascript running on example.com wants to access a WebSocket server at chat.example.com.
|
||||||
|
// In such a case, example.com is the origin and chat.example.com is the request host.
|
||||||
|
// One would set this field to []string{"example.com"} to authorize example.com to connect.
|
||||||
|
//
|
||||||
|
// Each pattern is matched case insensitively against the request origin host
|
||||||
|
// with filepath.Match.
|
||||||
|
// See https://golang.org/pkg/path/filepath/#Match
|
||||||
|
//
|
||||||
|
// Please ensure you understand the ramifications of enabling this.
|
||||||
|
// If used incorrectly your WebSocket server will be open to CSRF attacks.
|
||||||
|
//
|
||||||
|
// Do not use * as a pattern to allow any origin, prefer to use InsecureSkipVerify instead
|
||||||
|
// to bring attention to the danger of such a setting.
|
||||||
|
OriginPatterns []string
|
||||||
|
|
||||||
|
// CompressionMode controls the compression mode.
|
||||||
|
// Defaults to CompressionNoContextTakeover.
|
||||||
|
//
|
||||||
|
// See docs on CompressionMode for details.
|
||||||
|
CompressionMode CompressionMode
|
||||||
|
|
||||||
|
// CompressionThreshold controls the minimum size of a message before compression is applied.
|
||||||
|
//
|
||||||
|
// Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
|
||||||
|
// for CompressionContextTakeover.
|
||||||
|
CompressionThreshold int
|
||||||
|
}
|
||||||
|
|
||||||
|
// Accept accepts a WebSocket handshake from a client and upgrades the
|
||||||
|
// the connection to a WebSocket.
|
||||||
|
//
|
||||||
|
// Accept will not allow cross origin requests by default.
|
||||||
|
// See the InsecureSkipVerify and OriginPatterns options to allow cross origin requests.
|
||||||
|
//
|
||||||
|
// Accept will write a response to w on all errors.
|
||||||
|
func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) {
|
||||||
|
return accept(w, r, opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Conn, err error) {
|
||||||
|
defer errd.Wrap(&err, "failed to accept WebSocket connection")
|
||||||
|
|
||||||
|
if opts == nil {
|
||||||
|
opts = &AcceptOptions{}
|
||||||
|
}
|
||||||
|
opts = &*opts
|
||||||
|
|
||||||
|
errCode, err := verifyClientRequest(w, r)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), errCode)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !opts.InsecureSkipVerify {
|
||||||
|
err = authenticateOrigin(r, opts.OriginPatterns)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, filepath.ErrBadPattern) {
|
||||||
|
log.Printf("websocket: %v", err)
|
||||||
|
err = errors.New(http.StatusText(http.StatusForbidden))
|
||||||
|
}
|
||||||
|
http.Error(w, err.Error(), http.StatusForbidden)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
hj, ok := w.(http.Hijacker)
|
||||||
|
if !ok {
|
||||||
|
err = errors.New("http.ResponseWriter does not implement http.Hijacker")
|
||||||
|
http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Upgrade", "websocket")
|
||||||
|
w.Header().Set("Connection", "Upgrade")
|
||||||
|
|
||||||
|
key := r.Header.Get("Sec-WebSocket-Key")
|
||||||
|
w.Header().Set("Sec-WebSocket-Accept", secWebSocketAccept(key))
|
||||||
|
|
||||||
|
subproto := selectSubprotocol(r, opts.Subprotocols)
|
||||||
|
if subproto != "" {
|
||||||
|
w.Header().Set("Sec-WebSocket-Protocol", subproto)
|
||||||
|
}
|
||||||
|
|
||||||
|
copts, err := acceptCompression(r, w, opts.CompressionMode)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusSwitchingProtocols)
|
||||||
|
// See https://github.com/nhooyr/websocket/issues/166
|
||||||
|
if ginWriter, ok := w.(interface {
|
||||||
|
WriteHeaderNow()
|
||||||
|
}); ok {
|
||||||
|
ginWriter.WriteHeaderNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
netConn, brw, err := hj.Hijack()
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("failed to hijack connection: %w", err)
|
||||||
|
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/golang/go/issues/32314
|
||||||
|
b, _ := brw.Reader.Peek(brw.Reader.Buffered())
|
||||||
|
brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn))
|
||||||
|
|
||||||
|
return newConn(connConfig{
|
||||||
|
subprotocol: w.Header().Get("Sec-WebSocket-Protocol"),
|
||||||
|
rwc: netConn,
|
||||||
|
client: false,
|
||||||
|
copts: copts,
|
||||||
|
flateThreshold: opts.CompressionThreshold,
|
||||||
|
|
||||||
|
br: brw.Reader,
|
||||||
|
bw: brw.Writer,
|
||||||
|
}), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ error) {
|
||||||
|
if !r.ProtoAtLeast(1, 1) {
|
||||||
|
return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !headerContainsTokenIgnoreCase(r.Header, "Connection", "Upgrade") {
|
||||||
|
w.Header().Set("Connection", "Upgrade")
|
||||||
|
w.Header().Set("Upgrade", "websocket")
|
||||||
|
return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection"))
|
||||||
|
}
|
||||||
|
|
||||||
|
if !headerContainsTokenIgnoreCase(r.Header, "Upgrade", "websocket") {
|
||||||
|
w.Header().Set("Connection", "Upgrade")
|
||||||
|
w.Header().Set("Upgrade", "websocket")
|
||||||
|
return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade"))
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Method != "GET" {
|
||||||
|
return http.StatusMethodNotAllowed, fmt.Errorf("WebSocket protocol violation: handshake request method is not GET but %q", r.Method)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Header.Get("Sec-WebSocket-Version") != "13" {
|
||||||
|
w.Header().Set("Sec-WebSocket-Version", "13")
|
||||||
|
return http.StatusBadRequest, fmt.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version"))
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Header.Get("Sec-WebSocket-Key") == "" {
|
||||||
|
return http.StatusBadRequest, errors.New("WebSocket protocol violation: missing Sec-WebSocket-Key")
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func authenticateOrigin(r *http.Request, originHosts []string) error {
|
||||||
|
origin := r.Header.Get("Origin")
|
||||||
|
if origin == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := url.Parse(origin)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse Origin header %q: %w", origin, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.EqualFold(r.Host, u.Host) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, hostPattern := range originHosts {
|
||||||
|
matched, err := match(hostPattern, u.Host)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse filepath pattern %q: %w", hostPattern, err)
|
||||||
|
}
|
||||||
|
if matched {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host)
|
||||||
|
}
|
||||||
|
|
||||||
|
func match(pattern, s string) (bool, error) {
|
||||||
|
return filepath.Match(strings.ToLower(pattern), strings.ToLower(s))
|
||||||
|
}
|
||||||
|
|
||||||
|
func selectSubprotocol(r *http.Request, subprotocols []string) string {
|
||||||
|
cps := headerTokens(r.Header, "Sec-WebSocket-Protocol")
|
||||||
|
for _, sp := range subprotocols {
|
||||||
|
for _, cp := range cps {
|
||||||
|
if strings.EqualFold(sp, cp) {
|
||||||
|
return cp
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func acceptCompression(r *http.Request, w http.ResponseWriter, mode CompressionMode) (*compressionOptions, error) {
|
||||||
|
if mode == CompressionDisabled {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ext := range websocketExtensions(r.Header) {
|
||||||
|
switch ext.name {
|
||||||
|
case "permessage-deflate":
|
||||||
|
return acceptDeflate(w, ext, mode)
|
||||||
|
// Disabled for now, see https://github.com/nhooyr/websocket/issues/218
|
||||||
|
// case "x-webkit-deflate-frame":
|
||||||
|
// return acceptWebkitDeflate(w, ext, mode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) {
|
||||||
|
copts := mode.opts()
|
||||||
|
|
||||||
|
for _, p := range ext.params {
|
||||||
|
switch p {
|
||||||
|
case "client_no_context_takeover":
|
||||||
|
copts.clientNoContextTakeover = true
|
||||||
|
continue
|
||||||
|
case "server_no_context_takeover":
|
||||||
|
copts.serverNoContextTakeover = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(p, "client_max_window_bits") {
|
||||||
|
// We cannot adjust the read sliding window so cannot make use of this.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err := fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
copts.setHeader(w.Header())
|
||||||
|
|
||||||
|
return copts, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) {
|
||||||
|
copts := mode.opts()
|
||||||
|
// The peer must explicitly request it.
|
||||||
|
copts.serverNoContextTakeover = false
|
||||||
|
|
||||||
|
for _, p := range ext.params {
|
||||||
|
if p == "no_context_takeover" {
|
||||||
|
copts.serverNoContextTakeover = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// We explicitly fail on x-webkit-deflate-frame's max_window_bits parameter instead
|
||||||
|
// of ignoring it as the draft spec is unclear. It says the server can ignore it
|
||||||
|
// but the server has no way of signalling to the client it was ignored as the parameters
|
||||||
|
// are set one way.
|
||||||
|
// Thus us ignoring it would make the client think we understood it which would cause issues.
|
||||||
|
// See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06#section-4.1
|
||||||
|
//
|
||||||
|
// Either way, we're only implementing this for webkit which never sends the max_window_bits
|
||||||
|
// parameter so we don't need to worry about it.
|
||||||
|
err := fmt.Errorf("unsupported x-webkit-deflate-frame parameter: %q", p)
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
s := "x-webkit-deflate-frame"
|
||||||
|
if copts.clientNoContextTakeover {
|
||||||
|
s += "; no_context_takeover"
|
||||||
|
}
|
||||||
|
w.Header().Set("Sec-WebSocket-Extensions", s)
|
||||||
|
|
||||||
|
return copts, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func headerContainsTokenIgnoreCase(h http.Header, key, token string) bool {
|
||||||
|
for _, t := range headerTokens(h, key) {
|
||||||
|
if strings.EqualFold(t, token) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
type websocketExtension struct {
|
||||||
|
name string
|
||||||
|
params []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func websocketExtensions(h http.Header) []websocketExtension {
|
||||||
|
var exts []websocketExtension
|
||||||
|
extStrs := headerTokens(h, "Sec-WebSocket-Extensions")
|
||||||
|
for _, extStr := range extStrs {
|
||||||
|
if extStr == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
vals := strings.Split(extStr, ";")
|
||||||
|
for i := range vals {
|
||||||
|
vals[i] = strings.TrimSpace(vals[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
e := websocketExtension{
|
||||||
|
name: vals[0],
|
||||||
|
params: vals[1:],
|
||||||
|
}
|
||||||
|
|
||||||
|
exts = append(exts, e)
|
||||||
|
}
|
||||||
|
return exts
|
||||||
|
}
|
||||||
|
|
||||||
|
func headerTokens(h http.Header, key string) []string {
|
||||||
|
key = textproto.CanonicalMIMEHeaderKey(key)
|
||||||
|
var tokens []string
|
||||||
|
for _, v := range h[key] {
|
||||||
|
v = strings.TrimSpace(v)
|
||||||
|
for _, t := range strings.Split(v, ",") {
|
||||||
|
t = strings.TrimSpace(t)
|
||||||
|
tokens = append(tokens, t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
|
||||||
|
|
||||||
|
func secWebSocketAccept(secWebSocketKey string) string {
|
||||||
|
h := sha1.New()
|
||||||
|
h.Write([]byte(secWebSocketKey))
|
||||||
|
h.Write(keyGUID)
|
||||||
|
|
||||||
|
return base64.StdEncoding.EncodeToString(h.Sum(nil))
|
||||||
|
}
|
|
@ -0,0 +1,20 @@
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AcceptOptions represents Accept's options.
|
||||||
|
type AcceptOptions struct {
|
||||||
|
Subprotocols []string
|
||||||
|
InsecureSkipVerify bool
|
||||||
|
OriginPatterns []string
|
||||||
|
CompressionMode CompressionMode
|
||||||
|
CompressionThreshold int
|
||||||
|
}
|
||||||
|
|
||||||
|
// Accept is stubbed out for Wasm.
|
||||||
|
func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) {
|
||||||
|
return nil, errors.New("unimplemented")
|
||||||
|
}
|
|
@ -0,0 +1,76 @@
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// StatusCode represents a WebSocket status code.
|
||||||
|
// https://tools.ietf.org/html/rfc6455#section-7.4
|
||||||
|
type StatusCode int
|
||||||
|
|
||||||
|
// https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
|
||||||
|
//
|
||||||
|
// These are only the status codes defined by the protocol.
|
||||||
|
//
|
||||||
|
// You can define custom codes in the 3000-4999 range.
|
||||||
|
// The 3000-3999 range is reserved for use by libraries, frameworks and applications.
|
||||||
|
// The 4000-4999 range is reserved for private use.
|
||||||
|
const (
|
||||||
|
StatusNormalClosure StatusCode = 1000
|
||||||
|
StatusGoingAway StatusCode = 1001
|
||||||
|
StatusProtocolError StatusCode = 1002
|
||||||
|
StatusUnsupportedData StatusCode = 1003
|
||||||
|
|
||||||
|
// 1004 is reserved and so unexported.
|
||||||
|
statusReserved StatusCode = 1004
|
||||||
|
|
||||||
|
// StatusNoStatusRcvd cannot be sent in a close message.
|
||||||
|
// It is reserved for when a close message is received without
|
||||||
|
// a status code.
|
||||||
|
StatusNoStatusRcvd StatusCode = 1005
|
||||||
|
|
||||||
|
// StatusAbnormalClosure is exported for use only with Wasm.
|
||||||
|
// In non Wasm Go, the returned error will indicate whether the
|
||||||
|
// connection was closed abnormally.
|
||||||
|
StatusAbnormalClosure StatusCode = 1006
|
||||||
|
|
||||||
|
StatusInvalidFramePayloadData StatusCode = 1007
|
||||||
|
StatusPolicyViolation StatusCode = 1008
|
||||||
|
StatusMessageTooBig StatusCode = 1009
|
||||||
|
StatusMandatoryExtension StatusCode = 1010
|
||||||
|
StatusInternalError StatusCode = 1011
|
||||||
|
StatusServiceRestart StatusCode = 1012
|
||||||
|
StatusTryAgainLater StatusCode = 1013
|
||||||
|
StatusBadGateway StatusCode = 1014
|
||||||
|
|
||||||
|
// StatusTLSHandshake is only exported for use with Wasm.
|
||||||
|
// In non Wasm Go, the returned error will indicate whether there was
|
||||||
|
// a TLS handshake failure.
|
||||||
|
StatusTLSHandshake StatusCode = 1015
|
||||||
|
)
|
||||||
|
|
||||||
|
// CloseError is returned when the connection is closed with a status and reason.
|
||||||
|
//
|
||||||
|
// Use Go 1.13's errors.As to check for this error.
|
||||||
|
// Also see the CloseStatus helper.
|
||||||
|
type CloseError struct {
|
||||||
|
Code StatusCode
|
||||||
|
Reason string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ce CloseError) Error() string {
|
||||||
|
return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseStatus is a convenience wrapper around Go 1.13's errors.As to grab
|
||||||
|
// the status code from a CloseError.
|
||||||
|
//
|
||||||
|
// -1 will be returned if the passed error is nil or not a CloseError.
|
||||||
|
func CloseStatus(err error) StatusCode {
|
||||||
|
var ce CloseError
|
||||||
|
if errors.As(err, &ce) {
|
||||||
|
return ce.Code
|
||||||
|
}
|
||||||
|
return -1
|
||||||
|
}
|
|
@ -0,0 +1,211 @@
|
||||||
|
// +build !js
|
||||||
|
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"nhooyr.io/websocket/internal/errd"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Close performs the WebSocket close handshake with the given status code and reason.
|
||||||
|
//
|
||||||
|
// It will write a WebSocket close frame with a timeout of 5s and then wait 5s for
|
||||||
|
// the peer to send a close frame.
|
||||||
|
// All data messages received from the peer during the close handshake will be discarded.
|
||||||
|
//
|
||||||
|
// The connection can only be closed once. Additional calls to Close
|
||||||
|
// are no-ops.
|
||||||
|
//
|
||||||
|
// The maximum length of reason must be 125 bytes. Avoid
|
||||||
|
// sending a dynamic reason.
|
||||||
|
//
|
||||||
|
// Close will unblock all goroutines interacting with the connection once
|
||||||
|
// complete.
|
||||||
|
func (c *Conn) Close(code StatusCode, reason string) error {
|
||||||
|
return c.closeHandshake(code, reason)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) {
|
||||||
|
defer errd.Wrap(&err, "failed to close WebSocket")
|
||||||
|
|
||||||
|
writeErr := c.writeClose(code, reason)
|
||||||
|
closeHandshakeErr := c.waitCloseHandshake()
|
||||||
|
|
||||||
|
if writeErr != nil {
|
||||||
|
return writeErr
|
||||||
|
}
|
||||||
|
|
||||||
|
if CloseStatus(closeHandshakeErr) == -1 {
|
||||||
|
return closeHandshakeErr
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var errAlreadyWroteClose = errors.New("already wrote close")
|
||||||
|
|
||||||
|
func (c *Conn) writeClose(code StatusCode, reason string) error {
|
||||||
|
c.closeMu.Lock()
|
||||||
|
wroteClose := c.wroteClose
|
||||||
|
c.wroteClose = true
|
||||||
|
c.closeMu.Unlock()
|
||||||
|
if wroteClose {
|
||||||
|
return errAlreadyWroteClose
|
||||||
|
}
|
||||||
|
|
||||||
|
ce := CloseError{
|
||||||
|
Code: code,
|
||||||
|
Reason: reason,
|
||||||
|
}
|
||||||
|
|
||||||
|
var p []byte
|
||||||
|
var marshalErr error
|
||||||
|
if ce.Code != StatusNoStatusRcvd {
|
||||||
|
p, marshalErr = ce.bytes()
|
||||||
|
if marshalErr != nil {
|
||||||
|
log.Printf("websocket: %v", marshalErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
writeErr := c.writeControl(context.Background(), opClose, p)
|
||||||
|
if CloseStatus(writeErr) != -1 {
|
||||||
|
// Not a real error if it's due to a close frame being received.
|
||||||
|
writeErr = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// We do this after in case there was an error writing the close frame.
|
||||||
|
c.setCloseErr(fmt.Errorf("sent close frame: %w", ce))
|
||||||
|
|
||||||
|
if marshalErr != nil {
|
||||||
|
return marshalErr
|
||||||
|
}
|
||||||
|
return writeErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) waitCloseHandshake() error {
|
||||||
|
defer c.close(nil)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err := c.readMu.lock(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer c.readMu.unlock()
|
||||||
|
|
||||||
|
if c.readCloseFrameErr != nil {
|
||||||
|
return c.readCloseFrameErr
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
h, err := c.readLoop(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := int64(0); i < h.payloadLength; i++ {
|
||||||
|
_, err := c.br.ReadByte()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseClosePayload(p []byte) (CloseError, error) {
|
||||||
|
if len(p) == 0 {
|
||||||
|
return CloseError{
|
||||||
|
Code: StatusNoStatusRcvd,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(p) < 2 {
|
||||||
|
return CloseError{}, fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p)
|
||||||
|
}
|
||||||
|
|
||||||
|
ce := CloseError{
|
||||||
|
Code: StatusCode(binary.BigEndian.Uint16(p)),
|
||||||
|
Reason: string(p[2:]),
|
||||||
|
}
|
||||||
|
|
||||||
|
if !validWireCloseCode(ce.Code) {
|
||||||
|
return CloseError{}, fmt.Errorf("invalid status code %v", ce.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ce, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
|
||||||
|
// and https://tools.ietf.org/html/rfc6455#section-7.4.1
|
||||||
|
func validWireCloseCode(code StatusCode) bool {
|
||||||
|
switch code {
|
||||||
|
case statusReserved, StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if code >= StatusNormalClosure && code <= StatusBadGateway {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if code >= 3000 && code <= 4999 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ce CloseError) bytes() ([]byte, error) {
|
||||||
|
p, err := ce.bytesErr()
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("failed to marshal close frame: %w", err)
|
||||||
|
ce = CloseError{
|
||||||
|
Code: StatusInternalError,
|
||||||
|
}
|
||||||
|
p, _ = ce.bytesErr()
|
||||||
|
}
|
||||||
|
return p, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const maxCloseReason = maxControlPayload - 2
|
||||||
|
|
||||||
|
func (ce CloseError) bytesErr() ([]byte, error) {
|
||||||
|
if len(ce.Reason) > maxCloseReason {
|
||||||
|
return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason))
|
||||||
|
}
|
||||||
|
|
||||||
|
if !validWireCloseCode(ce.Code) {
|
||||||
|
return nil, fmt.Errorf("status code %v cannot be set", ce.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, 2+len(ce.Reason))
|
||||||
|
binary.BigEndian.PutUint16(buf, uint16(ce.Code))
|
||||||
|
copy(buf[2:], ce.Reason)
|
||||||
|
return buf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) setCloseErr(err error) {
|
||||||
|
c.closeMu.Lock()
|
||||||
|
c.setCloseErrLocked(err)
|
||||||
|
c.closeMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) setCloseErrLocked(err error) {
|
||||||
|
if c.closeErr == nil {
|
||||||
|
c.closeErr = fmt.Errorf("WebSocket closed: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) isClosed() bool {
|
||||||
|
select {
|
||||||
|
case <-c.closed:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,39 @@
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
// CompressionMode represents the modes available to the deflate extension.
|
||||||
|
// See https://tools.ietf.org/html/rfc7692
|
||||||
|
//
|
||||||
|
// A compatibility layer is implemented for the older deflate-frame extension used
|
||||||
|
// by safari. See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06
|
||||||
|
// It will work the same in every way except that we cannot signal to the peer we
|
||||||
|
// want to use no context takeover on our side, we can only signal that they should.
|
||||||
|
// It is however currently disabled due to Safari bugs. See https://github.com/nhooyr/websocket/issues/218
|
||||||
|
type CompressionMode int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// CompressionNoContextTakeover grabs a new flate.Reader and flate.Writer as needed
|
||||||
|
// for every message. This applies to both server and client side.
|
||||||
|
//
|
||||||
|
// This means less efficient compression as the sliding window from previous messages
|
||||||
|
// will not be used but the memory overhead will be lower if the connections
|
||||||
|
// are long lived and seldom used.
|
||||||
|
//
|
||||||
|
// The message will only be compressed if greater than 512 bytes.
|
||||||
|
CompressionNoContextTakeover CompressionMode = iota
|
||||||
|
|
||||||
|
// CompressionContextTakeover uses a flate.Reader and flate.Writer per connection.
|
||||||
|
// This enables reusing the sliding window from previous messages.
|
||||||
|
// As most WebSocket protocols are repetitive, this can be very efficient.
|
||||||
|
// It carries an overhead of 8 kB for every connection compared to CompressionNoContextTakeover.
|
||||||
|
//
|
||||||
|
// If the peer negotiates NoContextTakeover on the client or server side, it will be
|
||||||
|
// used instead as this is required by the RFC.
|
||||||
|
CompressionContextTakeover
|
||||||
|
|
||||||
|
// CompressionDisabled disables the deflate extension.
|
||||||
|
//
|
||||||
|
// Use this if you are using a predominantly binary protocol with very
|
||||||
|
// little duplication in between messages or CPU and memory are more
|
||||||
|
// important than bandwidth.
|
||||||
|
CompressionDisabled
|
||||||
|
)
|
|
@ -0,0 +1,181 @@
|
||||||
|
// +build !js
|
||||||
|
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/klauspost/compress/flate"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (m CompressionMode) opts() *compressionOptions {
|
||||||
|
return &compressionOptions{
|
||||||
|
clientNoContextTakeover: m == CompressionNoContextTakeover,
|
||||||
|
serverNoContextTakeover: m == CompressionNoContextTakeover,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type compressionOptions struct {
|
||||||
|
clientNoContextTakeover bool
|
||||||
|
serverNoContextTakeover bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (copts *compressionOptions) setHeader(h http.Header) {
|
||||||
|
s := "permessage-deflate"
|
||||||
|
if copts.clientNoContextTakeover {
|
||||||
|
s += "; client_no_context_takeover"
|
||||||
|
}
|
||||||
|
if copts.serverNoContextTakeover {
|
||||||
|
s += "; server_no_context_takeover"
|
||||||
|
}
|
||||||
|
h.Set("Sec-WebSocket-Extensions", s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// These bytes are required to get flate.Reader to return.
|
||||||
|
// They are removed when sending to avoid the overhead as
|
||||||
|
// WebSocket framing tell's when the message has ended but then
|
||||||
|
// we need to add them back otherwise flate.Reader keeps
|
||||||
|
// trying to return more bytes.
|
||||||
|
const deflateMessageTail = "\x00\x00\xff\xff"
|
||||||
|
|
||||||
|
type trimLastFourBytesWriter struct {
|
||||||
|
w io.Writer
|
||||||
|
tail []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tw *trimLastFourBytesWriter) reset() {
|
||||||
|
if tw != nil && tw.tail != nil {
|
||||||
|
tw.tail = tw.tail[:0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) {
|
||||||
|
if tw.tail == nil {
|
||||||
|
tw.tail = make([]byte, 0, 4)
|
||||||
|
}
|
||||||
|
|
||||||
|
extra := len(tw.tail) + len(p) - 4
|
||||||
|
|
||||||
|
if extra <= 0 {
|
||||||
|
tw.tail = append(tw.tail, p...)
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now we need to write as many extra bytes as we can from the previous tail.
|
||||||
|
if extra > len(tw.tail) {
|
||||||
|
extra = len(tw.tail)
|
||||||
|
}
|
||||||
|
if extra > 0 {
|
||||||
|
_, err := tw.w.Write(tw.tail[:extra])
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shift remaining bytes in tail over.
|
||||||
|
n := copy(tw.tail, tw.tail[extra:])
|
||||||
|
tw.tail = tw.tail[:n]
|
||||||
|
}
|
||||||
|
|
||||||
|
// If p is less than or equal to 4 bytes,
|
||||||
|
// all of it is is part of the tail.
|
||||||
|
if len(p) <= 4 {
|
||||||
|
tw.tail = append(tw.tail, p...)
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, only the last 4 bytes are.
|
||||||
|
tw.tail = append(tw.tail, p[len(p)-4:]...)
|
||||||
|
|
||||||
|
p = p[:len(p)-4]
|
||||||
|
n, err := tw.w.Write(p)
|
||||||
|
return n + 4, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var flateReaderPool sync.Pool
|
||||||
|
|
||||||
|
func getFlateReader(r io.Reader, dict []byte) io.Reader {
|
||||||
|
fr, ok := flateReaderPool.Get().(io.Reader)
|
||||||
|
if !ok {
|
||||||
|
return flate.NewReaderDict(r, dict)
|
||||||
|
}
|
||||||
|
fr.(flate.Resetter).Reset(r, dict)
|
||||||
|
return fr
|
||||||
|
}
|
||||||
|
|
||||||
|
func putFlateReader(fr io.Reader) {
|
||||||
|
flateReaderPool.Put(fr)
|
||||||
|
}
|
||||||
|
|
||||||
|
type slidingWindow struct {
|
||||||
|
buf []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
var swPoolMu sync.RWMutex
|
||||||
|
var swPool = map[int]*sync.Pool{}
|
||||||
|
|
||||||
|
func slidingWindowPool(n int) *sync.Pool {
|
||||||
|
swPoolMu.RLock()
|
||||||
|
p, ok := swPool[n]
|
||||||
|
swPoolMu.RUnlock()
|
||||||
|
if ok {
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
p = &sync.Pool{}
|
||||||
|
|
||||||
|
swPoolMu.Lock()
|
||||||
|
swPool[n] = p
|
||||||
|
swPoolMu.Unlock()
|
||||||
|
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sw *slidingWindow) init(n int) {
|
||||||
|
if sw.buf != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if n == 0 {
|
||||||
|
n = 32768
|
||||||
|
}
|
||||||
|
|
||||||
|
p := slidingWindowPool(n)
|
||||||
|
buf, ok := p.Get().([]byte)
|
||||||
|
if ok {
|
||||||
|
sw.buf = buf[:0]
|
||||||
|
} else {
|
||||||
|
sw.buf = make([]byte, 0, n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sw *slidingWindow) close() {
|
||||||
|
if sw.buf == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
swPoolMu.Lock()
|
||||||
|
swPool[cap(sw.buf)].Put(sw.buf)
|
||||||
|
swPoolMu.Unlock()
|
||||||
|
sw.buf = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sw *slidingWindow) write(p []byte) {
|
||||||
|
if len(p) >= cap(sw.buf) {
|
||||||
|
sw.buf = sw.buf[:cap(sw.buf)]
|
||||||
|
p = p[len(p)-cap(sw.buf):]
|
||||||
|
copy(sw.buf, p)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
left := cap(sw.buf) - len(sw.buf)
|
||||||
|
if left < len(p) {
|
||||||
|
// We need to shift spaceNeeded bytes from the end to make room for p at the end.
|
||||||
|
spaceNeeded := len(p) - left
|
||||||
|
copy(sw.buf, sw.buf[spaceNeeded:])
|
||||||
|
sw.buf = sw.buf[:len(sw.buf)-spaceNeeded]
|
||||||
|
}
|
||||||
|
|
||||||
|
sw.buf = append(sw.buf, p...)
|
||||||
|
}
|
|
@ -0,0 +1,13 @@
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
// MessageType represents the type of a WebSocket message.
|
||||||
|
// See https://tools.ietf.org/html/rfc6455#section-5.6
|
||||||
|
type MessageType int
|
||||||
|
|
||||||
|
// MessageType constants.
|
||||||
|
const (
|
||||||
|
// MessageText is for UTF-8 encoded text messages like JSON.
|
||||||
|
MessageText MessageType = iota + 1
|
||||||
|
// MessageBinary is for binary messages like protobufs.
|
||||||
|
MessageBinary
|
||||||
|
)
|
|
@ -0,0 +1,265 @@
|
||||||
|
// +build !js
|
||||||
|
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Conn represents a WebSocket connection.
|
||||||
|
// All methods may be called concurrently except for Reader and Read.
|
||||||
|
//
|
||||||
|
// You must always read from the connection. Otherwise control
|
||||||
|
// frames will not be handled. See Reader and CloseRead.
|
||||||
|
//
|
||||||
|
// Be sure to call Close on the connection when you
|
||||||
|
// are finished with it to release associated resources.
|
||||||
|
//
|
||||||
|
// On any error from any method, the connection is closed
|
||||||
|
// with an appropriate reason.
|
||||||
|
type Conn struct {
|
||||||
|
subprotocol string
|
||||||
|
rwc io.ReadWriteCloser
|
||||||
|
client bool
|
||||||
|
copts *compressionOptions
|
||||||
|
flateThreshold int
|
||||||
|
br *bufio.Reader
|
||||||
|
bw *bufio.Writer
|
||||||
|
|
||||||
|
readTimeout chan context.Context
|
||||||
|
writeTimeout chan context.Context
|
||||||
|
|
||||||
|
// Read state.
|
||||||
|
readMu *mu
|
||||||
|
readHeaderBuf [8]byte
|
||||||
|
readControlBuf [maxControlPayload]byte
|
||||||
|
msgReader *msgReader
|
||||||
|
readCloseFrameErr error
|
||||||
|
|
||||||
|
// Write state.
|
||||||
|
msgWriterState *msgWriterState
|
||||||
|
writeFrameMu *mu
|
||||||
|
writeBuf []byte
|
||||||
|
writeHeaderBuf [8]byte
|
||||||
|
writeHeader header
|
||||||
|
|
||||||
|
closed chan struct{}
|
||||||
|
closeMu sync.Mutex
|
||||||
|
closeErr error
|
||||||
|
wroteClose bool
|
||||||
|
|
||||||
|
pingCounter int32
|
||||||
|
activePingsMu sync.Mutex
|
||||||
|
activePings map[string]chan<- struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type connConfig struct {
|
||||||
|
subprotocol string
|
||||||
|
rwc io.ReadWriteCloser
|
||||||
|
client bool
|
||||||
|
copts *compressionOptions
|
||||||
|
flateThreshold int
|
||||||
|
|
||||||
|
br *bufio.Reader
|
||||||
|
bw *bufio.Writer
|
||||||
|
}
|
||||||
|
|
||||||
|
func newConn(cfg connConfig) *Conn {
|
||||||
|
c := &Conn{
|
||||||
|
subprotocol: cfg.subprotocol,
|
||||||
|
rwc: cfg.rwc,
|
||||||
|
client: cfg.client,
|
||||||
|
copts: cfg.copts,
|
||||||
|
flateThreshold: cfg.flateThreshold,
|
||||||
|
|
||||||
|
br: cfg.br,
|
||||||
|
bw: cfg.bw,
|
||||||
|
|
||||||
|
readTimeout: make(chan context.Context),
|
||||||
|
writeTimeout: make(chan context.Context),
|
||||||
|
|
||||||
|
closed: make(chan struct{}),
|
||||||
|
activePings: make(map[string]chan<- struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
c.readMu = newMu(c)
|
||||||
|
c.writeFrameMu = newMu(c)
|
||||||
|
|
||||||
|
c.msgReader = newMsgReader(c)
|
||||||
|
|
||||||
|
c.msgWriterState = newMsgWriterState(c)
|
||||||
|
if c.client {
|
||||||
|
c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc)
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.flate() && c.flateThreshold == 0 {
|
||||||
|
c.flateThreshold = 128
|
||||||
|
if !c.msgWriterState.flateContextTakeover() {
|
||||||
|
c.flateThreshold = 512
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
runtime.SetFinalizer(c, func(c *Conn) {
|
||||||
|
c.close(errors.New("connection garbage collected"))
|
||||||
|
})
|
||||||
|
|
||||||
|
go c.timeoutLoop()
|
||||||
|
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subprotocol returns the negotiated subprotocol.
|
||||||
|
// An empty string means the default protocol.
|
||||||
|
func (c *Conn) Subprotocol() string {
|
||||||
|
return c.subprotocol
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) close(err error) {
|
||||||
|
c.closeMu.Lock()
|
||||||
|
defer c.closeMu.Unlock()
|
||||||
|
|
||||||
|
if c.isClosed() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.setCloseErrLocked(err)
|
||||||
|
close(c.closed)
|
||||||
|
runtime.SetFinalizer(c, nil)
|
||||||
|
|
||||||
|
// Have to close after c.closed is closed to ensure any goroutine that wakes up
|
||||||
|
// from the connection being closed also sees that c.closed is closed and returns
|
||||||
|
// closeErr.
|
||||||
|
c.rwc.Close()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
c.msgWriterState.close()
|
||||||
|
|
||||||
|
c.msgReader.close()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) timeoutLoop() {
|
||||||
|
readCtx := context.Background()
|
||||||
|
writeCtx := context.Background()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.closed:
|
||||||
|
return
|
||||||
|
|
||||||
|
case writeCtx = <-c.writeTimeout:
|
||||||
|
case readCtx = <-c.readTimeout:
|
||||||
|
|
||||||
|
case <-readCtx.Done():
|
||||||
|
c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err()))
|
||||||
|
go c.writeError(StatusPolicyViolation, errors.New("timed out"))
|
||||||
|
case <-writeCtx.Done():
|
||||||
|
c.close(fmt.Errorf("write timed out: %w", writeCtx.Err()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) flate() bool {
|
||||||
|
return c.copts != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ping sends a ping to the peer and waits for a pong.
|
||||||
|
// Use this to measure latency or ensure the peer is responsive.
|
||||||
|
// Ping must be called concurrently with Reader as it does
|
||||||
|
// not read from the connection but instead waits for a Reader call
|
||||||
|
// to read the pong.
|
||||||
|
//
|
||||||
|
// TCP Keepalives should suffice for most use cases.
|
||||||
|
func (c *Conn) Ping(ctx context.Context) error {
|
||||||
|
p := atomic.AddInt32(&c.pingCounter, 1)
|
||||||
|
|
||||||
|
err := c.ping(ctx, strconv.Itoa(int(p)))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to ping: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) ping(ctx context.Context, p string) error {
|
||||||
|
pong := make(chan struct{}, 1)
|
||||||
|
|
||||||
|
c.activePingsMu.Lock()
|
||||||
|
c.activePings[p] = pong
|
||||||
|
c.activePingsMu.Unlock()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
c.activePingsMu.Lock()
|
||||||
|
delete(c.activePings, p)
|
||||||
|
c.activePingsMu.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := c.writeControl(ctx, opPing, []byte(p))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-c.closed:
|
||||||
|
return c.closeErr
|
||||||
|
case <-ctx.Done():
|
||||||
|
err := fmt.Errorf("failed to wait for pong: %w", ctx.Err())
|
||||||
|
c.close(err)
|
||||||
|
return err
|
||||||
|
case <-pong:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type mu struct {
|
||||||
|
c *Conn
|
||||||
|
ch chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMu(c *Conn) *mu {
|
||||||
|
return &mu{
|
||||||
|
c: c,
|
||||||
|
ch: make(chan struct{}, 1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mu) forceLock() {
|
||||||
|
m.ch <- struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mu) lock(ctx context.Context) error {
|
||||||
|
select {
|
||||||
|
case <-m.c.closed:
|
||||||
|
return m.c.closeErr
|
||||||
|
case <-ctx.Done():
|
||||||
|
err := fmt.Errorf("failed to acquire lock: %w", ctx.Err())
|
||||||
|
m.c.close(err)
|
||||||
|
return err
|
||||||
|
case m.ch <- struct{}{}:
|
||||||
|
// To make sure the connection is certainly alive.
|
||||||
|
// As it's possible the send on m.ch was selected
|
||||||
|
// over the receive on closed.
|
||||||
|
select {
|
||||||
|
case <-m.c.closed:
|
||||||
|
// Make sure to release.
|
||||||
|
m.unlock()
|
||||||
|
return m.c.closeErr
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mu) unlock() {
|
||||||
|
select {
|
||||||
|
case <-m.ch:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,292 @@
|
||||||
|
// +build !js
|
||||||
|
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"nhooyr.io/websocket/internal/errd"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DialOptions represents Dial's options.
|
||||||
|
type DialOptions struct {
|
||||||
|
// HTTPClient is used for the connection.
|
||||||
|
// Its Transport must return writable bodies for WebSocket handshakes.
|
||||||
|
// http.Transport does beginning with Go 1.12.
|
||||||
|
HTTPClient *http.Client
|
||||||
|
|
||||||
|
// HTTPHeader specifies the HTTP headers included in the handshake request.
|
||||||
|
HTTPHeader http.Header
|
||||||
|
|
||||||
|
// Subprotocols lists the WebSocket subprotocols to negotiate with the server.
|
||||||
|
Subprotocols []string
|
||||||
|
|
||||||
|
// CompressionMode controls the compression mode.
|
||||||
|
// Defaults to CompressionNoContextTakeover.
|
||||||
|
//
|
||||||
|
// See docs on CompressionMode for details.
|
||||||
|
CompressionMode CompressionMode
|
||||||
|
|
||||||
|
// CompressionThreshold controls the minimum size of a message before compression is applied.
|
||||||
|
//
|
||||||
|
// Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
|
||||||
|
// for CompressionContextTakeover.
|
||||||
|
CompressionThreshold int
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dial performs a WebSocket handshake on url.
|
||||||
|
//
|
||||||
|
// The response is the WebSocket handshake response from the server.
|
||||||
|
// You never need to close resp.Body yourself.
|
||||||
|
//
|
||||||
|
// If an error occurs, the returned response may be non nil.
|
||||||
|
// However, you can only read the first 1024 bytes of the body.
|
||||||
|
//
|
||||||
|
// This function requires at least Go 1.12 as it uses a new feature
|
||||||
|
// in net/http to perform WebSocket handshakes.
|
||||||
|
// See docs on the HTTPClient option and https://github.com/golang/go/issues/26937#issuecomment-415855861
|
||||||
|
//
|
||||||
|
// URLs with http/https schemes will work and are interpreted as ws/wss.
|
||||||
|
func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) {
|
||||||
|
return dial(ctx, u, opts, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (_ *Conn, _ *http.Response, err error) {
|
||||||
|
defer errd.Wrap(&err, "failed to WebSocket dial")
|
||||||
|
|
||||||
|
if opts == nil {
|
||||||
|
opts = &DialOptions{}
|
||||||
|
}
|
||||||
|
|
||||||
|
opts = &*opts
|
||||||
|
if opts.HTTPClient == nil {
|
||||||
|
opts.HTTPClient = http.DefaultClient
|
||||||
|
} else if opts.HTTPClient.Timeout > 0 {
|
||||||
|
var cancel context.CancelFunc
|
||||||
|
|
||||||
|
ctx, cancel = context.WithTimeout(ctx, opts.HTTPClient.Timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
newClient := *opts.HTTPClient
|
||||||
|
newClient.Timeout = 0
|
||||||
|
opts.HTTPClient = &newClient
|
||||||
|
}
|
||||||
|
|
||||||
|
if opts.HTTPHeader == nil {
|
||||||
|
opts.HTTPHeader = http.Header{}
|
||||||
|
}
|
||||||
|
|
||||||
|
secWebSocketKey, err := secWebSocketKey(rand)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var copts *compressionOptions
|
||||||
|
if opts.CompressionMode != CompressionDisabled {
|
||||||
|
copts = opts.CompressionMode.opts()
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := handshakeRequest(ctx, urls, opts, copts, secWebSocketKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, resp, err
|
||||||
|
}
|
||||||
|
respBody := resp.Body
|
||||||
|
resp.Body = nil
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
// We read a bit of the body for easier debugging.
|
||||||
|
r := io.LimitReader(respBody, 1024)
|
||||||
|
|
||||||
|
timer := time.AfterFunc(time.Second*3, func() {
|
||||||
|
respBody.Close()
|
||||||
|
})
|
||||||
|
defer timer.Stop()
|
||||||
|
|
||||||
|
b, _ := ioutil.ReadAll(r)
|
||||||
|
respBody.Close()
|
||||||
|
resp.Body = ioutil.NopCloser(bytes.NewReader(b))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
copts, err = verifyServerResponse(opts, copts, secWebSocketKey, resp)
|
||||||
|
if err != nil {
|
||||||
|
return nil, resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
rwc, ok := respBody.(io.ReadWriteCloser)
|
||||||
|
if !ok {
|
||||||
|
return nil, resp, fmt.Errorf("response body is not a io.ReadWriteCloser: %T", respBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
return newConn(connConfig{
|
||||||
|
subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"),
|
||||||
|
rwc: rwc,
|
||||||
|
client: true,
|
||||||
|
copts: copts,
|
||||||
|
flateThreshold: opts.CompressionThreshold,
|
||||||
|
br: getBufioReader(rwc),
|
||||||
|
bw: getBufioWriter(rwc),
|
||||||
|
}), resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts *compressionOptions, secWebSocketKey string) (*http.Response, error) {
|
||||||
|
u, err := url.Parse(urls)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse url: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch u.Scheme {
|
||||||
|
case "ws":
|
||||||
|
u.Scheme = "http"
|
||||||
|
case "wss":
|
||||||
|
u.Scheme = "https"
|
||||||
|
case "http", "https":
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unexpected url scheme: %q", u.Scheme)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, _ := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
|
||||||
|
req.Header = opts.HTTPHeader.Clone()
|
||||||
|
req.Header.Set("Connection", "Upgrade")
|
||||||
|
req.Header.Set("Upgrade", "websocket")
|
||||||
|
req.Header.Set("Sec-WebSocket-Version", "13")
|
||||||
|
req.Header.Set("Sec-WebSocket-Key", secWebSocketKey)
|
||||||
|
if len(opts.Subprotocols) > 0 {
|
||||||
|
req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
|
||||||
|
}
|
||||||
|
if copts != nil {
|
||||||
|
copts.setHeader(req.Header)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := opts.HTTPClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to send handshake request: %w", err)
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func secWebSocketKey(rr io.Reader) (string, error) {
|
||||||
|
if rr == nil {
|
||||||
|
rr = rand.Reader
|
||||||
|
}
|
||||||
|
b := make([]byte, 16)
|
||||||
|
_, err := io.ReadFull(rr, b)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to read random data from rand.Reader: %w", err)
|
||||||
|
}
|
||||||
|
return base64.StdEncoding.EncodeToString(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) {
|
||||||
|
if resp.StatusCode != http.StatusSwitchingProtocols {
|
||||||
|
return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !headerContainsTokenIgnoreCase(resp.Header, "Connection", "Upgrade") {
|
||||||
|
return nil, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection"))
|
||||||
|
}
|
||||||
|
|
||||||
|
if !headerContainsTokenIgnoreCase(resp.Header, "Upgrade", "WebSocket") {
|
||||||
|
return nil, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade"))
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(secWebSocketKey) {
|
||||||
|
return nil, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q",
|
||||||
|
resp.Header.Get("Sec-WebSocket-Accept"),
|
||||||
|
secWebSocketKey,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := verifySubprotocol(opts.Subprotocols, resp)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return verifyServerExtensions(copts, resp.Header)
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifySubprotocol(subprotos []string, resp *http.Response) error {
|
||||||
|
proto := resp.Header.Get("Sec-WebSocket-Protocol")
|
||||||
|
if proto == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, sp2 := range subprotos {
|
||||||
|
if strings.EqualFold(sp2, proto) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto)
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyServerExtensions(copts *compressionOptions, h http.Header) (*compressionOptions, error) {
|
||||||
|
exts := websocketExtensions(h)
|
||||||
|
if len(exts) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ext := exts[0]
|
||||||
|
if ext.name != "permessage-deflate" || len(exts) > 1 || copts == nil {
|
||||||
|
return nil, fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:])
|
||||||
|
}
|
||||||
|
|
||||||
|
copts = &*copts
|
||||||
|
|
||||||
|
for _, p := range ext.params {
|
||||||
|
switch p {
|
||||||
|
case "client_no_context_takeover":
|
||||||
|
copts.clientNoContextTakeover = true
|
||||||
|
continue
|
||||||
|
case "server_no_context_takeover":
|
||||||
|
copts.serverNoContextTakeover = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
|
||||||
|
}
|
||||||
|
|
||||||
|
return copts, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var bufioReaderPool sync.Pool
|
||||||
|
|
||||||
|
func getBufioReader(r io.Reader) *bufio.Reader {
|
||||||
|
br, ok := bufioReaderPool.Get().(*bufio.Reader)
|
||||||
|
if !ok {
|
||||||
|
return bufio.NewReader(r)
|
||||||
|
}
|
||||||
|
br.Reset(r)
|
||||||
|
return br
|
||||||
|
}
|
||||||
|
|
||||||
|
func putBufioReader(br *bufio.Reader) {
|
||||||
|
bufioReaderPool.Put(br)
|
||||||
|
}
|
||||||
|
|
||||||
|
var bufioWriterPool sync.Pool
|
||||||
|
|
||||||
|
func getBufioWriter(w io.Writer) *bufio.Writer {
|
||||||
|
bw, ok := bufioWriterPool.Get().(*bufio.Writer)
|
||||||
|
if !ok {
|
||||||
|
return bufio.NewWriter(w)
|
||||||
|
}
|
||||||
|
bw.Reset(w)
|
||||||
|
return bw
|
||||||
|
}
|
||||||
|
|
||||||
|
func putBufioWriter(bw *bufio.Writer) {
|
||||||
|
bufioWriterPool.Put(bw)
|
||||||
|
}
|
|
@ -0,0 +1,32 @@
|
||||||
|
// +build !js
|
||||||
|
|
||||||
|
// Package websocket implements the RFC 6455 WebSocket protocol.
|
||||||
|
//
|
||||||
|
// https://tools.ietf.org/html/rfc6455
|
||||||
|
//
|
||||||
|
// Use Dial to dial a WebSocket server.
|
||||||
|
//
|
||||||
|
// Use Accept to accept a WebSocket client.
|
||||||
|
//
|
||||||
|
// Conn represents the resulting WebSocket connection.
|
||||||
|
//
|
||||||
|
// The examples are the best way to understand how to correctly use the library.
|
||||||
|
//
|
||||||
|
// The wsjson and wspb subpackages contain helpers for JSON and protobuf messages.
|
||||||
|
//
|
||||||
|
// More documentation at https://nhooyr.io/websocket.
|
||||||
|
//
|
||||||
|
// Wasm
|
||||||
|
//
|
||||||
|
// The client side supports compiling to Wasm.
|
||||||
|
// It wraps the WebSocket browser API.
|
||||||
|
//
|
||||||
|
// See https://developer.mozilla.org/en-US/docs/Web/API/WebSocket
|
||||||
|
//
|
||||||
|
// Some important caveats to be aware of:
|
||||||
|
//
|
||||||
|
// - Accept always errors out
|
||||||
|
// - Conn.Ping is no-op
|
||||||
|
// - HTTPClient, HTTPHeader and CompressionMode in DialOptions are no-op
|
||||||
|
// - *http.Response from Dial is &http.Response{} with a 101 status code on success
|
||||||
|
package websocket // import "nhooyr.io/websocket"
|
|
@ -0,0 +1,294 @@
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"math"
|
||||||
|
"math/bits"
|
||||||
|
|
||||||
|
"nhooyr.io/websocket/internal/errd"
|
||||||
|
)
|
||||||
|
|
||||||
|
// opcode represents a WebSocket opcode.
|
||||||
|
type opcode int
|
||||||
|
|
||||||
|
// https://tools.ietf.org/html/rfc6455#section-11.8.
|
||||||
|
const (
|
||||||
|
opContinuation opcode = iota
|
||||||
|
opText
|
||||||
|
opBinary
|
||||||
|
// 3 - 7 are reserved for further non-control frames.
|
||||||
|
_
|
||||||
|
_
|
||||||
|
_
|
||||||
|
_
|
||||||
|
_
|
||||||
|
opClose
|
||||||
|
opPing
|
||||||
|
opPong
|
||||||
|
// 11-16 are reserved for further control frames.
|
||||||
|
)
|
||||||
|
|
||||||
|
// header represents a WebSocket frame header.
|
||||||
|
// See https://tools.ietf.org/html/rfc6455#section-5.2.
|
||||||
|
type header struct {
|
||||||
|
fin bool
|
||||||
|
rsv1 bool
|
||||||
|
rsv2 bool
|
||||||
|
rsv3 bool
|
||||||
|
opcode opcode
|
||||||
|
|
||||||
|
payloadLength int64
|
||||||
|
|
||||||
|
masked bool
|
||||||
|
maskKey uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// readFrameHeader reads a header from the reader.
|
||||||
|
// See https://tools.ietf.org/html/rfc6455#section-5.2.
|
||||||
|
func readFrameHeader(r *bufio.Reader, readBuf []byte) (h header, err error) {
|
||||||
|
defer errd.Wrap(&err, "failed to read frame header")
|
||||||
|
|
||||||
|
b, err := r.ReadByte()
|
||||||
|
if err != nil {
|
||||||
|
return header{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
h.fin = b&(1<<7) != 0
|
||||||
|
h.rsv1 = b&(1<<6) != 0
|
||||||
|
h.rsv2 = b&(1<<5) != 0
|
||||||
|
h.rsv3 = b&(1<<4) != 0
|
||||||
|
|
||||||
|
h.opcode = opcode(b & 0xf)
|
||||||
|
|
||||||
|
b, err = r.ReadByte()
|
||||||
|
if err != nil {
|
||||||
|
return header{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
h.masked = b&(1<<7) != 0
|
||||||
|
|
||||||
|
payloadLength := b &^ (1 << 7)
|
||||||
|
switch {
|
||||||
|
case payloadLength < 126:
|
||||||
|
h.payloadLength = int64(payloadLength)
|
||||||
|
case payloadLength == 126:
|
||||||
|
_, err = io.ReadFull(r, readBuf[:2])
|
||||||
|
h.payloadLength = int64(binary.BigEndian.Uint16(readBuf))
|
||||||
|
case payloadLength == 127:
|
||||||
|
_, err = io.ReadFull(r, readBuf)
|
||||||
|
h.payloadLength = int64(binary.BigEndian.Uint64(readBuf))
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return header{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.payloadLength < 0 {
|
||||||
|
return header{}, fmt.Errorf("received negative payload length: %v", h.payloadLength)
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.masked {
|
||||||
|
_, err = io.ReadFull(r, readBuf[:4])
|
||||||
|
if err != nil {
|
||||||
|
return header{}, err
|
||||||
|
}
|
||||||
|
h.maskKey = binary.LittleEndian.Uint32(readBuf)
|
||||||
|
}
|
||||||
|
|
||||||
|
return h, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// maxControlPayload is the maximum length of a control frame payload.
|
||||||
|
// See https://tools.ietf.org/html/rfc6455#section-5.5.
|
||||||
|
const maxControlPayload = 125
|
||||||
|
|
||||||
|
// writeFrameHeader writes the bytes of the header to w.
|
||||||
|
// See https://tools.ietf.org/html/rfc6455#section-5.2
|
||||||
|
func writeFrameHeader(h header, w *bufio.Writer, buf []byte) (err error) {
|
||||||
|
defer errd.Wrap(&err, "failed to write frame header")
|
||||||
|
|
||||||
|
var b byte
|
||||||
|
if h.fin {
|
||||||
|
b |= 1 << 7
|
||||||
|
}
|
||||||
|
if h.rsv1 {
|
||||||
|
b |= 1 << 6
|
||||||
|
}
|
||||||
|
if h.rsv2 {
|
||||||
|
b |= 1 << 5
|
||||||
|
}
|
||||||
|
if h.rsv3 {
|
||||||
|
b |= 1 << 4
|
||||||
|
}
|
||||||
|
|
||||||
|
b |= byte(h.opcode)
|
||||||
|
|
||||||
|
err = w.WriteByte(b)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
lengthByte := byte(0)
|
||||||
|
if h.masked {
|
||||||
|
lengthByte |= 1 << 7
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case h.payloadLength > math.MaxUint16:
|
||||||
|
lengthByte |= 127
|
||||||
|
case h.payloadLength > 125:
|
||||||
|
lengthByte |= 126
|
||||||
|
case h.payloadLength >= 0:
|
||||||
|
lengthByte |= byte(h.payloadLength)
|
||||||
|
}
|
||||||
|
err = w.WriteByte(lengthByte)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case h.payloadLength > math.MaxUint16:
|
||||||
|
binary.BigEndian.PutUint64(buf, uint64(h.payloadLength))
|
||||||
|
_, err = w.Write(buf)
|
||||||
|
case h.payloadLength > 125:
|
||||||
|
binary.BigEndian.PutUint16(buf, uint16(h.payloadLength))
|
||||||
|
_, err = w.Write(buf[:2])
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.masked {
|
||||||
|
binary.LittleEndian.PutUint32(buf, h.maskKey)
|
||||||
|
_, err = w.Write(buf[:4])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// mask applies the WebSocket masking algorithm to p
|
||||||
|
// with the given key.
|
||||||
|
// See https://tools.ietf.org/html/rfc6455#section-5.3
|
||||||
|
//
|
||||||
|
// The returned value is the correctly rotated key to
|
||||||
|
// to continue to mask/unmask the message.
|
||||||
|
//
|
||||||
|
// It is optimized for LittleEndian and expects the key
|
||||||
|
// to be in little endian.
|
||||||
|
//
|
||||||
|
// See https://github.com/golang/go/issues/31586
|
||||||
|
func mask(key uint32, b []byte) uint32 {
|
||||||
|
if len(b) >= 8 {
|
||||||
|
key64 := uint64(key)<<32 | uint64(key)
|
||||||
|
|
||||||
|
// At some point in the future we can clean these unrolled loops up.
|
||||||
|
// See https://github.com/golang/go/issues/31586#issuecomment-487436401
|
||||||
|
|
||||||
|
// Then we xor until b is less than 128 bytes.
|
||||||
|
for len(b) >= 128 {
|
||||||
|
v := binary.LittleEndian.Uint64(b)
|
||||||
|
binary.LittleEndian.PutUint64(b, v^key64)
|
||||||
|
v = binary.LittleEndian.Uint64(b[8:16])
|
||||||
|
binary.LittleEndian.PutUint64(b[8:16], v^key64)
|
||||||
|
v = binary.LittleEndian.Uint64(b[16:24])
|
||||||
|
binary.LittleEndian.PutUint64(b[16:24], v^key64)
|
||||||
|
v = binary.LittleEndian.Uint64(b[24:32])
|
||||||
|
binary.LittleEndian.PutUint64(b[24:32], v^key64)
|
||||||
|
v = binary.LittleEndian.Uint64(b[32:40])
|
||||||
|
binary.LittleEndian.PutUint64(b[32:40], v^key64)
|
||||||
|
v = binary.LittleEndian.Uint64(b[40:48])
|
||||||
|
binary.LittleEndian.PutUint64(b[40:48], v^key64)
|
||||||
|
v = binary.LittleEndian.Uint64(b[48:56])
|
||||||
|
binary.LittleEndian.PutUint64(b[48:56], v^key64)
|
||||||
|
v = binary.LittleEndian.Uint64(b[56:64])
|
||||||
|
binary.LittleEndian.PutUint64(b[56:64], v^key64)
|
||||||
|
v = binary.LittleEndian.Uint64(b[64:72])
|
||||||
|
binary.LittleEndian.PutUint64(b[64:72], v^key64)
|
||||||
|
v = binary.LittleEndian.Uint64(b[72:80])
|
||||||
|
binary.LittleEndian.PutUint64(b[72:80], v^key64)
|
||||||
|
v = binary.LittleEndian.Uint64(b[80:88])
|
||||||
|
binary.LittleEndian.PutUint64(b[80:88], v^key64)
|
||||||
|
v = binary.LittleEndian.Uint64(b[88:96])
|
||||||
|
binary.LittleEndian.PutUint64(b[88:96], v^key64)
|
||||||
|
v = binary.LittleEndian.Uint64(b[96:104])
|
||||||
|
binary.LittleEndian.PutUint64(b[96:104], v^key64)
|
||||||
|
v = binary.LittleEndian.Uint64(b[104:112])
|
||||||
|
binary.LittleEndian.PutUint64(b[104:112], v^key64)
|
||||||
|
v = binary.LittleEndian.Uint64(b[112:120])
|
||||||
|
binary.LittleEndian.PutUint64(b[112:120], v^key64)
|
||||||
|
v = binary.LittleEndian.Uint64(b[120:128])
|
||||||
|
binary.LittleEndian.PutUint64(b[120:128], v^key64)
|
||||||
|
b = b[128:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then we xor until b is less than 64 bytes.
|
||||||
|
for len(b) >= 64 {
|
||||||
|
v := binary.LittleEndian.Uint64(b)
|
||||||
|
binary.LittleEndian.PutUint64(b, v^key64)
|
||||||
|
v = binary.LittleEndian.Uint64(b[8:16])
|
||||||
|
binary.LittleEndian.PutUint64(b[8:16], v^key64)
|
||||||
|
v = binary.LittleEndian.Uint64(b[16:24])
|
||||||
|
binary.LittleEndian.PutUint64(b[16:24], v^key64)
|
||||||
|
v = binary.LittleEndian.Uint64(b[24:32])
|
||||||
|
binary.LittleEndian.PutUint64(b[24:32], v^key64)
|
||||||
|
v = binary.LittleEndian.Uint64(b[32:40])
|
||||||
|
binary.LittleEndian.PutUint64(b[32:40], v^key64)
|
||||||
|
v = binary.LittleEndian.Uint64(b[40:48])
|
||||||
|
binary.LittleEndian.PutUint64(b[40:48], v^key64)
|
||||||
|
v = binary.LittleEndian.Uint64(b[48:56])
|
||||||
|
binary.LittleEndian.PutUint64(b[48:56], v^key64)
|
||||||
|
v = binary.LittleEndian.Uint64(b[56:64])
|
||||||
|
binary.LittleEndian.PutUint64(b[56:64], v^key64)
|
||||||
|
b = b[64:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then we xor until b is less than 32 bytes.
|
||||||
|
for len(b) >= 32 {
|
||||||
|
v := binary.LittleEndian.Uint64(b)
|
||||||
|
binary.LittleEndian.PutUint64(b, v^key64)
|
||||||
|
v = binary.LittleEndian.Uint64(b[8:16])
|
||||||
|
binary.LittleEndian.PutUint64(b[8:16], v^key64)
|
||||||
|
v = binary.LittleEndian.Uint64(b[16:24])
|
||||||
|
binary.LittleEndian.PutUint64(b[16:24], v^key64)
|
||||||
|
v = binary.LittleEndian.Uint64(b[24:32])
|
||||||
|
binary.LittleEndian.PutUint64(b[24:32], v^key64)
|
||||||
|
b = b[32:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then we xor until b is less than 16 bytes.
|
||||||
|
for len(b) >= 16 {
|
||||||
|
v := binary.LittleEndian.Uint64(b)
|
||||||
|
binary.LittleEndian.PutUint64(b, v^key64)
|
||||||
|
v = binary.LittleEndian.Uint64(b[8:16])
|
||||||
|
binary.LittleEndian.PutUint64(b[8:16], v^key64)
|
||||||
|
b = b[16:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then we xor until b is less than 8 bytes.
|
||||||
|
for len(b) >= 8 {
|
||||||
|
v := binary.LittleEndian.Uint64(b)
|
||||||
|
binary.LittleEndian.PutUint64(b, v^key64)
|
||||||
|
b = b[8:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then we xor until b is less than 4 bytes.
|
||||||
|
for len(b) >= 4 {
|
||||||
|
v := binary.LittleEndian.Uint32(b)
|
||||||
|
binary.LittleEndian.PutUint32(b, v^key)
|
||||||
|
b = b[4:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// xor remaining bytes.
|
||||||
|
for i := range b {
|
||||||
|
b[i] ^= byte(key)
|
||||||
|
key = bits.RotateLeft32(key, -8)
|
||||||
|
}
|
||||||
|
|
||||||
|
return key
|
||||||
|
}
|
|
@ -0,0 +1,24 @@
|
||||||
|
package bpool
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
var bpool sync.Pool
|
||||||
|
|
||||||
|
// Get returns a buffer from the pool or creates a new one if
|
||||||
|
// the pool is empty.
|
||||||
|
func Get() *bytes.Buffer {
|
||||||
|
b := bpool.Get()
|
||||||
|
if b == nil {
|
||||||
|
return &bytes.Buffer{}
|
||||||
|
}
|
||||||
|
return b.(*bytes.Buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put returns a buffer into the pool.
|
||||||
|
func Put(b *bytes.Buffer) {
|
||||||
|
b.Reset()
|
||||||
|
bpool.Put(b)
|
||||||
|
}
|
|
@ -0,0 +1,14 @@
|
||||||
|
package errd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Wrap wraps err with fmt.Errorf if err is non nil.
|
||||||
|
// Intended for use with defer and a named error return.
|
||||||
|
// Inspired by https://github.com/golang/go/issues/32676.
|
||||||
|
func Wrap(err *error, f string, v ...interface{}) {
|
||||||
|
if *err != nil {
|
||||||
|
*err = fmt.Errorf(f+": %w", append(v, *err)...)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,170 @@
|
||||||
|
// +build js
|
||||||
|
|
||||||
|
// Package wsjs implements typed access to the browser javascript WebSocket API.
|
||||||
|
//
|
||||||
|
// https://developer.mozilla.org/en-US/docs/Web/API/WebSocket
|
||||||
|
package wsjs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"syscall/js"
|
||||||
|
)
|
||||||
|
|
||||||
|
func handleJSError(err *error, onErr func()) {
|
||||||
|
r := recover()
|
||||||
|
|
||||||
|
if jsErr, ok := r.(js.Error); ok {
|
||||||
|
*err = jsErr
|
||||||
|
|
||||||
|
if onErr != nil {
|
||||||
|
onErr()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r != nil {
|
||||||
|
panic(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// New is a wrapper around the javascript WebSocket constructor.
|
||||||
|
func New(url string, protocols []string) (c WebSocket, err error) {
|
||||||
|
defer handleJSError(&err, func() {
|
||||||
|
c = WebSocket{}
|
||||||
|
})
|
||||||
|
|
||||||
|
jsProtocols := make([]interface{}, len(protocols))
|
||||||
|
for i, p := range protocols {
|
||||||
|
jsProtocols[i] = p
|
||||||
|
}
|
||||||
|
|
||||||
|
c = WebSocket{
|
||||||
|
v: js.Global().Get("WebSocket").New(url, jsProtocols),
|
||||||
|
}
|
||||||
|
|
||||||
|
c.setBinaryType("arraybuffer")
|
||||||
|
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WebSocket is a wrapper around a javascript WebSocket object.
|
||||||
|
type WebSocket struct {
|
||||||
|
v js.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c WebSocket) setBinaryType(typ string) {
|
||||||
|
c.v.Set("binaryType", string(typ))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c WebSocket) addEventListener(eventType string, fn func(e js.Value)) func() {
|
||||||
|
f := js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
||||||
|
fn(args[0])
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
c.v.Call("addEventListener", eventType, f)
|
||||||
|
|
||||||
|
return func() {
|
||||||
|
c.v.Call("removeEventListener", eventType, f)
|
||||||
|
f.Release()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseEvent is the type passed to a WebSocket close handler.
|
||||||
|
type CloseEvent struct {
|
||||||
|
Code uint16
|
||||||
|
Reason string
|
||||||
|
WasClean bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnClose registers a function to be called when the WebSocket is closed.
|
||||||
|
func (c WebSocket) OnClose(fn func(CloseEvent)) (remove func()) {
|
||||||
|
return c.addEventListener("close", func(e js.Value) {
|
||||||
|
ce := CloseEvent{
|
||||||
|
Code: uint16(e.Get("code").Int()),
|
||||||
|
Reason: e.Get("reason").String(),
|
||||||
|
WasClean: e.Get("wasClean").Bool(),
|
||||||
|
}
|
||||||
|
fn(ce)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnError registers a function to be called when there is an error
|
||||||
|
// with the WebSocket.
|
||||||
|
func (c WebSocket) OnError(fn func(e js.Value)) (remove func()) {
|
||||||
|
return c.addEventListener("error", fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MessageEvent is the type passed to a message handler.
|
||||||
|
type MessageEvent struct {
|
||||||
|
// string or []byte.
|
||||||
|
Data interface{}
|
||||||
|
|
||||||
|
// There are more fields to the interface but we don't use them.
|
||||||
|
// See https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnMessage registers a function to be called when the WebSocket receives a message.
|
||||||
|
func (c WebSocket) OnMessage(fn func(m MessageEvent)) (remove func()) {
|
||||||
|
return c.addEventListener("message", func(e js.Value) {
|
||||||
|
var data interface{}
|
||||||
|
|
||||||
|
arrayBuffer := e.Get("data")
|
||||||
|
if arrayBuffer.Type() == js.TypeString {
|
||||||
|
data = arrayBuffer.String()
|
||||||
|
} else {
|
||||||
|
data = extractArrayBuffer(arrayBuffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
me := MessageEvent{
|
||||||
|
Data: data,
|
||||||
|
}
|
||||||
|
fn(me)
|
||||||
|
|
||||||
|
return
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subprotocol returns the WebSocket subprotocol in use.
|
||||||
|
func (c WebSocket) Subprotocol() string {
|
||||||
|
return c.v.Get("protocol").String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnOpen registers a function to be called when the WebSocket is opened.
|
||||||
|
func (c WebSocket) OnOpen(fn func(e js.Value)) (remove func()) {
|
||||||
|
return c.addEventListener("open", fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the WebSocket with the given code and reason.
|
||||||
|
func (c WebSocket) Close(code int, reason string) (err error) {
|
||||||
|
defer handleJSError(&err, nil)
|
||||||
|
c.v.Call("close", code, reason)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendText sends the given string as a text message
|
||||||
|
// on the WebSocket.
|
||||||
|
func (c WebSocket) SendText(v string) (err error) {
|
||||||
|
defer handleJSError(&err, nil)
|
||||||
|
c.v.Call("send", v)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendBytes sends the given message as a binary message
|
||||||
|
// on the WebSocket.
|
||||||
|
func (c WebSocket) SendBytes(v []byte) (err error) {
|
||||||
|
defer handleJSError(&err, nil)
|
||||||
|
c.v.Call("send", uint8Array(v))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractArrayBuffer(arrayBuffer js.Value) []byte {
|
||||||
|
uint8Array := js.Global().Get("Uint8Array").New(arrayBuffer)
|
||||||
|
dst := make([]byte, uint8Array.Length())
|
||||||
|
js.CopyBytesToGo(dst, uint8Array)
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
func uint8Array(src []byte) js.Value {
|
||||||
|
uint8Array := js.Global().Get("Uint8Array").New(len(src))
|
||||||
|
js.CopyBytesToJS(uint8Array, src)
|
||||||
|
return uint8Array
|
||||||
|
}
|
|
@ -0,0 +1,25 @@
|
||||||
|
package xsync
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Go allows running a function in another goroutine
|
||||||
|
// and waiting for its error.
|
||||||
|
func Go(fn func() error) <-chan error {
|
||||||
|
errs := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
r := recover()
|
||||||
|
if r != nil {
|
||||||
|
select {
|
||||||
|
case errs <- fmt.Errorf("panic in go fn: %v", r):
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
errs <- fn()
|
||||||
|
}()
|
||||||
|
|
||||||
|
return errs
|
||||||
|
}
|
|
@ -0,0 +1,23 @@
|
||||||
|
package xsync
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Int64 represents an atomic int64.
|
||||||
|
type Int64 struct {
|
||||||
|
// We do not use atomic.Load/StoreInt64 since it does not
|
||||||
|
// work on 32 bit computers but we need 64 bit integers.
|
||||||
|
i atomic.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load loads the int64.
|
||||||
|
func (v *Int64) Load() int64 {
|
||||||
|
i, _ := v.i.Load().(int64)
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store stores the int64.
|
||||||
|
func (v *Int64) Store(i int64) {
|
||||||
|
v.i.Store(i)
|
||||||
|
}
|
|
@ -0,0 +1,166 @@
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"math"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NetConn converts a *websocket.Conn into a net.Conn.
|
||||||
|
//
|
||||||
|
// It's for tunneling arbitrary protocols over WebSockets.
|
||||||
|
// Few users of the library will need this but it's tricky to implement
|
||||||
|
// correctly and so provided in the library.
|
||||||
|
// See https://github.com/nhooyr/websocket/issues/100.
|
||||||
|
//
|
||||||
|
// Every Write to the net.Conn will correspond to a message write of
|
||||||
|
// the given type on *websocket.Conn.
|
||||||
|
//
|
||||||
|
// The passed ctx bounds the lifetime of the net.Conn. If cancelled,
|
||||||
|
// all reads and writes on the net.Conn will be cancelled.
|
||||||
|
//
|
||||||
|
// If a message is read that is not of the correct type, the connection
|
||||||
|
// will be closed with StatusUnsupportedData and an error will be returned.
|
||||||
|
//
|
||||||
|
// Close will close the *websocket.Conn with StatusNormalClosure.
|
||||||
|
//
|
||||||
|
// When a deadline is hit, the connection will be closed. This is
|
||||||
|
// different from most net.Conn implementations where only the
|
||||||
|
// reading/writing goroutines are interrupted but the connection is kept alive.
|
||||||
|
//
|
||||||
|
// The Addr methods will return a mock net.Addr that returns "websocket" for Network
|
||||||
|
// and "websocket/unknown-addr" for String.
|
||||||
|
//
|
||||||
|
// A received StatusNormalClosure or StatusGoingAway close frame will be translated to
|
||||||
|
// io.EOF when reading.
|
||||||
|
func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn {
|
||||||
|
nc := &netConn{
|
||||||
|
c: c,
|
||||||
|
msgType: msgType,
|
||||||
|
}
|
||||||
|
|
||||||
|
var cancel context.CancelFunc
|
||||||
|
nc.writeContext, cancel = context.WithCancel(ctx)
|
||||||
|
nc.writeTimer = time.AfterFunc(math.MaxInt64, cancel)
|
||||||
|
if !nc.writeTimer.Stop() {
|
||||||
|
<-nc.writeTimer.C
|
||||||
|
}
|
||||||
|
|
||||||
|
nc.readContext, cancel = context.WithCancel(ctx)
|
||||||
|
nc.readTimer = time.AfterFunc(math.MaxInt64, cancel)
|
||||||
|
if !nc.readTimer.Stop() {
|
||||||
|
<-nc.readTimer.C
|
||||||
|
}
|
||||||
|
|
||||||
|
return nc
|
||||||
|
}
|
||||||
|
|
||||||
|
type netConn struct {
|
||||||
|
c *Conn
|
||||||
|
msgType MessageType
|
||||||
|
|
||||||
|
writeTimer *time.Timer
|
||||||
|
writeContext context.Context
|
||||||
|
|
||||||
|
readTimer *time.Timer
|
||||||
|
readContext context.Context
|
||||||
|
|
||||||
|
readMu sync.Mutex
|
||||||
|
eofed bool
|
||||||
|
reader io.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ net.Conn = &netConn{}
|
||||||
|
|
||||||
|
func (c *netConn) Close() error {
|
||||||
|
return c.c.Close(StatusNormalClosure, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *netConn) Write(p []byte) (int, error) {
|
||||||
|
err := c.c.Write(c.writeContext, c.msgType, p)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *netConn) Read(p []byte) (int, error) {
|
||||||
|
c.readMu.Lock()
|
||||||
|
defer c.readMu.Unlock()
|
||||||
|
|
||||||
|
if c.eofed {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.reader == nil {
|
||||||
|
typ, r, err := c.c.Reader(c.readContext)
|
||||||
|
if err != nil {
|
||||||
|
switch CloseStatus(err) {
|
||||||
|
case StatusNormalClosure, StatusGoingAway:
|
||||||
|
c.eofed = true
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if typ != c.msgType {
|
||||||
|
err := fmt.Errorf("unexpected frame type read (expected %v): %v", c.msgType, typ)
|
||||||
|
c.c.Close(StatusUnsupportedData, err.Error())
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
c.reader = r
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := c.reader.Read(p)
|
||||||
|
if err == io.EOF {
|
||||||
|
c.reader = nil
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
type websocketAddr struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a websocketAddr) Network() string {
|
||||||
|
return "websocket"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a websocketAddr) String() string {
|
||||||
|
return "websocket/unknown-addr"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *netConn) RemoteAddr() net.Addr {
|
||||||
|
return websocketAddr{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *netConn) LocalAddr() net.Addr {
|
||||||
|
return websocketAddr{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *netConn) SetDeadline(t time.Time) error {
|
||||||
|
c.SetWriteDeadline(t)
|
||||||
|
c.SetReadDeadline(t)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *netConn) SetWriteDeadline(t time.Time) error {
|
||||||
|
if t.IsZero() {
|
||||||
|
c.writeTimer.Stop()
|
||||||
|
} else {
|
||||||
|
c.writeTimer.Reset(t.Sub(time.Now()))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *netConn) SetReadDeadline(t time.Time) error {
|
||||||
|
if t.IsZero() {
|
||||||
|
c.readTimer.Stop()
|
||||||
|
} else {
|
||||||
|
c.readTimer.Reset(t.Sub(time.Now()))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -0,0 +1,474 @@
|
||||||
|
// +build !js
|
||||||
|
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"nhooyr.io/websocket/internal/errd"
|
||||||
|
"nhooyr.io/websocket/internal/xsync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Reader reads from the connection until until there is a WebSocket
|
||||||
|
// data message to be read. It will handle ping, pong and close frames as appropriate.
|
||||||
|
//
|
||||||
|
// It returns the type of the message and an io.Reader to read it.
|
||||||
|
// The passed context will also bound the reader.
|
||||||
|
// Ensure you read to EOF otherwise the connection will hang.
|
||||||
|
//
|
||||||
|
// Call CloseRead if you do not expect any data messages from the peer.
|
||||||
|
//
|
||||||
|
// Only one Reader may be open at a time.
|
||||||
|
func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
|
||||||
|
return c.reader(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read is a convenience method around Reader to read a single message
|
||||||
|
// from the connection.
|
||||||
|
func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
|
||||||
|
typ, r, err := c.Reader(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return 0, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
b, err := ioutil.ReadAll(r)
|
||||||
|
return typ, b, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseRead starts a goroutine to read from the connection until it is closed
|
||||||
|
// or a data message is received.
|
||||||
|
//
|
||||||
|
// Once CloseRead is called you cannot read any messages from the connection.
|
||||||
|
// The returned context will be cancelled when the connection is closed.
|
||||||
|
//
|
||||||
|
// If a data message is received, the connection will be closed with StatusPolicyViolation.
|
||||||
|
//
|
||||||
|
// Call CloseRead when you do not expect to read any more messages.
|
||||||
|
// Since it actively reads from the connection, it will ensure that ping, pong and close
|
||||||
|
// frames are responded to. This means c.Ping and c.Close will still work as expected.
|
||||||
|
func (c *Conn) CloseRead(ctx context.Context) context.Context {
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
go func() {
|
||||||
|
defer cancel()
|
||||||
|
c.Reader(ctx)
|
||||||
|
c.Close(StatusPolicyViolation, "unexpected data message")
|
||||||
|
}()
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetReadLimit sets the max number of bytes to read for a single message.
|
||||||
|
// It applies to the Reader and Read methods.
|
||||||
|
//
|
||||||
|
// By default, the connection has a message read limit of 32768 bytes.
|
||||||
|
//
|
||||||
|
// When the limit is hit, the connection will be closed with StatusMessageTooBig.
|
||||||
|
func (c *Conn) SetReadLimit(n int64) {
|
||||||
|
// We add read one more byte than the limit in case
|
||||||
|
// there is a fin frame that needs to be read.
|
||||||
|
c.msgReader.limitReader.limit.Store(n + 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
const defaultReadLimit = 32768
|
||||||
|
|
||||||
|
func newMsgReader(c *Conn) *msgReader {
|
||||||
|
mr := &msgReader{
|
||||||
|
c: c,
|
||||||
|
fin: true,
|
||||||
|
}
|
||||||
|
mr.readFunc = mr.read
|
||||||
|
|
||||||
|
mr.limitReader = newLimitReader(c, mr.readFunc, defaultReadLimit+1)
|
||||||
|
return mr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mr *msgReader) resetFlate() {
|
||||||
|
if mr.flateContextTakeover() {
|
||||||
|
mr.dict.init(32768)
|
||||||
|
}
|
||||||
|
if mr.flateBufio == nil {
|
||||||
|
mr.flateBufio = getBufioReader(mr.readFunc)
|
||||||
|
}
|
||||||
|
|
||||||
|
mr.flateReader = getFlateReader(mr.flateBufio, mr.dict.buf)
|
||||||
|
mr.limitReader.r = mr.flateReader
|
||||||
|
mr.flateTail.Reset(deflateMessageTail)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mr *msgReader) putFlateReader() {
|
||||||
|
if mr.flateReader != nil {
|
||||||
|
putFlateReader(mr.flateReader)
|
||||||
|
mr.flateReader = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mr *msgReader) close() {
|
||||||
|
mr.c.readMu.forceLock()
|
||||||
|
mr.putFlateReader()
|
||||||
|
mr.dict.close()
|
||||||
|
if mr.flateBufio != nil {
|
||||||
|
putBufioReader(mr.flateBufio)
|
||||||
|
}
|
||||||
|
|
||||||
|
if mr.c.client {
|
||||||
|
putBufioReader(mr.c.br)
|
||||||
|
mr.c.br = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mr *msgReader) flateContextTakeover() bool {
|
||||||
|
if mr.c.client {
|
||||||
|
return !mr.c.copts.serverNoContextTakeover
|
||||||
|
}
|
||||||
|
return !mr.c.copts.clientNoContextTakeover
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) readRSV1Illegal(h header) bool {
|
||||||
|
// If compression is disabled, rsv1 is illegal.
|
||||||
|
if !c.flate() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
// rsv1 is only allowed on data frames beginning messages.
|
||||||
|
if h.opcode != opText && h.opcode != opBinary {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) readLoop(ctx context.Context) (header, error) {
|
||||||
|
for {
|
||||||
|
h, err := c.readFrameHeader(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return header{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.rsv1 && c.readRSV1Illegal(h) || h.rsv2 || h.rsv3 {
|
||||||
|
err := fmt.Errorf("received header with unexpected rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3)
|
||||||
|
c.writeError(StatusProtocolError, err)
|
||||||
|
return header{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !c.client && !h.masked {
|
||||||
|
return header{}, errors.New("received unmasked frame from client")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch h.opcode {
|
||||||
|
case opClose, opPing, opPong:
|
||||||
|
err = c.handleControl(ctx, h)
|
||||||
|
if err != nil {
|
||||||
|
// Pass through CloseErrors when receiving a close frame.
|
||||||
|
if h.opcode == opClose && CloseStatus(err) != -1 {
|
||||||
|
return header{}, err
|
||||||
|
}
|
||||||
|
return header{}, fmt.Errorf("failed to handle control frame %v: %w", h.opcode, err)
|
||||||
|
}
|
||||||
|
case opContinuation, opText, opBinary:
|
||||||
|
return h, nil
|
||||||
|
default:
|
||||||
|
err := fmt.Errorf("received unknown opcode %v", h.opcode)
|
||||||
|
c.writeError(StatusProtocolError, err)
|
||||||
|
return header{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) readFrameHeader(ctx context.Context) (header, error) {
|
||||||
|
select {
|
||||||
|
case <-c.closed:
|
||||||
|
return header{}, c.closeErr
|
||||||
|
case c.readTimeout <- ctx:
|
||||||
|
}
|
||||||
|
|
||||||
|
h, err := readFrameHeader(c.br, c.readHeaderBuf[:])
|
||||||
|
if err != nil {
|
||||||
|
select {
|
||||||
|
case <-c.closed:
|
||||||
|
return header{}, c.closeErr
|
||||||
|
case <-ctx.Done():
|
||||||
|
return header{}, ctx.Err()
|
||||||
|
default:
|
||||||
|
c.close(err)
|
||||||
|
return header{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-c.closed:
|
||||||
|
return header{}, c.closeErr
|
||||||
|
case c.readTimeout <- context.Background():
|
||||||
|
}
|
||||||
|
|
||||||
|
return h, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) {
|
||||||
|
select {
|
||||||
|
case <-c.closed:
|
||||||
|
return 0, c.closeErr
|
||||||
|
case c.readTimeout <- ctx:
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := io.ReadFull(c.br, p)
|
||||||
|
if err != nil {
|
||||||
|
select {
|
||||||
|
case <-c.closed:
|
||||||
|
return n, c.closeErr
|
||||||
|
case <-ctx.Done():
|
||||||
|
return n, ctx.Err()
|
||||||
|
default:
|
||||||
|
err = fmt.Errorf("failed to read frame payload: %w", err)
|
||||||
|
c.close(err)
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-c.closed:
|
||||||
|
return n, c.closeErr
|
||||||
|
case c.readTimeout <- context.Background():
|
||||||
|
}
|
||||||
|
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) handleControl(ctx context.Context, h header) (err error) {
|
||||||
|
if h.payloadLength < 0 || h.payloadLength > maxControlPayload {
|
||||||
|
err := fmt.Errorf("received control frame payload with invalid length: %d", h.payloadLength)
|
||||||
|
c.writeError(StatusProtocolError, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !h.fin {
|
||||||
|
err := errors.New("received fragmented control frame")
|
||||||
|
c.writeError(StatusProtocolError, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
b := c.readControlBuf[:h.payloadLength]
|
||||||
|
_, err = c.readFramePayload(ctx, b)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.masked {
|
||||||
|
mask(h.maskKey, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch h.opcode {
|
||||||
|
case opPing:
|
||||||
|
return c.writeControl(ctx, opPong, b)
|
||||||
|
case opPong:
|
||||||
|
c.activePingsMu.Lock()
|
||||||
|
pong, ok := c.activePings[string(b)]
|
||||||
|
c.activePingsMu.Unlock()
|
||||||
|
if ok {
|
||||||
|
select {
|
||||||
|
case pong <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
c.readCloseFrameErr = err
|
||||||
|
}()
|
||||||
|
|
||||||
|
ce, err := parseClosePayload(b)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("received invalid close payload: %w", err)
|
||||||
|
c.writeError(StatusProtocolError, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = fmt.Errorf("received close frame: %w", ce)
|
||||||
|
c.setCloseErr(err)
|
||||||
|
c.writeClose(ce.Code, ce.Reason)
|
||||||
|
c.close(err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) {
|
||||||
|
defer errd.Wrap(&err, "failed to get reader")
|
||||||
|
|
||||||
|
err = c.readMu.lock(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return 0, nil, err
|
||||||
|
}
|
||||||
|
defer c.readMu.unlock()
|
||||||
|
|
||||||
|
if !c.msgReader.fin {
|
||||||
|
err = errors.New("previous message not read to completion")
|
||||||
|
c.close(fmt.Errorf("failed to get reader: %w", err))
|
||||||
|
return 0, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
h, err := c.readLoop(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return 0, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.opcode == opContinuation {
|
||||||
|
err := errors.New("received continuation frame without text or binary frame")
|
||||||
|
c.writeError(StatusProtocolError, err)
|
||||||
|
return 0, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.msgReader.reset(ctx, h)
|
||||||
|
|
||||||
|
return MessageType(h.opcode), c.msgReader, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type msgReader struct {
|
||||||
|
c *Conn
|
||||||
|
|
||||||
|
ctx context.Context
|
||||||
|
flate bool
|
||||||
|
flateReader io.Reader
|
||||||
|
flateBufio *bufio.Reader
|
||||||
|
flateTail strings.Reader
|
||||||
|
limitReader *limitReader
|
||||||
|
dict slidingWindow
|
||||||
|
|
||||||
|
fin bool
|
||||||
|
payloadLength int64
|
||||||
|
maskKey uint32
|
||||||
|
|
||||||
|
// readerFunc(mr.Read) to avoid continuous allocations.
|
||||||
|
readFunc readerFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mr *msgReader) reset(ctx context.Context, h header) {
|
||||||
|
mr.ctx = ctx
|
||||||
|
mr.flate = h.rsv1
|
||||||
|
mr.limitReader.reset(mr.readFunc)
|
||||||
|
|
||||||
|
if mr.flate {
|
||||||
|
mr.resetFlate()
|
||||||
|
}
|
||||||
|
|
||||||
|
mr.setFrame(h)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mr *msgReader) setFrame(h header) {
|
||||||
|
mr.fin = h.fin
|
||||||
|
mr.payloadLength = h.payloadLength
|
||||||
|
mr.maskKey = h.maskKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mr *msgReader) Read(p []byte) (n int, err error) {
|
||||||
|
err = mr.c.readMu.lock(mr.ctx)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("failed to read: %w", err)
|
||||||
|
}
|
||||||
|
defer mr.c.readMu.unlock()
|
||||||
|
|
||||||
|
n, err = mr.limitReader.Read(p)
|
||||||
|
if mr.flate && mr.flateContextTakeover() {
|
||||||
|
p = p[:n]
|
||||||
|
mr.dict.write(p)
|
||||||
|
}
|
||||||
|
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate {
|
||||||
|
mr.putFlateReader()
|
||||||
|
return n, io.EOF
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("failed to read: %w", err)
|
||||||
|
mr.c.close(err)
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mr *msgReader) read(p []byte) (int, error) {
|
||||||
|
for {
|
||||||
|
if mr.payloadLength == 0 {
|
||||||
|
if mr.fin {
|
||||||
|
if mr.flate {
|
||||||
|
return mr.flateTail.Read(p)
|
||||||
|
}
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
h, err := mr.c.readLoop(mr.ctx)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if h.opcode != opContinuation {
|
||||||
|
err := errors.New("received new data message without finishing the previous message")
|
||||||
|
mr.c.writeError(StatusProtocolError, err)
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
mr.setFrame(h)
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if int64(len(p)) > mr.payloadLength {
|
||||||
|
p = p[:mr.payloadLength]
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := mr.c.readFramePayload(mr.ctx, p)
|
||||||
|
if err != nil {
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
mr.payloadLength -= int64(n)
|
||||||
|
|
||||||
|
if !mr.c.client {
|
||||||
|
mr.maskKey = mask(mr.maskKey, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type limitReader struct {
|
||||||
|
c *Conn
|
||||||
|
r io.Reader
|
||||||
|
limit xsync.Int64
|
||||||
|
n int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func newLimitReader(c *Conn, r io.Reader, limit int64) *limitReader {
|
||||||
|
lr := &limitReader{
|
||||||
|
c: c,
|
||||||
|
}
|
||||||
|
lr.limit.Store(limit)
|
||||||
|
lr.reset(r)
|
||||||
|
return lr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (lr *limitReader) reset(r io.Reader) {
|
||||||
|
lr.n = lr.limit.Load()
|
||||||
|
lr.r = r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (lr *limitReader) Read(p []byte) (int, error) {
|
||||||
|
if lr.n <= 0 {
|
||||||
|
err := fmt.Errorf("read limited at %v bytes", lr.limit.Load())
|
||||||
|
lr.c.writeError(StatusMessageTooBig, err)
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if int64(len(p)) > lr.n {
|
||||||
|
p = p[:lr.n]
|
||||||
|
}
|
||||||
|
n, err := lr.r.Read(p)
|
||||||
|
lr.n -= int64(n)
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
type readerFunc func(p []byte) (int, error)
|
||||||
|
|
||||||
|
func (f readerFunc) Read(p []byte) (int, error) {
|
||||||
|
return f(p)
|
||||||
|
}
|
|
@ -0,0 +1,91 @@
|
||||||
|
// Code generated by "stringer -type=opcode,MessageType,StatusCode -output=stringer.go"; DO NOT EDIT.
|
||||||
|
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
import "strconv"
|
||||||
|
|
||||||
|
func _() {
|
||||||
|
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||||
|
// Re-run the stringer command to generate them again.
|
||||||
|
var x [1]struct{}
|
||||||
|
_ = x[opContinuation-0]
|
||||||
|
_ = x[opText-1]
|
||||||
|
_ = x[opBinary-2]
|
||||||
|
_ = x[opClose-8]
|
||||||
|
_ = x[opPing-9]
|
||||||
|
_ = x[opPong-10]
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
_opcode_name_0 = "opContinuationopTextopBinary"
|
||||||
|
_opcode_name_1 = "opCloseopPingopPong"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
_opcode_index_0 = [...]uint8{0, 14, 20, 28}
|
||||||
|
_opcode_index_1 = [...]uint8{0, 7, 13, 19}
|
||||||
|
)
|
||||||
|
|
||||||
|
func (i opcode) String() string {
|
||||||
|
switch {
|
||||||
|
case 0 <= i && i <= 2:
|
||||||
|
return _opcode_name_0[_opcode_index_0[i]:_opcode_index_0[i+1]]
|
||||||
|
case 8 <= i && i <= 10:
|
||||||
|
i -= 8
|
||||||
|
return _opcode_name_1[_opcode_index_1[i]:_opcode_index_1[i+1]]
|
||||||
|
default:
|
||||||
|
return "opcode(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func _() {
|
||||||
|
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||||
|
// Re-run the stringer command to generate them again.
|
||||||
|
var x [1]struct{}
|
||||||
|
_ = x[MessageText-1]
|
||||||
|
_ = x[MessageBinary-2]
|
||||||
|
}
|
||||||
|
|
||||||
|
const _MessageType_name = "MessageTextMessageBinary"
|
||||||
|
|
||||||
|
var _MessageType_index = [...]uint8{0, 11, 24}
|
||||||
|
|
||||||
|
func (i MessageType) String() string {
|
||||||
|
i -= 1
|
||||||
|
if i < 0 || i >= MessageType(len(_MessageType_index)-1) {
|
||||||
|
return "MessageType(" + strconv.FormatInt(int64(i+1), 10) + ")"
|
||||||
|
}
|
||||||
|
return _MessageType_name[_MessageType_index[i]:_MessageType_index[i+1]]
|
||||||
|
}
|
||||||
|
func _() {
|
||||||
|
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||||
|
// Re-run the stringer command to generate them again.
|
||||||
|
var x [1]struct{}
|
||||||
|
_ = x[StatusNormalClosure-1000]
|
||||||
|
_ = x[StatusGoingAway-1001]
|
||||||
|
_ = x[StatusProtocolError-1002]
|
||||||
|
_ = x[StatusUnsupportedData-1003]
|
||||||
|
_ = x[statusReserved-1004]
|
||||||
|
_ = x[StatusNoStatusRcvd-1005]
|
||||||
|
_ = x[StatusAbnormalClosure-1006]
|
||||||
|
_ = x[StatusInvalidFramePayloadData-1007]
|
||||||
|
_ = x[StatusPolicyViolation-1008]
|
||||||
|
_ = x[StatusMessageTooBig-1009]
|
||||||
|
_ = x[StatusMandatoryExtension-1010]
|
||||||
|
_ = x[StatusInternalError-1011]
|
||||||
|
_ = x[StatusServiceRestart-1012]
|
||||||
|
_ = x[StatusTryAgainLater-1013]
|
||||||
|
_ = x[StatusBadGateway-1014]
|
||||||
|
_ = x[StatusTLSHandshake-1015]
|
||||||
|
}
|
||||||
|
|
||||||
|
const _StatusCode_name = "StatusNormalClosureStatusGoingAwayStatusProtocolErrorStatusUnsupportedDatastatusReservedStatusNoStatusRcvdStatusAbnormalClosureStatusInvalidFramePayloadDataStatusPolicyViolationStatusMessageTooBigStatusMandatoryExtensionStatusInternalErrorStatusServiceRestartStatusTryAgainLaterStatusBadGatewayStatusTLSHandshake"
|
||||||
|
|
||||||
|
var _StatusCode_index = [...]uint16{0, 19, 34, 53, 74, 88, 106, 127, 156, 177, 196, 220, 239, 259, 278, 294, 312}
|
||||||
|
|
||||||
|
func (i StatusCode) String() string {
|
||||||
|
i -= 1000
|
||||||
|
if i < 0 || i >= StatusCode(len(_StatusCode_index)-1) {
|
||||||
|
return "StatusCode(" + strconv.FormatInt(int64(i+1000), 10) + ")"
|
||||||
|
}
|
||||||
|
return _StatusCode_name[_StatusCode_index[i]:_StatusCode_index[i+1]]
|
||||||
|
}
|
|
@ -0,0 +1,397 @@
|
||||||
|
// +build !js
|
||||||
|
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/klauspost/compress/flate"
|
||||||
|
|
||||||
|
"nhooyr.io/websocket/internal/errd"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Writer returns a writer bounded by the context that will write
|
||||||
|
// a WebSocket message of type dataType to the connection.
|
||||||
|
//
|
||||||
|
// You must close the writer once you have written the entire message.
|
||||||
|
//
|
||||||
|
// Only one writer can be open at a time, multiple calls will block until the previous writer
|
||||||
|
// is closed.
|
||||||
|
func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
|
||||||
|
w, err := c.writer(ctx, typ)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get writer: %w", err)
|
||||||
|
}
|
||||||
|
return w, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write writes a message to the connection.
|
||||||
|
//
|
||||||
|
// See the Writer method if you want to stream a message.
|
||||||
|
//
|
||||||
|
// If compression is disabled or the threshold is not met, then it
|
||||||
|
// will write the message in a single frame.
|
||||||
|
func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
|
||||||
|
_, err := c.write(ctx, typ, p)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to write msg: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type msgWriter struct {
|
||||||
|
mw *msgWriterState
|
||||||
|
closed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mw *msgWriter) Write(p []byte) (int, error) {
|
||||||
|
if mw.closed {
|
||||||
|
return 0, errors.New("cannot use closed writer")
|
||||||
|
}
|
||||||
|
return mw.mw.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mw *msgWriter) Close() error {
|
||||||
|
if mw.closed {
|
||||||
|
return errors.New("cannot use closed writer")
|
||||||
|
}
|
||||||
|
mw.closed = true
|
||||||
|
return mw.mw.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
type msgWriterState struct {
|
||||||
|
c *Conn
|
||||||
|
|
||||||
|
mu *mu
|
||||||
|
writeMu *mu
|
||||||
|
|
||||||
|
ctx context.Context
|
||||||
|
opcode opcode
|
||||||
|
flate bool
|
||||||
|
|
||||||
|
trimWriter *trimLastFourBytesWriter
|
||||||
|
dict slidingWindow
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMsgWriterState(c *Conn) *msgWriterState {
|
||||||
|
mw := &msgWriterState{
|
||||||
|
c: c,
|
||||||
|
mu: newMu(c),
|
||||||
|
writeMu: newMu(c),
|
||||||
|
}
|
||||||
|
return mw
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mw *msgWriterState) ensureFlate() {
|
||||||
|
if mw.trimWriter == nil {
|
||||||
|
mw.trimWriter = &trimLastFourBytesWriter{
|
||||||
|
w: writerFunc(mw.write),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mw.dict.init(8192)
|
||||||
|
mw.flate = true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mw *msgWriterState) flateContextTakeover() bool {
|
||||||
|
if mw.c.client {
|
||||||
|
return !mw.c.copts.clientNoContextTakeover
|
||||||
|
}
|
||||||
|
return !mw.c.copts.serverNoContextTakeover
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
|
||||||
|
err := c.msgWriterState.reset(ctx, typ)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &msgWriter{
|
||||||
|
mw: c.msgWriterState,
|
||||||
|
closed: false,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) {
|
||||||
|
mw, err := c.writer(ctx, typ)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !c.flate() {
|
||||||
|
defer c.msgWriterState.mu.unlock()
|
||||||
|
return c.writeFrame(ctx, true, false, c.msgWriterState.opcode, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := mw.Write(p)
|
||||||
|
if err != nil {
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = mw.Close()
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error {
|
||||||
|
err := mw.mu.lock(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
mw.ctx = ctx
|
||||||
|
mw.opcode = opcode(typ)
|
||||||
|
mw.flate = false
|
||||||
|
|
||||||
|
mw.trimWriter.reset()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write writes the given bytes to the WebSocket connection.
|
||||||
|
func (mw *msgWriterState) Write(p []byte) (_ int, err error) {
|
||||||
|
err = mw.writeMu.lock(mw.ctx)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("failed to write: %w", err)
|
||||||
|
}
|
||||||
|
defer mw.writeMu.unlock()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("failed to write: %w", err)
|
||||||
|
mw.c.close(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if mw.c.flate() {
|
||||||
|
// Only enables flate if the length crosses the
|
||||||
|
// threshold on the first frame
|
||||||
|
if mw.opcode != opContinuation && len(p) >= mw.c.flateThreshold {
|
||||||
|
mw.ensureFlate()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if mw.flate {
|
||||||
|
err = flate.StatelessDeflate(mw.trimWriter, p, false, mw.dict.buf)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
mw.dict.write(p)
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return mw.write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mw *msgWriterState) write(p []byte) (int, error) {
|
||||||
|
n, err := mw.c.writeFrame(mw.ctx, false, mw.flate, mw.opcode, p)
|
||||||
|
if err != nil {
|
||||||
|
return n, fmt.Errorf("failed to write data frame: %w", err)
|
||||||
|
}
|
||||||
|
mw.opcode = opContinuation
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close flushes the frame to the connection.
|
||||||
|
func (mw *msgWriterState) Close() (err error) {
|
||||||
|
defer errd.Wrap(&err, "failed to close writer")
|
||||||
|
|
||||||
|
err = mw.writeMu.lock(mw.ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer mw.writeMu.unlock()
|
||||||
|
|
||||||
|
_, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to write fin frame: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if mw.flate && !mw.flateContextTakeover() {
|
||||||
|
mw.dict.close()
|
||||||
|
}
|
||||||
|
mw.mu.unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mw *msgWriterState) close() {
|
||||||
|
if mw.c.client {
|
||||||
|
mw.c.writeFrameMu.forceLock()
|
||||||
|
putBufioWriter(mw.c.bw)
|
||||||
|
}
|
||||||
|
|
||||||
|
mw.writeMu.forceLock()
|
||||||
|
mw.dict.close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error {
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
_, err := c.writeFrame(ctx, true, false, opcode, p)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to write control frame %v: %w", opcode, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// frame handles all writes to the connection.
|
||||||
|
func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) {
|
||||||
|
err = c.writeFrameMu.lock(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer c.writeFrameMu.unlock()
|
||||||
|
|
||||||
|
// If the state says a close has already been written, we wait until
|
||||||
|
// the connection is closed and return that error.
|
||||||
|
//
|
||||||
|
// However, if the frame being written is a close, that means its the close from
|
||||||
|
// the state being set so we let it go through.
|
||||||
|
c.closeMu.Lock()
|
||||||
|
wroteClose := c.wroteClose
|
||||||
|
c.closeMu.Unlock()
|
||||||
|
if wroteClose && opcode != opClose {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return 0, ctx.Err()
|
||||||
|
case <-c.closed:
|
||||||
|
return 0, c.closeErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-c.closed:
|
||||||
|
return 0, c.closeErr
|
||||||
|
case c.writeTimeout <- ctx:
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
select {
|
||||||
|
case <-c.closed:
|
||||||
|
err = c.closeErr
|
||||||
|
case <-ctx.Done():
|
||||||
|
err = ctx.Err()
|
||||||
|
}
|
||||||
|
c.close(err)
|
||||||
|
err = fmt.Errorf("failed to write frame: %w", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
c.writeHeader.fin = fin
|
||||||
|
c.writeHeader.opcode = opcode
|
||||||
|
c.writeHeader.payloadLength = int64(len(p))
|
||||||
|
|
||||||
|
if c.client {
|
||||||
|
c.writeHeader.masked = true
|
||||||
|
_, err = io.ReadFull(rand.Reader, c.writeHeaderBuf[:4])
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("failed to generate masking key: %w", err)
|
||||||
|
}
|
||||||
|
c.writeHeader.maskKey = binary.LittleEndian.Uint32(c.writeHeaderBuf[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
c.writeHeader.rsv1 = false
|
||||||
|
if flate && (opcode == opText || opcode == opBinary) {
|
||||||
|
c.writeHeader.rsv1 = true
|
||||||
|
}
|
||||||
|
|
||||||
|
err = writeFrameHeader(c.writeHeader, c.bw, c.writeHeaderBuf[:])
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := c.writeFramePayload(p)
|
||||||
|
if err != nil {
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.writeHeader.fin {
|
||||||
|
err = c.bw.Flush()
|
||||||
|
if err != nil {
|
||||||
|
return n, fmt.Errorf("failed to flush: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-c.closed:
|
||||||
|
return n, c.closeErr
|
||||||
|
case c.writeTimeout <- context.Background():
|
||||||
|
}
|
||||||
|
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) writeFramePayload(p []byte) (n int, err error) {
|
||||||
|
defer errd.Wrap(&err, "failed to write frame payload")
|
||||||
|
|
||||||
|
if !c.writeHeader.masked {
|
||||||
|
return c.bw.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
maskKey := c.writeHeader.maskKey
|
||||||
|
for len(p) > 0 {
|
||||||
|
// If the buffer is full, we need to flush.
|
||||||
|
if c.bw.Available() == 0 {
|
||||||
|
err = c.bw.Flush()
|
||||||
|
if err != nil {
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start of next write in the buffer.
|
||||||
|
i := c.bw.Buffered()
|
||||||
|
|
||||||
|
j := len(p)
|
||||||
|
if j > c.bw.Available() {
|
||||||
|
j = c.bw.Available()
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := c.bw.Write(p[:j])
|
||||||
|
if err != nil {
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
maskKey = mask(maskKey, c.writeBuf[i:c.bw.Buffered()])
|
||||||
|
|
||||||
|
p = p[j:]
|
||||||
|
n += j
|
||||||
|
}
|
||||||
|
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type writerFunc func(p []byte) (int, error)
|
||||||
|
|
||||||
|
func (f writerFunc) Write(p []byte) (int, error) {
|
||||||
|
return f(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractBufioWriterBuf grabs the []byte backing a *bufio.Writer
|
||||||
|
// and returns it.
|
||||||
|
func extractBufioWriterBuf(bw *bufio.Writer, w io.Writer) []byte {
|
||||||
|
var writeBuf []byte
|
||||||
|
bw.Reset(writerFunc(func(p2 []byte) (int, error) {
|
||||||
|
writeBuf = p2[:cap(p2)]
|
||||||
|
return len(p2), nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
bw.WriteByte(0)
|
||||||
|
bw.Flush()
|
||||||
|
|
||||||
|
bw.Reset(w)
|
||||||
|
|
||||||
|
return writeBuf
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) writeError(code StatusCode, err error) {
|
||||||
|
c.setCloseErr(err)
|
||||||
|
c.writeClose(code, err.Error())
|
||||||
|
c.close(nil)
|
||||||
|
}
|
|
@ -0,0 +1,379 @@
|
||||||
|
package websocket // import "nhooyr.io/websocket"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"reflect"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"syscall/js"
|
||||||
|
|
||||||
|
"nhooyr.io/websocket/internal/bpool"
|
||||||
|
"nhooyr.io/websocket/internal/wsjs"
|
||||||
|
"nhooyr.io/websocket/internal/xsync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Conn provides a wrapper around the browser WebSocket API.
|
||||||
|
type Conn struct {
|
||||||
|
ws wsjs.WebSocket
|
||||||
|
|
||||||
|
// read limit for a message in bytes.
|
||||||
|
msgReadLimit xsync.Int64
|
||||||
|
|
||||||
|
closingMu sync.Mutex
|
||||||
|
isReadClosed xsync.Int64
|
||||||
|
closeOnce sync.Once
|
||||||
|
closed chan struct{}
|
||||||
|
closeErrOnce sync.Once
|
||||||
|
closeErr error
|
||||||
|
closeWasClean bool
|
||||||
|
|
||||||
|
releaseOnClose func()
|
||||||
|
releaseOnMessage func()
|
||||||
|
|
||||||
|
readSignal chan struct{}
|
||||||
|
readBufMu sync.Mutex
|
||||||
|
readBuf []wsjs.MessageEvent
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) close(err error, wasClean bool) {
|
||||||
|
c.closeOnce.Do(func() {
|
||||||
|
runtime.SetFinalizer(c, nil)
|
||||||
|
|
||||||
|
if !wasClean {
|
||||||
|
err = fmt.Errorf("unclean connection close: %w", err)
|
||||||
|
}
|
||||||
|
c.setCloseErr(err)
|
||||||
|
c.closeWasClean = wasClean
|
||||||
|
close(c.closed)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) init() {
|
||||||
|
c.closed = make(chan struct{})
|
||||||
|
c.readSignal = make(chan struct{}, 1)
|
||||||
|
|
||||||
|
c.msgReadLimit.Store(32768)
|
||||||
|
|
||||||
|
c.releaseOnClose = c.ws.OnClose(func(e wsjs.CloseEvent) {
|
||||||
|
err := CloseError{
|
||||||
|
Code: StatusCode(e.Code),
|
||||||
|
Reason: e.Reason,
|
||||||
|
}
|
||||||
|
// We do not know if we sent or received this close as
|
||||||
|
// its possible the browser triggered it without us
|
||||||
|
// explicitly sending it.
|
||||||
|
c.close(err, e.WasClean)
|
||||||
|
|
||||||
|
c.releaseOnClose()
|
||||||
|
c.releaseOnMessage()
|
||||||
|
})
|
||||||
|
|
||||||
|
c.releaseOnMessage = c.ws.OnMessage(func(e wsjs.MessageEvent) {
|
||||||
|
c.readBufMu.Lock()
|
||||||
|
defer c.readBufMu.Unlock()
|
||||||
|
|
||||||
|
c.readBuf = append(c.readBuf, e)
|
||||||
|
|
||||||
|
// Lets the read goroutine know there is definitely something in readBuf.
|
||||||
|
select {
|
||||||
|
case c.readSignal <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
runtime.SetFinalizer(c, func(c *Conn) {
|
||||||
|
c.setCloseErr(errors.New("connection garbage collected"))
|
||||||
|
c.closeWithInternal()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) closeWithInternal() {
|
||||||
|
c.Close(StatusInternalError, "something went wrong")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read attempts to read a message from the connection.
|
||||||
|
// The maximum time spent waiting is bounded by the context.
|
||||||
|
func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
|
||||||
|
if c.isReadClosed.Load() == 1 {
|
||||||
|
return 0, nil, errors.New("WebSocket connection read closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
typ, p, err := c.read(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return 0, nil, fmt.Errorf("failed to read: %w", err)
|
||||||
|
}
|
||||||
|
if int64(len(p)) > c.msgReadLimit.Load() {
|
||||||
|
err := fmt.Errorf("read limited at %v bytes", c.msgReadLimit.Load())
|
||||||
|
c.Close(StatusMessageTooBig, err.Error())
|
||||||
|
return 0, nil, err
|
||||||
|
}
|
||||||
|
return typ, p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) read(ctx context.Context) (MessageType, []byte, error) {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
c.Close(StatusPolicyViolation, "read timed out")
|
||||||
|
return 0, nil, ctx.Err()
|
||||||
|
case <-c.readSignal:
|
||||||
|
case <-c.closed:
|
||||||
|
return 0, nil, c.closeErr
|
||||||
|
}
|
||||||
|
|
||||||
|
c.readBufMu.Lock()
|
||||||
|
defer c.readBufMu.Unlock()
|
||||||
|
|
||||||
|
me := c.readBuf[0]
|
||||||
|
// We copy the messages forward and decrease the size
|
||||||
|
// of the slice to avoid reallocating.
|
||||||
|
copy(c.readBuf, c.readBuf[1:])
|
||||||
|
c.readBuf = c.readBuf[:len(c.readBuf)-1]
|
||||||
|
|
||||||
|
if len(c.readBuf) > 0 {
|
||||||
|
// Next time we read, we'll grab the message.
|
||||||
|
select {
|
||||||
|
case c.readSignal <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch p := me.Data.(type) {
|
||||||
|
case string:
|
||||||
|
return MessageText, []byte(p), nil
|
||||||
|
case []byte:
|
||||||
|
return MessageBinary, p, nil
|
||||||
|
default:
|
||||||
|
panic("websocket: unexpected data type from wsjs OnMessage: " + reflect.TypeOf(me.Data).String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ping is mocked out for Wasm.
|
||||||
|
func (c *Conn) Ping(ctx context.Context) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write writes a message of the given type to the connection.
|
||||||
|
// Always non blocking.
|
||||||
|
func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
|
||||||
|
err := c.write(ctx, typ, p)
|
||||||
|
if err != nil {
|
||||||
|
// Have to ensure the WebSocket is closed after a write error
|
||||||
|
// to match the Go API. It can only error if the message type
|
||||||
|
// is unexpected or the passed bytes contain invalid UTF-8 for
|
||||||
|
// MessageText.
|
||||||
|
err := fmt.Errorf("failed to write: %w", err)
|
||||||
|
c.setCloseErr(err)
|
||||||
|
c.closeWithInternal()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error {
|
||||||
|
if c.isClosed() {
|
||||||
|
return c.closeErr
|
||||||
|
}
|
||||||
|
switch typ {
|
||||||
|
case MessageBinary:
|
||||||
|
return c.ws.SendBytes(p)
|
||||||
|
case MessageText:
|
||||||
|
return c.ws.SendText(string(p))
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unexpected message type: %v", typ)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the WebSocket with the given code and reason.
|
||||||
|
// It will wait until the peer responds with a close frame
|
||||||
|
// or the connection is closed.
|
||||||
|
// It thus performs the full WebSocket close handshake.
|
||||||
|
func (c *Conn) Close(code StatusCode, reason string) error {
|
||||||
|
err := c.exportedClose(code, reason)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to close WebSocket: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) exportedClose(code StatusCode, reason string) error {
|
||||||
|
c.closingMu.Lock()
|
||||||
|
defer c.closingMu.Unlock()
|
||||||
|
|
||||||
|
ce := fmt.Errorf("sent close: %w", CloseError{
|
||||||
|
Code: code,
|
||||||
|
Reason: reason,
|
||||||
|
})
|
||||||
|
|
||||||
|
if c.isClosed() {
|
||||||
|
return fmt.Errorf("tried to close with %q but connection already closed: %w", ce, c.closeErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.setCloseErr(ce)
|
||||||
|
err := c.ws.Close(int(code), reason)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
<-c.closed
|
||||||
|
if !c.closeWasClean {
|
||||||
|
return c.closeErr
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subprotocol returns the negotiated subprotocol.
|
||||||
|
// An empty string means the default protocol.
|
||||||
|
func (c *Conn) Subprotocol() string {
|
||||||
|
return c.ws.Subprotocol()
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialOptions represents the options available to pass to Dial.
|
||||||
|
type DialOptions struct {
|
||||||
|
// Subprotocols lists the subprotocols to negotiate with the server.
|
||||||
|
Subprotocols []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dial creates a new WebSocket connection to the given url with the given options.
|
||||||
|
// The passed context bounds the maximum time spent waiting for the connection to open.
|
||||||
|
// The returned *http.Response is always nil or a mock. It's only in the signature
|
||||||
|
// to match the core API.
|
||||||
|
func Dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Response, error) {
|
||||||
|
c, resp, err := dial(ctx, url, opts)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("failed to WebSocket dial %q: %w", url, err)
|
||||||
|
}
|
||||||
|
return c, resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Response, error) {
|
||||||
|
if opts == nil {
|
||||||
|
opts = &DialOptions{}
|
||||||
|
}
|
||||||
|
|
||||||
|
url = strings.Replace(url, "http://", "ws://", 1)
|
||||||
|
url = strings.Replace(url, "https://", "wss://", 1)
|
||||||
|
|
||||||
|
ws, err := wsjs.New(url, opts.Subprotocols)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
c := &Conn{
|
||||||
|
ws: ws,
|
||||||
|
}
|
||||||
|
c.init()
|
||||||
|
|
||||||
|
opench := make(chan struct{})
|
||||||
|
releaseOpen := ws.OnOpen(func(e js.Value) {
|
||||||
|
close(opench)
|
||||||
|
})
|
||||||
|
defer releaseOpen()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
c.Close(StatusPolicyViolation, "dial timed out")
|
||||||
|
return nil, nil, ctx.Err()
|
||||||
|
case <-opench:
|
||||||
|
return c, &http.Response{
|
||||||
|
StatusCode: http.StatusSwitchingProtocols,
|
||||||
|
}, nil
|
||||||
|
case <-c.closed:
|
||||||
|
return nil, nil, c.closeErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reader attempts to read a message from the connection.
|
||||||
|
// The maximum time spent waiting is bounded by the context.
|
||||||
|
func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
|
||||||
|
typ, p, err := c.Read(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return 0, nil, err
|
||||||
|
}
|
||||||
|
return typ, bytes.NewReader(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Writer returns a writer to write a WebSocket data message to the connection.
|
||||||
|
// It buffers the entire message in memory and then sends it when the writer
|
||||||
|
// is closed.
|
||||||
|
func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
|
||||||
|
return writer{
|
||||||
|
c: c,
|
||||||
|
ctx: ctx,
|
||||||
|
typ: typ,
|
||||||
|
b: bpool.Get(),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type writer struct {
|
||||||
|
closed bool
|
||||||
|
|
||||||
|
c *Conn
|
||||||
|
ctx context.Context
|
||||||
|
typ MessageType
|
||||||
|
|
||||||
|
b *bytes.Buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w writer) Write(p []byte) (int, error) {
|
||||||
|
if w.closed {
|
||||||
|
return 0, errors.New("cannot write to closed writer")
|
||||||
|
}
|
||||||
|
n, err := w.b.Write(p)
|
||||||
|
if err != nil {
|
||||||
|
return n, fmt.Errorf("failed to write message: %w", err)
|
||||||
|
}
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w writer) Close() error {
|
||||||
|
if w.closed {
|
||||||
|
return errors.New("cannot close closed writer")
|
||||||
|
}
|
||||||
|
w.closed = true
|
||||||
|
defer bpool.Put(w.b)
|
||||||
|
|
||||||
|
err := w.c.Write(w.ctx, w.typ, w.b.Bytes())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to close writer: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseRead implements *Conn.CloseRead for wasm.
|
||||||
|
func (c *Conn) CloseRead(ctx context.Context) context.Context {
|
||||||
|
c.isReadClosed.Store(1)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
go func() {
|
||||||
|
defer cancel()
|
||||||
|
c.read(ctx)
|
||||||
|
c.Close(StatusPolicyViolation, "unexpected data message")
|
||||||
|
}()
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetReadLimit implements *Conn.SetReadLimit for wasm.
|
||||||
|
func (c *Conn) SetReadLimit(n int64) {
|
||||||
|
c.msgReadLimit.Store(n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) setCloseErr(err error) {
|
||||||
|
c.closeErrOnce.Do(func() {
|
||||||
|
c.closeErr = fmt.Errorf("WebSocket closed: %w", err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) isClosed() bool {
|
||||||
|
select {
|
||||||
|
case <-c.closed:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue