TUN-7373: Streaming logs override for same actor

To help accommodate web browser interactions with websockets, when a
streaming logs session is requested for the same actor while already
serving a session for that user in a separate request, the original
request will be closed and the new request start streaming logs
instead. This should help with rogue sessions holding on for too long
with no client on the other side (before idle timeout or connection
close).
This commit is contained in:
Devin Carr 2023-04-21 11:54:37 -07:00
parent ee5e447d44
commit 38cd455e4d
109 changed files with 12691 additions and 1798 deletions

View File

@ -3,6 +3,7 @@ import asyncio
import json import json
import pytest import pytest
import requests import requests
import websockets
from websockets.client import connect, WebSocketClientProtocol from websockets.client import connect, WebSocketClientProtocol
from conftest import CfdModes from conftest import CfdModes
from constants import MAX_RETRIES, BACKOFF_SECS from constants import MAX_RETRIES, BACKOFF_SECS
@ -88,6 +89,46 @@ class TestTail:
# send stop_streaming # send stop_streaming
await websocket.send('{"type": "stop_streaming"}') await websocket.send('{"type": "stop_streaming"}')
@pytest.mark.asyncio
async def test_streaming_logs_actor_override(self, tmp_path, component_tests_config):
"""
Validates that a streaming logs session can be overriden by the same actor
"""
print("test_streaming_logs_actor_override")
config = component_tests_config(cfd_mode=CfdModes.NAMED, run_proxy_dns=False, provide_ingress=False)
LOGGER.debug(config)
headers = {}
headers["Content-Type"] = "application/json"
config_path = write_config(tmp_path, config.full_config)
with start_cloudflared(tmp_path, config, cfd_args=["run", "--hello-world"], new_process=True):
wait_tunnel_ready(tunnel_url=config.get_url(),
require_min_connections=1)
cfd_cli = CloudflaredCli(config, config_path, LOGGER)
url = cfd_cli.get_management_wsurl("logs", config, config_path)
task = asyncio.ensure_future(start_streaming_to_be_remotely_closed(url))
override_task = asyncio.ensure_future(start_streaming_override(url))
await asyncio.wait([task, override_task])
assert task.exception() == None, task.exception()
assert override_task.exception() == None, override_task.exception()
async def start_streaming_to_be_remotely_closed(url):
async with connect(url, open_timeout=5, close_timeout=5) as websocket:
try:
await websocket.send(json.dumps({"type": "start_streaming"}))
await asyncio.sleep(10)
assert websocket.closed, "expected this request to be forcibly closed by the override"
except websockets.ConnectionClosed:
# we expect the request to be closed
pass
async def start_streaming_override(url):
# wait for the first connection to be established
await asyncio.sleep(1)
async with connect(url, open_timeout=5, close_timeout=5) as websocket:
await websocket.send(json.dumps({"type": "start_streaming"}))
await asyncio.sleep(1)
await websocket.send(json.dumps({"type": "stop_streaming"}))
await asyncio.sleep(1)
# Every http request has two log lines sent # Every http request has two log lines sent
async def generate_and_validate_log_event(websocket: WebSocketClientProtocol, url: str): async def generate_and_validate_log_event(websocket: WebSocketClientProtocol, url: str):

9
go.mod
View File

@ -13,6 +13,7 @@ require (
github.com/getsentry/raven-go v0.2.0 github.com/getsentry/raven-go v0.2.0
github.com/getsentry/sentry-go v0.16.0 github.com/getsentry/sentry-go v0.16.0
github.com/go-chi/chi/v5 v5.0.8 github.com/go-chi/chi/v5 v5.0.8
github.com/go-jose/go-jose/v3 v3.0.0
github.com/gobwas/ws v1.0.4 github.com/gobwas/ws v1.0.4
github.com/golang-collections/collections v0.0.0-20130729185459-604e922904d3 github.com/golang-collections/collections v0.0.0-20130729185459-604e922904d3
github.com/google/gopacket v1.1.19 github.com/google/gopacket v1.1.19
@ -36,8 +37,8 @@ require (
go.opentelemetry.io/otel/trace v1.6.3 go.opentelemetry.io/otel/trace v1.6.3
go.opentelemetry.io/proto/otlp v0.15.0 go.opentelemetry.io/proto/otlp v0.15.0
go.uber.org/automaxprocs v1.4.0 go.uber.org/automaxprocs v1.4.0
golang.org/x/crypto v0.5.0 golang.org/x/crypto v0.8.0
golang.org/x/net v0.7.0 golang.org/x/net v0.9.0
golang.org/x/sync v0.1.0 golang.org/x/sync v0.1.0
golang.org/x/sys v0.7.0 golang.org/x/sys v0.7.0
golang.org/x/term v0.7.0 golang.org/x/term v0.7.0
@ -96,8 +97,8 @@ require (
github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect
golang.org/x/mod v0.8.0 // indirect golang.org/x/mod v0.8.0 // indirect
golang.org/x/oauth2 v0.4.0 // indirect golang.org/x/oauth2 v0.4.0 // indirect
golang.org/x/text v0.7.0 // indirect golang.org/x/text v0.9.0 // indirect
golang.org/x/tools v0.1.12 // indirect golang.org/x/tools v0.6.0 // indirect
google.golang.org/appengine v1.6.7 // indirect google.golang.org/appengine v1.6.7 // indirect
google.golang.org/genproto v0.0.0-20221202195650-67e5cbc046fd // indirect google.golang.org/genproto v0.0.0-20221202195650-67e5cbc046fd // indirect
google.golang.org/grpc v1.51.0 // indirect google.golang.org/grpc v1.51.0 // indirect

19
go.sum
View File

@ -178,6 +178,8 @@ github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxI
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-jose/go-jose/v3 v3.0.0 h1:s6rrhirfEP/CGIoc6p+PZAeogN2SxKav6Wp7+dyMWVo=
github.com/go-jose/go-jose/v3 v3.0.0/go.mod h1:RNkWWRld676jZEYoV3+XK8L2ZnNSvIsxFMht0mSX+u8=
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
@ -542,12 +544,13 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20190313024323-a1f597ede03a/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190313024323-a1f597ede03a/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.5.0 h1:U/0M97KRkSFvyD/3FSmdP5W5swImpNgle/EHFhOsQPE= golang.org/x/crypto v0.8.0 h1:pd9TJtTueMTVQXzk8E2XESSMQDj/U7OUu0PqJqPXQjQ=
golang.org/x/crypto v0.5.0/go.mod h1:NK/OQwhpMQP3MwtdjgLlYHnH9ebylxKWv3e0fK+mkQU= golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@ -640,8 +643,8 @@ golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4/go.mod h1:CfG3xpIq0wQ8r1q4Su
golang.org/x/net v0.0.0-20220607020251-c690dde0001d/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.0.0-20220607020251-c690dde0001d/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk=
golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g= golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM=
golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
@ -777,8 +780,8 @@ golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo= golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
@ -839,8 +842,8 @@ golang.org/x/tools v0.1.3/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.6-0.20210726203631-07bc1bf47fb2/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.6-0.20210726203631-07bc1bf47fb2/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU= golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

View File

@ -11,16 +11,9 @@ import (
var json = jsoniter.ConfigFastest var json = jsoniter.ConfigFastest
const (
// Indicates how many log messages the listener will hold before dropping.
// Provides a throttling mechanism to drop latest messages if the sender
// can't keep up with the influx of log messages.
logWindow = 30
)
// Logger manages the number of management streaming log sessions // Logger manages the number of management streaming log sessions
type Logger struct { type Logger struct {
sessions []*Session sessions []*session
mu sync.RWMutex mu sync.RWMutex
// Unique logger that isn't a io.Writer of the list of zerolog writers. This helps prevent management log // Unique logger that isn't a io.Writer of the list of zerolog writers. This helps prevent management log
@ -40,69 +33,47 @@ func NewLogger() *Logger {
} }
type LoggerListener interface { type LoggerListener interface {
Listen(*StreamingFilters) *Session // ActiveSession returns the first active session for the requested actor.
Close(*Session) ActiveSession(actor) *session
// ActiveSession returns the count of active sessions.
ActiveSessions() int
// Listen appends the session to the list of sessions that receive log events.
Listen(*session)
// Remove a session from the available sessions that were receiving log events.
Remove(*session)
} }
type Session struct { func (l *Logger) ActiveSession(actor actor) *session {
// Buffered channel that holds the recent log events l.mu.RLock()
listener chan *Log defer l.mu.RUnlock()
// Types of log events that this session will provide through the listener for _, session := range l.sessions {
filters *StreamingFilters if session.actor.ID == actor.ID && session.active.Load() {
} return session
func newSession(size int, filters *StreamingFilters) *Session {
s := &Session{
listener: make(chan *Log, size),
}
if filters != nil {
s.filters = filters
} else {
s.filters = &StreamingFilters{}
}
return s
}
// Insert attempts to insert the log to the session. If the log event matches the provided session filters, it
// will be applied to the listener.
func (s *Session) Insert(log *Log) {
// Level filters are optional
if s.filters.Level != nil {
if *s.filters.Level > log.Level {
return
} }
} }
// Event filters are optional return nil
if len(s.filters.Events) != 0 && !contains(s.filters.Events, log.Event) {
return
}
select {
case s.listener <- log:
default:
// buffer is full, discard
}
} }
func contains(array []LogEventType, t LogEventType) bool { func (l *Logger) ActiveSessions() int {
for _, v := range array { l.mu.RLock()
if v == t { defer l.mu.RUnlock()
return true count := 0
for _, session := range l.sessions {
if session.active.Load() {
count += 1
} }
} }
return false return count
} }
// Listen creates a new Session that will append filtered log events as they are created. func (l *Logger) Listen(session *session) {
func (l *Logger) Listen(filters *StreamingFilters) *Session {
l.mu.Lock() l.mu.Lock()
defer l.mu.Unlock() defer l.mu.Unlock()
listener := newSession(logWindow, filters) session.active.Store(true)
l.sessions = append(l.sessions, listener) l.sessions = append(l.sessions, session)
return listener
} }
// Close will remove a Session from the available sessions that were receiving log events. func (l *Logger) Remove(session *session) {
func (l *Logger) Close(session *Session) {
l.mu.Lock() l.mu.Lock()
defer l.mu.Unlock() defer l.mu.Unlock()
index := -1 index := -1

View File

@ -1,8 +1,8 @@
package management package management
import ( import (
"context"
"testing" "testing"
"time"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -21,9 +21,14 @@ func TestLoggerWrite_NoSessions(t *testing.T) {
func TestLoggerWrite_OneSession(t *testing.T) { func TestLoggerWrite_OneSession(t *testing.T) {
logger := NewLogger() logger := NewLogger()
zlog := zerolog.New(logger).With().Timestamp().Logger().Level(zerolog.InfoLevel) zlog := zerolog.New(logger).With().Timestamp().Logger().Level(zerolog.InfoLevel)
_, cancel := context.WithCancel(context.Background())
defer cancel()
session := logger.Listen(nil) session := newSession(logWindow, actor{ID: actorID}, cancel)
defer logger.Close(session) logger.Listen(session)
defer logger.Remove(session)
assert.Equal(t, 1, logger.ActiveSessions())
assert.Equal(t, session, logger.ActiveSession(actor{ID: actorID}))
zlog.Info().Int(EventTypeKey, int(HTTP)).Msg("hello") zlog.Info().Int(EventTypeKey, int(HTTP)).Msg("hello")
select { select {
case event := <-session.listener: case event := <-session.listener:
@ -40,12 +45,20 @@ func TestLoggerWrite_OneSession(t *testing.T) {
func TestLoggerWrite_MultipleSessions(t *testing.T) { func TestLoggerWrite_MultipleSessions(t *testing.T) {
logger := NewLogger() logger := NewLogger()
zlog := zerolog.New(logger).With().Timestamp().Logger().Level(zerolog.InfoLevel) zlog := zerolog.New(logger).With().Timestamp().Logger().Level(zerolog.InfoLevel)
_, cancel := context.WithCancel(context.Background())
defer cancel()
session1 := newSession(logWindow, actor{}, cancel)
logger.Listen(session1)
defer logger.Remove(session1)
assert.Equal(t, 1, logger.ActiveSessions())
session2 := newSession(logWindow, actor{}, cancel)
logger.Listen(session2)
assert.Equal(t, 2, logger.ActiveSessions())
session1 := logger.Listen(nil)
defer logger.Close(session1)
session2 := logger.Listen(nil)
zlog.Info().Int(EventTypeKey, int(HTTP)).Msg("hello") zlog.Info().Int(EventTypeKey, int(HTTP)).Msg("hello")
for _, session := range []*Session{session1, session2} { for _, session := range []*session{session1, session2} {
select { select {
case event := <-session.listener: case event := <-session.listener:
assert.NotEmpty(t, event.Time) assert.NotEmpty(t, event.Time)
@ -58,7 +71,7 @@ func TestLoggerWrite_MultipleSessions(t *testing.T) {
} }
// Close session2 and make sure session1 still receives events // Close session2 and make sure session1 still receives events
logger.Close(session2) logger.Remove(session2)
zlog.Info().Int(EventTypeKey, int(HTTP)).Msg("hello2") zlog.Info().Int(EventTypeKey, int(HTTP)).Msg("hello2")
select { select {
case event := <-session1.listener: case event := <-session1.listener:
@ -79,104 +92,6 @@ func TestLoggerWrite_MultipleSessions(t *testing.T) {
} }
} }
// Validate that the session filters events
func TestSession_Insert(t *testing.T) {
infoLevel := new(LogLevel)
*infoLevel = Info
warnLevel := new(LogLevel)
*warnLevel = Warn
for _, test := range []struct {
name string
filters StreamingFilters
expectLog bool
}{
{
name: "none",
expectLog: true,
},
{
name: "level",
filters: StreamingFilters{
Level: infoLevel,
},
expectLog: true,
},
{
name: "filtered out level",
filters: StreamingFilters{
Level: warnLevel,
},
expectLog: false,
},
{
name: "events",
filters: StreamingFilters{
Events: []LogEventType{HTTP},
},
expectLog: true,
},
{
name: "filtered out event",
filters: StreamingFilters{
Events: []LogEventType{Cloudflared},
},
expectLog: false,
},
{
name: "filter and event",
filters: StreamingFilters{
Level: infoLevel,
Events: []LogEventType{HTTP},
},
expectLog: true,
},
} {
t.Run(test.name, func(t *testing.T) {
session := newSession(4, &test.filters)
log := Log{
Time: time.Now().UTC().Format(time.RFC3339),
Event: HTTP,
Level: Info,
Message: "test",
}
session.Insert(&log)
select {
case <-session.listener:
require.True(t, test.expectLog)
default:
require.False(t, test.expectLog)
}
})
}
}
// Validate that the session has a max amount of events to hold
func TestSession_InsertOverflow(t *testing.T) {
session := newSession(1, nil)
log := Log{
Time: time.Now().UTC().Format(time.RFC3339),
Event: HTTP,
Level: Info,
Message: "test",
}
// Insert 2 but only max channel size for 1
session.Insert(&log)
session.Insert(&log)
select {
case <-session.listener:
// pass
default:
require.Fail(t, "expected one log event")
}
// Second dequeue should fail
select {
case <-session.listener:
require.Fail(t, "expected no more remaining log events")
default:
// pass
}
}
type mockWriter struct { type mockWriter struct {
event *Log event *Log
err error err error

76
management/middleware.go Normal file
View File

@ -0,0 +1,76 @@
package management
import (
"context"
"fmt"
"net/http"
)
type ctxKey int
const (
accessClaimsCtxKey ctxKey = iota
)
const (
connectorIDQuery = "connector_id"
accessTokenQuery = "access_token"
)
var (
errMissingAccessToken = managementError{Code: 1001, Message: "missing access_token query parameter"}
)
// HTTP middleware setting the parsed access_token claims in the request context
func ValidateAccessTokenQueryMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Validate access token
accessToken := r.URL.Query().Get("access_token")
if accessToken == "" {
writeHTTPErrorResponse(w, errMissingAccessToken)
return
}
token, err := parseToken(accessToken)
if err != nil {
writeHTTPErrorResponse(w, errMissingAccessToken)
return
}
r = r.WithContext(context.WithValue(r.Context(), accessClaimsCtxKey, token))
next.ServeHTTP(w, r)
})
}
// Middleware validation error struct for returning to the eyeball
type managementError struct {
Code int `json:"code,omitempty"`
Message string `json:"message,omitempty"`
}
func (m *managementError) Error() string {
return m.Message
}
// Middleware validation error HTTP response JSON for returning to the eyeball
type managementErrorResponse struct {
Success bool `json:"success,omitempty"`
Errors []managementError `json:"errors,omitempty"`
}
// writeErrorResponse will respond to the eyeball with basic HTTP JSON payloads with validation failure information
func writeHTTPErrorResponse(w http.ResponseWriter, errResp managementError) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
err := json.NewEncoder(w).Encode(managementErrorResponse{
Success: false,
Errors: []managementError{errResp},
})
// we have already written the header, so write a basic error response if unable to encode the error
if err != nil {
// fallback to text message
http.Error(w, fmt.Sprintf(
"%d %s",
http.StatusBadRequest,
http.StatusText(http.StatusBadRequest),
), http.StatusBadRequest)
}
}

View File

@ -0,0 +1,71 @@
package management
import (
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/go-chi/chi/v5"
"github.com/stretchr/testify/assert"
)
func TestValidateAccessTokenQueryMiddleware(t *testing.T) {
r := chi.NewRouter()
r.Use(ValidateAccessTokenQueryMiddleware)
r.Get("/valid", func(w http.ResponseWriter, r *http.Request) {
claims, ok := r.Context().Value(accessClaimsCtxKey).(*managementTokenClaims)
assert.True(t, ok)
assert.True(t, claims.verify())
w.WriteHeader(http.StatusOK)
})
r.Get("/invalid", func(w http.ResponseWriter, r *http.Request) {
_, ok := r.Context().Value(accessClaimsCtxKey).(*managementTokenClaims)
assert.False(t, ok)
w.WriteHeader(http.StatusOK)
})
ts := httptest.NewServer(r)
defer ts.Close()
// valid: with access_token query param
path := "/valid?access_token=" + validToken
resp, _ := testRequest(t, ts, "GET", path, nil)
assert.Equal(t, http.StatusOK, resp.StatusCode)
// invalid: unset token
path = "/invalid"
resp, err := testRequest(t, ts, "GET", path, nil)
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
assert.NotNil(t, err)
assert.Equal(t, errMissingAccessToken, err.Errors[0])
// invalid: invalid token
path = "/invalid?access_token=eyJ"
resp, err = testRequest(t, ts, "GET", path, nil)
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
assert.NotNil(t, err)
assert.Equal(t, errMissingAccessToken, err.Errors[0])
}
func testRequest(t *testing.T, ts *httptest.Server, method, path string, body io.Reader) (*http.Response, *managementErrorResponse) {
req, err := http.NewRequest(method, ts.URL+path, body)
if err != nil {
t.Fatal(err)
return nil, nil
}
resp, err := ts.Client().Do(req)
if err != nil {
t.Fatal(err)
return nil, nil
}
var claims managementErrorResponse
err = json.NewDecoder(resp.Body).Decode(&claims)
if err != nil {
return resp, nil
}
defer resp.Body.Close()
return resp, &claims
}

View File

@ -7,7 +7,6 @@ import (
"net/http" "net/http"
"os" "os"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
@ -24,7 +23,7 @@ const (
// value will return this error to incoming requests. // value will return this error to incoming requests.
StatusSessionLimitExceeded websocket.StatusCode = 4002 StatusSessionLimitExceeded websocket.StatusCode = 4002
reasonSessionLimitExceeded = "limit exceeded for streaming sessions" reasonSessionLimitExceeded = "limit exceeded for streaming sessions"
// There is a limited idle time while not actively serving a session for a request before dropping the connection.
StatusIdleLimitExceeded websocket.StatusCode = 4003 StatusIdleLimitExceeded websocket.StatusCode = 4003
reasonIdleLimitExceeded = "session was idle for too long" reasonIdleLimitExceeded = "session was idle for too long"
) )
@ -41,9 +40,6 @@ type ManagementService struct {
log *zerolog.Logger log *zerolog.Logger
router chi.Router router chi.Router
// streaming signifies if the service is already streaming logs. Helps limit the number of active users streaming logs
// from this cloudflared instance.
streaming atomic.Bool
// streamingMut is a lock to prevent concurrent requests to start streaming. Utilizing the atomic.Bool is not // streamingMut is a lock to prevent concurrent requests to start streaming. Utilizing the atomic.Bool is not
// sufficient to complete this operation since many other checks during an incoming new request are needed // sufficient to complete this operation since many other checks during an incoming new request are needed
// to validate this before setting streaming to true. // to validate this before setting streaming to true.
@ -67,6 +63,7 @@ func New(managementHostname string,
label: label, label: label,
} }
r := chi.NewRouter() r := chi.NewRouter()
r.Use(ValidateAccessTokenQueryMiddleware)
r.Get("/ping", ping) r.Get("/ping", ping)
r.Head("/ping", ping) r.Head("/ping", ping)
r.Get("/logs", s.logs) r.Get("/logs", s.logs)
@ -92,7 +89,6 @@ type getHostDetailsResponse struct {
} }
func (m *ManagementService) getHostDetails(w http.ResponseWriter, r *http.Request) { func (m *ManagementService) getHostDetails(w http.ResponseWriter, r *http.Request) {
var getHostDetailsResponse = getHostDetailsResponse{ var getHostDetailsResponse = getHostDetailsResponse{
ClientID: m.clientID.String(), ClientID: m.clientID.String(),
} }
@ -157,12 +153,11 @@ func (m *ManagementService) readEvents(c *websocket.Conn, ctx context.Context, e
} }
// streamLogs will begin the process of reading from the Session listener and write the log events to the client. // streamLogs will begin the process of reading from the Session listener and write the log events to the client.
func (m *ManagementService) streamLogs(c *websocket.Conn, ctx context.Context, session *Session) { func (m *ManagementService) streamLogs(c *websocket.Conn, ctx context.Context, session *session) {
defer m.logger.Close(session) for session.Active() {
for m.streaming.Load() {
select { select {
case <-ctx.Done(): case <-ctx.Done():
m.streaming.Store(false) session.Stop()
return return
case event := <-session.listener: case event := <-session.listener:
err := WriteEvent(c, ctx, &EventLog{ err := WriteEvent(c, ctx, &EventLog{
@ -176,7 +171,7 @@ func (m *ManagementService) streamLogs(c *websocket.Conn, ctx context.Context, s
m.log.Err(c.Close(websocket.StatusInternalError, err.Error())).Send() m.log.Err(c.Close(websocket.StatusInternalError, err.Error())).Send()
} }
// Any errors when writing the messages to the client will stop streaming and close the connection // Any errors when writing the messages to the client will stop streaming and close the connection
m.streaming.Store(false) session.Stop()
return return
} }
default: default:
@ -185,28 +180,38 @@ func (m *ManagementService) streamLogs(c *websocket.Conn, ctx context.Context, s
} }
} }
// startStreaming will check the conditions of the request and begin streaming or close the connection for invalid // canStartStream will check the conditions of the request and return if the session can begin streaming.
// requests. func (m *ManagementService) canStartStream(session *session) bool {
func (m *ManagementService) startStreaming(c *websocket.Conn, ctx context.Context, event *ClientEvent) {
m.streamingMut.Lock() m.streamingMut.Lock()
defer m.streamingMut.Unlock() defer m.streamingMut.Unlock()
// Limits to one user for streaming logs // Limits to one actor for streaming logs
if m.streaming.Load() { if m.logger.ActiveSessions() > 0 {
m.log.Warn(). // Allow the same user to preempt their existing session to disconnect their old session and start streaming
Msgf("Another management session request was attempted but one session already being served; there is a limit of streaming log sessions to reduce overall performance impact.") // with this new session instead.
m.log.Err(c.Close(StatusSessionLimitExceeded, reasonSessionLimitExceeded)).Send() if existingSession := m.logger.ActiveSession(session.actor); existingSession != nil {
return m.log.Info().
Msgf("Another management session request for the same actor was requested; the other session will be disconnected to handle the new request.")
existingSession.Stop()
m.logger.Remove(existingSession)
existingSession.cancel()
} else {
m.log.Warn().
Msgf("Another management session request was attempted but one session already being served; there is a limit of streaming log sessions to reduce overall performance impact.")
return false
}
} }
return true
}
// parseFilters will check the ClientEvent for start_streaming and assign filters if provided to the session
func (m *ManagementService) parseFilters(c *websocket.Conn, event *ClientEvent, session *session) bool {
// Expect the first incoming request // Expect the first incoming request
startEvent, ok := IntoClientEvent[EventStartStreaming](event, StartStreaming) startEvent, ok := IntoClientEvent[EventStartStreaming](event, StartStreaming)
if !ok { if !ok {
m.log.Warn().Err(c.Close(StatusInvalidCommand, reasonInvalidCommand)).Msgf("expected start_streaming as first recieved event") return false
return
} }
m.streaming.Store(true) session.Filters(startEvent.Filters)
listener := m.logger.Listen(startEvent.Filters) return true
m.log.Debug().Msgf("Streaming logs")
go m.streamLogs(c, ctx, listener)
} }
// Management Streaming Logs accept handler // Management Streaming Logs accept handler
@ -227,11 +232,23 @@ func (m *ManagementService) logs(w http.ResponseWriter, r *http.Request) {
ping := time.NewTicker(15 * time.Second) ping := time.NewTicker(15 * time.Second)
defer ping.Stop() defer ping.Stop()
// Close the connection if no operation has occurred after the idle timeout. // Close the connection if no operation has occurred after the idle timeout. The timeout is halted
// when streaming logs is active.
idleTimeout := 5 * time.Minute idleTimeout := 5 * time.Minute
idle := time.NewTimer(idleTimeout) idle := time.NewTimer(idleTimeout)
defer idle.Stop() defer idle.Stop()
// Fetch the claims from the request context to acquire the actor
claims, ok := ctx.Value(accessClaimsCtxKey).(*managementTokenClaims)
if !ok || claims == nil {
// Typically should never happen as it is provided in the context from the middleware
m.log.Err(c.Close(websocket.StatusInternalError, "missing access_token")).Send()
return
}
session := newSession(logWindow, claims.Actor, cancel)
defer m.logger.Remove(session)
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
@ -242,12 +259,28 @@ func (m *ManagementService) logs(w http.ResponseWriter, r *http.Request) {
switch event.Type { switch event.Type {
case StartStreaming: case StartStreaming:
idle.Stop() idle.Stop()
m.startStreaming(c, ctx, event) // Expect the first incoming request
startEvent, ok := IntoClientEvent[EventStartStreaming](event, StartStreaming)
if !ok {
m.log.Warn().Msgf("expected start_streaming as first recieved event")
m.log.Err(c.Close(StatusInvalidCommand, reasonInvalidCommand)).Send()
return
}
// Make sure the session can start
if !m.canStartStream(session) {
m.log.Err(c.Close(StatusSessionLimitExceeded, reasonSessionLimitExceeded)).Send()
return
}
session.Filters(startEvent.Filters)
m.logger.Listen(session)
m.log.Debug().Msgf("Streaming logs")
go m.streamLogs(c, ctx, session)
continue continue
case StopStreaming: case StopStreaming:
idle.Reset(idleTimeout) idle.Reset(idleTimeout)
// TODO: limit StopStreaming to only halt streaming for clients that are already streaming // Stop the current session for the current actor who requested it
m.streaming.Store(false) session.Stop()
m.logger.Remove(session)
case UnknownClientEventType: case UnknownClientEventType:
fallthrough fallthrough
default: default:

View File

@ -7,6 +7,7 @@ import (
"time" "time"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"nhooyr.io/websocket" "nhooyr.io/websocket"
@ -59,3 +60,67 @@ func TestReadEventsLoop_ContextCancelled(t *testing.T) {
m.readEvents(server, ctx, events) m.readEvents(server, ctx, events)
server.Close(websocket.StatusInternalError, "") server.Close(websocket.StatusInternalError, "")
} }
func TestCanStartStream_NoSessions(t *testing.T) {
m := ManagementService{
log: &noopLogger,
logger: &Logger{
Log: &noopLogger,
},
}
_, cancel := context.WithCancel(context.Background())
session := newSession(0, actor{}, cancel)
assert.True(t, m.canStartStream(session))
}
func TestCanStartStream_ExistingSessionDifferentActor(t *testing.T) {
m := ManagementService{
log: &noopLogger,
logger: &Logger{
Log: &noopLogger,
},
}
_, cancel := context.WithCancel(context.Background())
session1 := newSession(0, actor{ID: "test"}, cancel)
assert.True(t, m.canStartStream(session1))
m.logger.Listen(session1)
assert.True(t, session1.Active())
// Try another session
session2 := newSession(0, actor{ID: "test2"}, cancel)
assert.Equal(t, 1, m.logger.ActiveSessions())
assert.False(t, m.canStartStream(session2))
// Close session1
m.logger.Remove(session1)
assert.True(t, session1.Active()) // Remove doesn't stop a session
session1.Stop()
assert.False(t, session1.Active())
assert.Equal(t, 0, m.logger.ActiveSessions())
// Try session2 again
assert.True(t, m.canStartStream(session2))
}
func TestCanStartStream_ExistingSessionSameActor(t *testing.T) {
m := ManagementService{
log: &noopLogger,
logger: &Logger{
Log: &noopLogger,
},
}
actor := actor{ID: "test"}
_, cancel := context.WithCancel(context.Background())
session1 := newSession(0, actor, cancel)
assert.True(t, m.canStartStream(session1))
m.logger.Listen(session1)
assert.True(t, session1.Active())
// Try another session
session2 := newSession(0, actor, cancel)
assert.Equal(t, 1, m.logger.ActiveSessions())
assert.True(t, m.canStartStream(session2))
// session1 is removed and stopped
assert.Equal(t, 0, m.logger.ActiveSessions())
assert.False(t, session1.Active())
}

88
management/session.go Normal file
View File

@ -0,0 +1,88 @@
package management
import (
"context"
"sync/atomic"
)
const (
// Indicates how many log messages the listener will hold before dropping.
// Provides a throttling mechanism to drop latest messages if the sender
// can't keep up with the influx of log messages.
logWindow = 30
)
// session captures a streaming logs session for a connection of an actor.
type session struct {
// Indicates if the session is streaming or not. Modifying this will affect the active session.
active atomic.Bool
// Allows the session to control the context of the underlying connection to close it out when done. Mostly
// used by the LoggerListener to close out and cleanup a session.
cancel context.CancelFunc
// Actor who started the session
actor actor
// Buffered channel that holds the recent log events
listener chan *Log
// Types of log events that this session will provide through the listener
filters *StreamingFilters
}
// NewSession creates a new session.
func newSession(size int, actor actor, cancel context.CancelFunc) *session {
s := &session{
active: atomic.Bool{},
cancel: cancel,
actor: actor,
listener: make(chan *Log, size),
filters: &StreamingFilters{},
}
return s
}
// Filters assigns the StreamingFilters to the session
func (s *session) Filters(filters *StreamingFilters) {
if filters != nil {
s.filters = filters
} else {
s.filters = &StreamingFilters{}
}
}
// Insert attempts to insert the log to the session. If the log event matches the provided session filters, it
// will be applied to the listener.
func (s *session) Insert(log *Log) {
// Level filters are optional
if s.filters.Level != nil {
if *s.filters.Level > log.Level {
return
}
}
// Event filters are optional
if len(s.filters.Events) != 0 && !contains(s.filters.Events, log.Event) {
return
}
select {
case s.listener <- log:
default:
// buffer is full, discard
}
}
// Active returns if the session is active
func (s *session) Active() bool {
return s.active.Load()
}
// Stop will halt the session
func (s *session) Stop() {
s.active.Store(false)
}
func contains(array []LogEventType, t LogEventType) bool {
for _, v := range array {
if v == t {
return true
}
}
return false
}

126
management/session_test.go Normal file
View File

@ -0,0 +1,126 @@
package management
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Validate the active states of the session
func TestSession_ActiveControl(t *testing.T) {
_, cancel := context.WithCancel(context.Background())
defer cancel()
session := newSession(4, actor{}, cancel)
// session starts out not active
assert.False(t, session.Active())
session.active.Store(true)
assert.True(t, session.Active())
session.Stop()
assert.False(t, session.Active())
}
// Validate that the session filters events
func TestSession_Insert(t *testing.T) {
_, cancel := context.WithCancel(context.Background())
defer cancel()
infoLevel := new(LogLevel)
*infoLevel = Info
warnLevel := new(LogLevel)
*warnLevel = Warn
for _, test := range []struct {
name string
filters StreamingFilters
expectLog bool
}{
{
name: "none",
expectLog: true,
},
{
name: "level",
filters: StreamingFilters{
Level: infoLevel,
},
expectLog: true,
},
{
name: "filtered out level",
filters: StreamingFilters{
Level: warnLevel,
},
expectLog: false,
},
{
name: "events",
filters: StreamingFilters{
Events: []LogEventType{HTTP},
},
expectLog: true,
},
{
name: "filtered out event",
filters: StreamingFilters{
Events: []LogEventType{Cloudflared},
},
expectLog: false,
},
{
name: "filter and event",
filters: StreamingFilters{
Level: infoLevel,
Events: []LogEventType{HTTP},
},
expectLog: true,
},
} {
t.Run(test.name, func(t *testing.T) {
session := newSession(4, actor{}, cancel)
session.Filters(&test.filters)
log := Log{
Time: time.Now().UTC().Format(time.RFC3339),
Event: HTTP,
Level: Info,
Message: "test",
}
session.Insert(&log)
select {
case <-session.listener:
require.True(t, test.expectLog)
default:
require.False(t, test.expectLog)
}
})
}
}
// Validate that the session has a max amount of events to hold
func TestSession_InsertOverflow(t *testing.T) {
_, cancel := context.WithCancel(context.Background())
defer cancel()
session := newSession(1, actor{}, cancel)
log := Log{
Time: time.Now().UTC().Format(time.RFC3339),
Event: HTTP,
Level: Info,
Message: "test",
}
// Insert 2 but only max channel size for 1
session.Insert(&log)
session.Insert(&log)
select {
case <-session.listener:
// pass
default:
require.Fail(t, "expected one log event")
}
// Second dequeue should fail
select {
case <-session.listener:
require.Fail(t, "expected no more remaining log events")
default:
// pass
}
}

55
management/token.go Normal file
View File

@ -0,0 +1,55 @@
package management
import (
"fmt"
"github.com/go-jose/go-jose/v3/jwt"
)
type managementTokenClaims struct {
Tunnel tunnel `json:"tun"`
Actor actor `json:"actor"`
}
// VerifyTunnel compares the tun claim isn't empty
func (c *managementTokenClaims) verify() bool {
return c.Tunnel.verify() && c.Actor.verify()
}
type tunnel struct {
ID string `json:"id"`
AccountTag string `json:"account_tag"`
}
// verify compares the tun claim isn't empty
func (t *tunnel) verify() bool {
return t.AccountTag != "" && t.ID != ""
}
type actor struct {
ID string `json:"id"`
Support bool `json:"support"`
}
// verify checks the ID claim isn't empty
func (t *actor) verify() bool {
return t.ID != ""
}
func parseToken(token string) (*managementTokenClaims, error) {
jwt, err := jwt.ParseSigned(token)
if err != nil {
return nil, fmt.Errorf("malformed jwt: %v", err)
}
var claims managementTokenClaims
// This is actually safe because we verify the token in the edge before it reaches cloudflared
err = jwt.UnsafeClaimsWithoutVerification(&claims)
if err != nil {
return nil, fmt.Errorf("malformed jwt: %v", err)
}
if !claims.verify() {
return nil, fmt.Errorf("invalid management token format provided")
}
return &claims, nil
}

130
management/token_test.go Normal file
View File

@ -0,0 +1,130 @@
package management
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"errors"
"testing"
"github.com/stretchr/testify/require"
"gopkg.in/square/go-jose.v2"
)
const (
validToken = "eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiIsImtpZCI6IjEifQ.eyJ0dW4iOnsiaWQiOiI3YjA5ODE0OS01MWZlLTRlZTUtYTY4Ny0zZTM3NDQ2NmVmYzciLCJhY2NvdW50X3RhZyI6ImNkMzkxZTljMDYyNmE4Zjc2Y2IxZjY3MGY2NTkxYjA1In0sImFjdG9yIjp7ImlkIjoiZGNhcnJAY2xvdWRmbGFyZS5jb20iLCJzdXBwb3J0IjpmYWxzZX0sInJlcyI6WyJsb2dzIl0sImV4cCI6MTY3NzExNzY5NiwiaWF0IjoxNjc3MTE0MDk2LCJpc3MiOiJ0dW5uZWxzdG9yZSJ9.mKenOdOy3Xi4O-grldFnAAemdlE9WajEpTDC_FwezXQTstWiRTLwU65P5jt4vNsIiZA4OJRq7bH-QYID9wf9NA"
accountTag = "cd391e9c0626a8f76cb1f670f6591b05"
tunnelID = "7b098149-51fe-4ee5-a687-3e374466efc7"
actorID = "45d2751e-6b59-45a9-814d-f630786bd0cd"
)
type invalidManagementTokenClaims struct {
Invalid string `json:"invalid"`
}
func TestParseToken(t *testing.T) {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
for _, test := range []struct {
name string
claims any
err error
}{
{
name: "Valid",
claims: managementTokenClaims{
Tunnel: tunnel{
ID: tunnelID,
AccountTag: accountTag,
},
Actor: actor{
ID: actorID,
},
},
},
{
name: "Invalid claims",
claims: invalidManagementTokenClaims{Invalid: "invalid"},
err: errors.New("invalid management token format provided"),
},
{
name: "Missing Tunnel",
claims: managementTokenClaims{
Actor: actor{
ID: actorID,
},
},
err: errors.New("invalid management token format provided"),
},
{
name: "Missing Tunnel ID",
claims: managementTokenClaims{
Tunnel: tunnel{
AccountTag: accountTag,
},
Actor: actor{
ID: actorID,
},
},
err: errors.New("invalid management token format provided"),
},
{
name: "Missing Account Tag",
claims: managementTokenClaims{
Tunnel: tunnel{
ID: tunnelID,
},
Actor: actor{
ID: actorID,
},
},
err: errors.New("invalid management token format provided"),
},
{
name: "Missing Actor",
claims: managementTokenClaims{
Tunnel: tunnel{
ID: tunnelID,
AccountTag: accountTag,
},
},
err: errors.New("invalid management token format provided"),
},
{
name: "Missing Actor ID",
claims: managementTokenClaims{
Tunnel: tunnel{
ID: tunnelID,
},
Actor: actor{},
},
err: errors.New("invalid management token format provided"),
},
} {
t.Run(test.name, func(t *testing.T) {
jwt := signToken(t, test.claims, key)
claims, err := parseToken(jwt)
if test.err != nil {
require.EqualError(t, err, test.err.Error())
return
}
require.Nil(t, err)
require.Equal(t, test.claims, *claims)
})
}
}
func signToken(t *testing.T, token any, key *ecdsa.PrivateKey) string {
opts := (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", "1")
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: key}, opts)
require.NoError(t, err)
payload, err := json.Marshal(token)
require.NoError(t, err)
jws, err := signer.Sign(payload)
require.NoError(t, err)
jwt, err := jws.CompactSerialize()
require.NoError(t, err)
return jwt
}

2
vendor/github.com/go-jose/go-jose/v3/.gitignore generated vendored Normal file
View File

@ -0,0 +1,2 @@
jose-util/jose-util
jose-util.t.err

53
vendor/github.com/go-jose/go-jose/v3/.golangci.yml generated vendored Normal file
View File

@ -0,0 +1,53 @@
# https://github.com/golangci/golangci-lint
run:
skip-files:
- doc_test.go
modules-download-mode: readonly
linters:
enable-all: true
disable:
- gochecknoglobals
- goconst
- lll
- maligned
- nakedret
- scopelint
- unparam
- funlen # added in 1.18 (requires go-jose changes before it can be enabled)
linters-settings:
gocyclo:
min-complexity: 35
issues:
exclude-rules:
- text: "don't use ALL_CAPS in Go names"
linters:
- golint
- text: "hardcoded credentials"
linters:
- gosec
- text: "weak cryptographic primitive"
linters:
- gosec
- path: json/
linters:
- dupl
- errcheck
- gocritic
- gocyclo
- golint
- govet
- ineffassign
- staticcheck
- structcheck
- stylecheck
- unused
- path: _test\.go
linters:
- scopelint
- path: jwk.go
linters:
- gocyclo

33
vendor/github.com/go-jose/go-jose/v3/.travis.yml generated vendored Normal file
View File

@ -0,0 +1,33 @@
language: go
matrix:
fast_finish: true
allow_failures:
- go: tip
go:
- "1.13.x"
- "1.14.x"
- tip
before_script:
- export PATH=$HOME/.local/bin:$PATH
before_install:
- go get -u github.com/mattn/goveralls github.com/wadey/gocovmerge
- curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | sh -s -- -b $(go env GOPATH)/bin v1.18.0
- pip install cram --user
script:
- go test -v -covermode=count -coverprofile=profile.cov .
- go test -v -covermode=count -coverprofile=cryptosigner/profile.cov ./cryptosigner
- go test -v -covermode=count -coverprofile=cipher/profile.cov ./cipher
- go test -v -covermode=count -coverprofile=jwt/profile.cov ./jwt
- go test -v ./json # no coverage for forked encoding/json package
- golangci-lint run
- cd jose-util && go build && PATH=$PWD:$PATH cram -v jose-util.t # cram tests jose-util
- cd ..
after_success:
- gocovmerge *.cov */*.cov > merged.coverprofile
- goveralls -coverprofile merged.coverprofile -service=travis-ci

10
vendor/github.com/go-jose/go-jose/v3/BUG-BOUNTY.md generated vendored Normal file
View File

@ -0,0 +1,10 @@
Serious about security
======================
Square recognizes the important contributions the security research community
can make. We therefore encourage reporting security issues with the code
contained in this repository.
If you believe you have discovered a security vulnerability, please follow the
guidelines at <https://bugcrowd.com/squareopensource>.

15
vendor/github.com/go-jose/go-jose/v3/CONTRIBUTING.md generated vendored Normal file
View File

@ -0,0 +1,15 @@
# Contributing
If you would like to contribute code to go-jose you can do so through GitHub by
forking the repository and sending a pull request.
When submitting code, please make every effort to follow existing conventions
and style in order to keep the code as readable as possible. Please also make
sure all tests pass by running `go test`, and format your code with `go fmt`.
We also recommend using `golint` and `errcheck`.
Before your code can be accepted into the project you must also sign the
Individual Contributor License Agreement. We use [cla-assistant.io][1] and you
will be prompted to sign once a pull request is opened.
[1]: https://cla-assistant.io/

202
vendor/github.com/go-jose/go-jose/v3/LICENSE generated vendored Normal file
View File

@ -0,0 +1,202 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

122
vendor/github.com/go-jose/go-jose/v3/README.md generated vendored Normal file
View File

@ -0,0 +1,122 @@
# Go JOSE
[![godoc](http://img.shields.io/badge/godoc-jose_package-blue.svg?style=flat)](https://godoc.org/gopkg.in/go-jose/go-jose.v2)
[![godoc](http://img.shields.io/badge/godoc-jwt_package-blue.svg?style=flat)](https://godoc.org/gopkg.in/go-jose/go-jose.v2/jwt)
[![license](http://img.shields.io/badge/license-apache_2.0-blue.svg?style=flat)](https://raw.githubusercontent.com/go-jose/go-jose/master/LICENSE)
[![build](https://travis-ci.org/go-jose/go-jose.svg?branch=master)](https://travis-ci.org/go-jose/go-jose)
[![coverage](https://coveralls.io/repos/github/go-jose/go-jose/badge.svg?branch=master)](https://coveralls.io/r/go-jose/go-jose)
Package jose aims to provide an implementation of the Javascript Object Signing
and Encryption set of standards. This includes support for JSON Web Encryption,
JSON Web Signature, and JSON Web Token standards.
**Disclaimer**: This library contains encryption software that is subject to
the U.S. Export Administration Regulations. You may not export, re-export,
transfer or download this code or any part of it in violation of any United
States law, directive or regulation. In particular this software may not be
exported or re-exported in any form or on any media to Iran, North Sudan,
Syria, Cuba, or North Korea, or to denied persons or entities mentioned on any
US maintained blocked list.
## Overview
The implementation follows the
[JSON Web Encryption](http://dx.doi.org/10.17487/RFC7516) (RFC 7516),
[JSON Web Signature](http://dx.doi.org/10.17487/RFC7515) (RFC 7515), and
[JSON Web Token](http://dx.doi.org/10.17487/RFC7519) (RFC 7519) specifications.
Tables of supported algorithms are shown below. The library supports both
the compact and JWS/JWE JSON Serialization formats, and has optional support for
multiple recipients. It also comes with a small command-line utility
([`jose-util`](https://github.com/go-jose/go-jose/tree/master/jose-util))
for dealing with JOSE messages in a shell.
**Note**: We use a forked version of the `encoding/json` package from the Go
standard library which uses case-sensitive matching for member names (instead
of [case-insensitive matching](https://www.ietf.org/mail-archive/web/json/current/msg03763.html)).
This is to avoid differences in interpretation of messages between go-jose and
libraries in other languages.
### Versions
[Version 2](https://gopkg.in/go-jose/go-jose.v2)
([branch](https://github.com/go-jose/go-jose/tree/v2),
[doc](https://godoc.org/gopkg.in/go-jose/go-jose.v2)) is the current stable version:
import "gopkg.in/go-jose/go-jose.v2"
[Version 3](https://github.com/go-jose/go-jose)
([branch](https://github.com/go-jose/go-jose/tree/master),
[doc](https://godoc.org/github.com/go-jose/go-jose)) is the under development/unstable version (not released yet):
import "github.com/go-jose/go-jose/v3"
All new feature development takes place on the `master` branch, which we are
preparing to release as version 3 soon. Version 2 will continue to receive
critical bug and security fixes. Note that starting with version 3 we are
using Go modules for versioning instead of `gopkg.in` as before. Version 3 also will require Go version 1.13 or higher.
Version 1 (on the `v1` branch) is frozen and not supported anymore.
### Supported algorithms
See below for a table of supported algorithms. Algorithm identifiers match
the names in the [JSON Web Algorithms](http://dx.doi.org/10.17487/RFC7518)
standard where possible. The Godoc reference has a list of constants.
Key encryption | Algorithm identifier(s)
:------------------------- | :------------------------------
RSA-PKCS#1v1.5 | RSA1_5
RSA-OAEP | RSA-OAEP, RSA-OAEP-256
AES key wrap | A128KW, A192KW, A256KW
AES-GCM key wrap | A128GCMKW, A192GCMKW, A256GCMKW
ECDH-ES + AES key wrap | ECDH-ES+A128KW, ECDH-ES+A192KW, ECDH-ES+A256KW
ECDH-ES (direct) | ECDH-ES<sup>1</sup>
Direct encryption | dir<sup>1</sup>
<sup>1. Not supported in multi-recipient mode</sup>
Signing / MAC | Algorithm identifier(s)
:------------------------- | :------------------------------
RSASSA-PKCS#1v1.5 | RS256, RS384, RS512
RSASSA-PSS | PS256, PS384, PS512
HMAC | HS256, HS384, HS512
ECDSA | ES256, ES384, ES512
Ed25519 | EdDSA<sup>2</sup>
<sup>2. Only available in version 2 of the package</sup>
Content encryption | Algorithm identifier(s)
:------------------------- | :------------------------------
AES-CBC+HMAC | A128CBC-HS256, A192CBC-HS384, A256CBC-HS512
AES-GCM | A128GCM, A192GCM, A256GCM
Compression | Algorithm identifiers(s)
:------------------------- | -------------------------------
DEFLATE (RFC 1951) | DEF
### Supported key types
See below for a table of supported key types. These are understood by the
library, and can be passed to corresponding functions such as `NewEncrypter` or
`NewSigner`. Each of these keys can also be wrapped in a JWK if desired, which
allows attaching a key id.
Algorithm(s) | Corresponding types
:------------------------- | -------------------------------
RSA | *[rsa.PublicKey](http://golang.org/pkg/crypto/rsa/#PublicKey), *[rsa.PrivateKey](http://golang.org/pkg/crypto/rsa/#PrivateKey)
ECDH, ECDSA | *[ecdsa.PublicKey](http://golang.org/pkg/crypto/ecdsa/#PublicKey), *[ecdsa.PrivateKey](http://golang.org/pkg/crypto/ecdsa/#PrivateKey)
EdDSA<sup>1</sup> | [ed25519.PublicKey](https://godoc.org/pkg/crypto/ed25519#PublicKey), [ed25519.PrivateKey](https://godoc.org/pkg/crypto/ed25519#PrivateKey)
AES, HMAC | []byte
<sup>1. Only available in version 2 or later of the package</sup>
## Examples
[![godoc](http://img.shields.io/badge/godoc-jose_package-blue.svg?style=flat)](https://godoc.org/gopkg.in/go-jose/go-jose.v2)
[![godoc](http://img.shields.io/badge/godoc-jwt_package-blue.svg?style=flat)](https://godoc.org/gopkg.in/go-jose/go-jose.v2/jwt)
Examples can be found in the Godoc
reference for this package. The
[`jose-util`](https://github.com/go-jose/go-jose/tree/master/jose-util)
subdirectory also contains a small command-line utility which might be useful
as an example as well.

592
vendor/github.com/go-jose/go-jose/v3/asymmetric.go generated vendored Normal file
View File

@ -0,0 +1,592 @@
/*-
* Copyright 2014 Square Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package jose
import (
"crypto"
"crypto/aes"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rand"
"crypto/rsa"
"crypto/sha1"
"crypto/sha256"
"errors"
"fmt"
"math/big"
josecipher "github.com/go-jose/go-jose/v3/cipher"
"github.com/go-jose/go-jose/v3/json"
)
// A generic RSA-based encrypter/verifier
type rsaEncrypterVerifier struct {
publicKey *rsa.PublicKey
}
// A generic RSA-based decrypter/signer
type rsaDecrypterSigner struct {
privateKey *rsa.PrivateKey
}
// A generic EC-based encrypter/verifier
type ecEncrypterVerifier struct {
publicKey *ecdsa.PublicKey
}
type edEncrypterVerifier struct {
publicKey ed25519.PublicKey
}
// A key generator for ECDH-ES
type ecKeyGenerator struct {
size int
algID string
publicKey *ecdsa.PublicKey
}
// A generic EC-based decrypter/signer
type ecDecrypterSigner struct {
privateKey *ecdsa.PrivateKey
}
type edDecrypterSigner struct {
privateKey ed25519.PrivateKey
}
// newRSARecipient creates recipientKeyInfo based on the given key.
func newRSARecipient(keyAlg KeyAlgorithm, publicKey *rsa.PublicKey) (recipientKeyInfo, error) {
// Verify that key management algorithm is supported by this encrypter
switch keyAlg {
case RSA1_5, RSA_OAEP, RSA_OAEP_256:
default:
return recipientKeyInfo{}, ErrUnsupportedAlgorithm
}
if publicKey == nil {
return recipientKeyInfo{}, errors.New("invalid public key")
}
return recipientKeyInfo{
keyAlg: keyAlg,
keyEncrypter: &rsaEncrypterVerifier{
publicKey: publicKey,
},
}, nil
}
// newRSASigner creates a recipientSigInfo based on the given key.
func newRSASigner(sigAlg SignatureAlgorithm, privateKey *rsa.PrivateKey) (recipientSigInfo, error) {
// Verify that key management algorithm is supported by this encrypter
switch sigAlg {
case RS256, RS384, RS512, PS256, PS384, PS512:
default:
return recipientSigInfo{}, ErrUnsupportedAlgorithm
}
if privateKey == nil {
return recipientSigInfo{}, errors.New("invalid private key")
}
return recipientSigInfo{
sigAlg: sigAlg,
publicKey: staticPublicKey(&JSONWebKey{
Key: privateKey.Public(),
}),
signer: &rsaDecrypterSigner{
privateKey: privateKey,
},
}, nil
}
func newEd25519Signer(sigAlg SignatureAlgorithm, privateKey ed25519.PrivateKey) (recipientSigInfo, error) {
if sigAlg != EdDSA {
return recipientSigInfo{}, ErrUnsupportedAlgorithm
}
if privateKey == nil {
return recipientSigInfo{}, errors.New("invalid private key")
}
return recipientSigInfo{
sigAlg: sigAlg,
publicKey: staticPublicKey(&JSONWebKey{
Key: privateKey.Public(),
}),
signer: &edDecrypterSigner{
privateKey: privateKey,
},
}, nil
}
// newECDHRecipient creates recipientKeyInfo based on the given key.
func newECDHRecipient(keyAlg KeyAlgorithm, publicKey *ecdsa.PublicKey) (recipientKeyInfo, error) {
// Verify that key management algorithm is supported by this encrypter
switch keyAlg {
case ECDH_ES, ECDH_ES_A128KW, ECDH_ES_A192KW, ECDH_ES_A256KW:
default:
return recipientKeyInfo{}, ErrUnsupportedAlgorithm
}
if publicKey == nil || !publicKey.Curve.IsOnCurve(publicKey.X, publicKey.Y) {
return recipientKeyInfo{}, errors.New("invalid public key")
}
return recipientKeyInfo{
keyAlg: keyAlg,
keyEncrypter: &ecEncrypterVerifier{
publicKey: publicKey,
},
}, nil
}
// newECDSASigner creates a recipientSigInfo based on the given key.
func newECDSASigner(sigAlg SignatureAlgorithm, privateKey *ecdsa.PrivateKey) (recipientSigInfo, error) {
// Verify that key management algorithm is supported by this encrypter
switch sigAlg {
case ES256, ES384, ES512:
default:
return recipientSigInfo{}, ErrUnsupportedAlgorithm
}
if privateKey == nil {
return recipientSigInfo{}, errors.New("invalid private key")
}
return recipientSigInfo{
sigAlg: sigAlg,
publicKey: staticPublicKey(&JSONWebKey{
Key: privateKey.Public(),
}),
signer: &ecDecrypterSigner{
privateKey: privateKey,
},
}, nil
}
// Encrypt the given payload and update the object.
func (ctx rsaEncrypterVerifier) encryptKey(cek []byte, alg KeyAlgorithm) (recipientInfo, error) {
encryptedKey, err := ctx.encrypt(cek, alg)
if err != nil {
return recipientInfo{}, err
}
return recipientInfo{
encryptedKey: encryptedKey,
header: &rawHeader{},
}, nil
}
// Encrypt the given payload. Based on the key encryption algorithm,
// this will either use RSA-PKCS1v1.5 or RSA-OAEP (with SHA-1 or SHA-256).
func (ctx rsaEncrypterVerifier) encrypt(cek []byte, alg KeyAlgorithm) ([]byte, error) {
switch alg {
case RSA1_5:
return rsa.EncryptPKCS1v15(RandReader, ctx.publicKey, cek)
case RSA_OAEP:
return rsa.EncryptOAEP(sha1.New(), RandReader, ctx.publicKey, cek, []byte{})
case RSA_OAEP_256:
return rsa.EncryptOAEP(sha256.New(), RandReader, ctx.publicKey, cek, []byte{})
}
return nil, ErrUnsupportedAlgorithm
}
// Decrypt the given payload and return the content encryption key.
func (ctx rsaDecrypterSigner) decryptKey(headers rawHeader, recipient *recipientInfo, generator keyGenerator) ([]byte, error) {
return ctx.decrypt(recipient.encryptedKey, headers.getAlgorithm(), generator)
}
// Decrypt the given payload. Based on the key encryption algorithm,
// this will either use RSA-PKCS1v1.5 or RSA-OAEP (with SHA-1 or SHA-256).
func (ctx rsaDecrypterSigner) decrypt(jek []byte, alg KeyAlgorithm, generator keyGenerator) ([]byte, error) {
// Note: The random reader on decrypt operations is only used for blinding,
// so stubbing is meanlingless (hence the direct use of rand.Reader).
switch alg {
case RSA1_5:
defer func() {
// DecryptPKCS1v15SessionKey sometimes panics on an invalid payload
// because of an index out of bounds error, which we want to ignore.
// This has been fixed in Go 1.3.1 (released 2014/08/13), the recover()
// only exists for preventing crashes with unpatched versions.
// See: https://groups.google.com/forum/#!topic/golang-dev/7ihX6Y6kx9k
// See: https://code.google.com/p/go/source/detail?r=58ee390ff31602edb66af41ed10901ec95904d33
_ = recover()
}()
// Perform some input validation.
keyBytes := ctx.privateKey.PublicKey.N.BitLen() / 8
if keyBytes != len(jek) {
// Input size is incorrect, the encrypted payload should always match
// the size of the public modulus (e.g. using a 2048 bit key will
// produce 256 bytes of output). Reject this since it's invalid input.
return nil, ErrCryptoFailure
}
cek, _, err := generator.genKey()
if err != nil {
return nil, ErrCryptoFailure
}
// When decrypting an RSA-PKCS1v1.5 payload, we must take precautions to
// prevent chosen-ciphertext attacks as described in RFC 3218, "Preventing
// the Million Message Attack on Cryptographic Message Syntax". We are
// therefore deliberately ignoring errors here.
_ = rsa.DecryptPKCS1v15SessionKey(rand.Reader, ctx.privateKey, jek, cek)
return cek, nil
case RSA_OAEP:
// Use rand.Reader for RSA blinding
return rsa.DecryptOAEP(sha1.New(), rand.Reader, ctx.privateKey, jek, []byte{})
case RSA_OAEP_256:
// Use rand.Reader for RSA blinding
return rsa.DecryptOAEP(sha256.New(), rand.Reader, ctx.privateKey, jek, []byte{})
}
return nil, ErrUnsupportedAlgorithm
}
// Sign the given payload
func (ctx rsaDecrypterSigner) signPayload(payload []byte, alg SignatureAlgorithm) (Signature, error) {
var hash crypto.Hash
switch alg {
case RS256, PS256:
hash = crypto.SHA256
case RS384, PS384:
hash = crypto.SHA384
case RS512, PS512:
hash = crypto.SHA512
default:
return Signature{}, ErrUnsupportedAlgorithm
}
hasher := hash.New()
// According to documentation, Write() on hash never fails
_, _ = hasher.Write(payload)
hashed := hasher.Sum(nil)
var out []byte
var err error
switch alg {
case RS256, RS384, RS512:
out, err = rsa.SignPKCS1v15(RandReader, ctx.privateKey, hash, hashed)
case PS256, PS384, PS512:
out, err = rsa.SignPSS(RandReader, ctx.privateKey, hash, hashed, &rsa.PSSOptions{
SaltLength: rsa.PSSSaltLengthEqualsHash,
})
}
if err != nil {
return Signature{}, err
}
return Signature{
Signature: out,
protected: &rawHeader{},
}, nil
}
// Verify the given payload
func (ctx rsaEncrypterVerifier) verifyPayload(payload []byte, signature []byte, alg SignatureAlgorithm) error {
var hash crypto.Hash
switch alg {
case RS256, PS256:
hash = crypto.SHA256
case RS384, PS384:
hash = crypto.SHA384
case RS512, PS512:
hash = crypto.SHA512
default:
return ErrUnsupportedAlgorithm
}
hasher := hash.New()
// According to documentation, Write() on hash never fails
_, _ = hasher.Write(payload)
hashed := hasher.Sum(nil)
switch alg {
case RS256, RS384, RS512:
return rsa.VerifyPKCS1v15(ctx.publicKey, hash, hashed, signature)
case PS256, PS384, PS512:
return rsa.VerifyPSS(ctx.publicKey, hash, hashed, signature, nil)
}
return ErrUnsupportedAlgorithm
}
// Encrypt the given payload and update the object.
func (ctx ecEncrypterVerifier) encryptKey(cek []byte, alg KeyAlgorithm) (recipientInfo, error) {
switch alg {
case ECDH_ES:
// ECDH-ES mode doesn't wrap a key, the shared secret is used directly as the key.
return recipientInfo{
header: &rawHeader{},
}, nil
case ECDH_ES_A128KW, ECDH_ES_A192KW, ECDH_ES_A256KW:
default:
return recipientInfo{}, ErrUnsupportedAlgorithm
}
generator := ecKeyGenerator{
algID: string(alg),
publicKey: ctx.publicKey,
}
switch alg {
case ECDH_ES_A128KW:
generator.size = 16
case ECDH_ES_A192KW:
generator.size = 24
case ECDH_ES_A256KW:
generator.size = 32
}
kek, header, err := generator.genKey()
if err != nil {
return recipientInfo{}, err
}
block, err := aes.NewCipher(kek)
if err != nil {
return recipientInfo{}, err
}
jek, err := josecipher.KeyWrap(block, cek)
if err != nil {
return recipientInfo{}, err
}
return recipientInfo{
encryptedKey: jek,
header: &header,
}, nil
}
// Get key size for EC key generator
func (ctx ecKeyGenerator) keySize() int {
return ctx.size
}
// Get a content encryption key for ECDH-ES
func (ctx ecKeyGenerator) genKey() ([]byte, rawHeader, error) {
priv, err := ecdsa.GenerateKey(ctx.publicKey.Curve, RandReader)
if err != nil {
return nil, rawHeader{}, err
}
out := josecipher.DeriveECDHES(ctx.algID, []byte{}, []byte{}, priv, ctx.publicKey, ctx.size)
b, err := json.Marshal(&JSONWebKey{
Key: &priv.PublicKey,
})
if err != nil {
return nil, nil, err
}
headers := rawHeader{
headerEPK: makeRawMessage(b),
}
return out, headers, nil
}
// Decrypt the given payload and return the content encryption key.
func (ctx ecDecrypterSigner) decryptKey(headers rawHeader, recipient *recipientInfo, generator keyGenerator) ([]byte, error) {
epk, err := headers.getEPK()
if err != nil {
return nil, errors.New("go-jose/go-jose: invalid epk header")
}
if epk == nil {
return nil, errors.New("go-jose/go-jose: missing epk header")
}
publicKey, ok := epk.Key.(*ecdsa.PublicKey)
if publicKey == nil || !ok {
return nil, errors.New("go-jose/go-jose: invalid epk header")
}
if !ctx.privateKey.Curve.IsOnCurve(publicKey.X, publicKey.Y) {
return nil, errors.New("go-jose/go-jose: invalid public key in epk header")
}
apuData, err := headers.getAPU()
if err != nil {
return nil, errors.New("go-jose/go-jose: invalid apu header")
}
apvData, err := headers.getAPV()
if err != nil {
return nil, errors.New("go-jose/go-jose: invalid apv header")
}
deriveKey := func(algID string, size int) []byte {
return josecipher.DeriveECDHES(algID, apuData.bytes(), apvData.bytes(), ctx.privateKey, publicKey, size)
}
var keySize int
algorithm := headers.getAlgorithm()
switch algorithm {
case ECDH_ES:
// ECDH-ES uses direct key agreement, no key unwrapping necessary.
return deriveKey(string(headers.getEncryption()), generator.keySize()), nil
case ECDH_ES_A128KW:
keySize = 16
case ECDH_ES_A192KW:
keySize = 24
case ECDH_ES_A256KW:
keySize = 32
default:
return nil, ErrUnsupportedAlgorithm
}
key := deriveKey(string(algorithm), keySize)
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
return josecipher.KeyUnwrap(block, recipient.encryptedKey)
}
func (ctx edDecrypterSigner) signPayload(payload []byte, alg SignatureAlgorithm) (Signature, error) {
if alg != EdDSA {
return Signature{}, ErrUnsupportedAlgorithm
}
sig, err := ctx.privateKey.Sign(RandReader, payload, crypto.Hash(0))
if err != nil {
return Signature{}, err
}
return Signature{
Signature: sig,
protected: &rawHeader{},
}, nil
}
func (ctx edEncrypterVerifier) verifyPayload(payload []byte, signature []byte, alg SignatureAlgorithm) error {
if alg != EdDSA {
return ErrUnsupportedAlgorithm
}
ok := ed25519.Verify(ctx.publicKey, payload, signature)
if !ok {
return errors.New("go-jose/go-jose: ed25519 signature failed to verify")
}
return nil
}
// Sign the given payload
func (ctx ecDecrypterSigner) signPayload(payload []byte, alg SignatureAlgorithm) (Signature, error) {
var expectedBitSize int
var hash crypto.Hash
switch alg {
case ES256:
expectedBitSize = 256
hash = crypto.SHA256
case ES384:
expectedBitSize = 384
hash = crypto.SHA384
case ES512:
expectedBitSize = 521
hash = crypto.SHA512
}
curveBits := ctx.privateKey.Curve.Params().BitSize
if expectedBitSize != curveBits {
return Signature{}, fmt.Errorf("go-jose/go-jose: expected %d bit key, got %d bits instead", expectedBitSize, curveBits)
}
hasher := hash.New()
// According to documentation, Write() on hash never fails
_, _ = hasher.Write(payload)
hashed := hasher.Sum(nil)
r, s, err := ecdsa.Sign(RandReader, ctx.privateKey, hashed)
if err != nil {
return Signature{}, err
}
keyBytes := curveBits / 8
if curveBits%8 > 0 {
keyBytes++
}
// We serialize the outputs (r and s) into big-endian byte arrays and pad
// them with zeros on the left to make sure the sizes work out. Both arrays
// must be keyBytes long, and the output must be 2*keyBytes long.
rBytes := r.Bytes()
rBytesPadded := make([]byte, keyBytes)
copy(rBytesPadded[keyBytes-len(rBytes):], rBytes)
sBytes := s.Bytes()
sBytesPadded := make([]byte, keyBytes)
copy(sBytesPadded[keyBytes-len(sBytes):], sBytes)
out := append(rBytesPadded, sBytesPadded...)
return Signature{
Signature: out,
protected: &rawHeader{},
}, nil
}
// Verify the given payload
func (ctx ecEncrypterVerifier) verifyPayload(payload []byte, signature []byte, alg SignatureAlgorithm) error {
var keySize int
var hash crypto.Hash
switch alg {
case ES256:
keySize = 32
hash = crypto.SHA256
case ES384:
keySize = 48
hash = crypto.SHA384
case ES512:
keySize = 66
hash = crypto.SHA512
default:
return ErrUnsupportedAlgorithm
}
if len(signature) != 2*keySize {
return fmt.Errorf("go-jose/go-jose: invalid signature size, have %d bytes, wanted %d", len(signature), 2*keySize)
}
hasher := hash.New()
// According to documentation, Write() on hash never fails
_, _ = hasher.Write(payload)
hashed := hasher.Sum(nil)
r := big.NewInt(0).SetBytes(signature[:keySize])
s := big.NewInt(0).SetBytes(signature[keySize:])
match := ecdsa.Verify(ctx.publicKey, hashed, r, s)
if !match {
return errors.New("go-jose/go-jose: ecdsa signature failed to verify")
}
return nil
}

196
vendor/github.com/go-jose/go-jose/v3/cipher/cbc_hmac.go generated vendored Normal file
View File

@ -0,0 +1,196 @@
/*-
* Copyright 2014 Square Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package josecipher
import (
"bytes"
"crypto/cipher"
"crypto/hmac"
"crypto/sha256"
"crypto/sha512"
"crypto/subtle"
"encoding/binary"
"errors"
"hash"
)
const (
nonceBytes = 16
)
// NewCBCHMAC instantiates a new AEAD based on CBC+HMAC.
func NewCBCHMAC(key []byte, newBlockCipher func([]byte) (cipher.Block, error)) (cipher.AEAD, error) {
keySize := len(key) / 2
integrityKey := key[:keySize]
encryptionKey := key[keySize:]
blockCipher, err := newBlockCipher(encryptionKey)
if err != nil {
return nil, err
}
var hash func() hash.Hash
switch keySize {
case 16:
hash = sha256.New
case 24:
hash = sha512.New384
case 32:
hash = sha512.New
}
return &cbcAEAD{
hash: hash,
blockCipher: blockCipher,
authtagBytes: keySize,
integrityKey: integrityKey,
}, nil
}
// An AEAD based on CBC+HMAC
type cbcAEAD struct {
hash func() hash.Hash
authtagBytes int
integrityKey []byte
blockCipher cipher.Block
}
func (ctx *cbcAEAD) NonceSize() int {
return nonceBytes
}
func (ctx *cbcAEAD) Overhead() int {
// Maximum overhead is block size (for padding) plus auth tag length, where
// the length of the auth tag is equivalent to the key size.
return ctx.blockCipher.BlockSize() + ctx.authtagBytes
}
// Seal encrypts and authenticates the plaintext.
func (ctx *cbcAEAD) Seal(dst, nonce, plaintext, data []byte) []byte {
// Output buffer -- must take care not to mangle plaintext input.
ciphertext := make([]byte, uint64(len(plaintext))+uint64(ctx.Overhead()))[:len(plaintext)]
copy(ciphertext, plaintext)
ciphertext = padBuffer(ciphertext, ctx.blockCipher.BlockSize())
cbc := cipher.NewCBCEncrypter(ctx.blockCipher, nonce)
cbc.CryptBlocks(ciphertext, ciphertext)
authtag := ctx.computeAuthTag(data, nonce, ciphertext)
ret, out := resize(dst, uint64(len(dst))+uint64(len(ciphertext))+uint64(len(authtag)))
copy(out, ciphertext)
copy(out[len(ciphertext):], authtag)
return ret
}
// Open decrypts and authenticates the ciphertext.
func (ctx *cbcAEAD) Open(dst, nonce, ciphertext, data []byte) ([]byte, error) {
if len(ciphertext) < ctx.authtagBytes {
return nil, errors.New("go-jose/go-jose: invalid ciphertext (too short)")
}
offset := len(ciphertext) - ctx.authtagBytes
expectedTag := ctx.computeAuthTag(data, nonce, ciphertext[:offset])
match := subtle.ConstantTimeCompare(expectedTag, ciphertext[offset:])
if match != 1 {
return nil, errors.New("go-jose/go-jose: invalid ciphertext (auth tag mismatch)")
}
cbc := cipher.NewCBCDecrypter(ctx.blockCipher, nonce)
// Make copy of ciphertext buffer, don't want to modify in place
buffer := append([]byte{}, ciphertext[:offset]...)
if len(buffer)%ctx.blockCipher.BlockSize() > 0 {
return nil, errors.New("go-jose/go-jose: invalid ciphertext (invalid length)")
}
cbc.CryptBlocks(buffer, buffer)
// Remove padding
plaintext, err := unpadBuffer(buffer, ctx.blockCipher.BlockSize())
if err != nil {
return nil, err
}
ret, out := resize(dst, uint64(len(dst))+uint64(len(plaintext)))
copy(out, plaintext)
return ret, nil
}
// Compute an authentication tag
func (ctx *cbcAEAD) computeAuthTag(aad, nonce, ciphertext []byte) []byte {
buffer := make([]byte, uint64(len(aad))+uint64(len(nonce))+uint64(len(ciphertext))+8)
n := 0
n += copy(buffer, aad)
n += copy(buffer[n:], nonce)
n += copy(buffer[n:], ciphertext)
binary.BigEndian.PutUint64(buffer[n:], uint64(len(aad))*8)
// According to documentation, Write() on hash.Hash never fails.
hmac := hmac.New(ctx.hash, ctx.integrityKey)
_, _ = hmac.Write(buffer)
return hmac.Sum(nil)[:ctx.authtagBytes]
}
// resize ensures that the given slice has a capacity of at least n bytes.
// If the capacity of the slice is less than n, a new slice is allocated
// and the existing data will be copied.
func resize(in []byte, n uint64) (head, tail []byte) {
if uint64(cap(in)) >= n {
head = in[:n]
} else {
head = make([]byte, n)
copy(head, in)
}
tail = head[len(in):]
return
}
// Apply padding
func padBuffer(buffer []byte, blockSize int) []byte {
missing := blockSize - (len(buffer) % blockSize)
ret, out := resize(buffer, uint64(len(buffer))+uint64(missing))
padding := bytes.Repeat([]byte{byte(missing)}, missing)
copy(out, padding)
return ret
}
// Remove padding
func unpadBuffer(buffer []byte, blockSize int) ([]byte, error) {
if len(buffer)%blockSize != 0 {
return nil, errors.New("go-jose/go-jose: invalid padding")
}
last := buffer[len(buffer)-1]
count := int(last)
if count == 0 || count > blockSize || count > len(buffer) {
return nil, errors.New("go-jose/go-jose: invalid padding")
}
padding := bytes.Repeat([]byte{last}, count)
if !bytes.HasSuffix(buffer, padding) {
return nil, errors.New("go-jose/go-jose: invalid padding")
}
return buffer[:len(buffer)-count], nil
}

View File

@ -0,0 +1,75 @@
/*-
* Copyright 2014 Square Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package josecipher
import (
"crypto"
"encoding/binary"
"hash"
"io"
)
type concatKDF struct {
z, info []byte
i uint32
cache []byte
hasher hash.Hash
}
// NewConcatKDF builds a KDF reader based on the given inputs.
func NewConcatKDF(hash crypto.Hash, z, algID, ptyUInfo, ptyVInfo, supPubInfo, supPrivInfo []byte) io.Reader {
buffer := make([]byte, uint64(len(algID))+uint64(len(ptyUInfo))+uint64(len(ptyVInfo))+uint64(len(supPubInfo))+uint64(len(supPrivInfo)))
n := 0
n += copy(buffer, algID)
n += copy(buffer[n:], ptyUInfo)
n += copy(buffer[n:], ptyVInfo)
n += copy(buffer[n:], supPubInfo)
copy(buffer[n:], supPrivInfo)
hasher := hash.New()
return &concatKDF{
z: z,
info: buffer,
hasher: hasher,
cache: []byte{},
i: 1,
}
}
func (ctx *concatKDF) Read(out []byte) (int, error) {
copied := copy(out, ctx.cache)
ctx.cache = ctx.cache[copied:]
for copied < len(out) {
ctx.hasher.Reset()
// Write on a hash.Hash never fails
_ = binary.Write(ctx.hasher, binary.BigEndian, ctx.i)
_, _ = ctx.hasher.Write(ctx.z)
_, _ = ctx.hasher.Write(ctx.info)
hash := ctx.hasher.Sum(nil)
chunkCopied := copy(out[copied:], hash)
copied += chunkCopied
ctx.cache = hash[chunkCopied:]
ctx.i++
}
return copied, nil
}

86
vendor/github.com/go-jose/go-jose/v3/cipher/ecdh_es.go generated vendored Normal file
View File

@ -0,0 +1,86 @@
/*-
* Copyright 2014 Square Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package josecipher
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"encoding/binary"
)
// DeriveECDHES derives a shared encryption key using ECDH/ConcatKDF as described in JWE/JWA.
// It is an error to call this function with a private/public key that are not on the same
// curve. Callers must ensure that the keys are valid before calling this function. Output
// size may be at most 1<<16 bytes (64 KiB).
func DeriveECDHES(alg string, apuData, apvData []byte, priv *ecdsa.PrivateKey, pub *ecdsa.PublicKey, size int) []byte {
if size > 1<<16 {
panic("ECDH-ES output size too large, must be less than or equal to 1<<16")
}
// algId, partyUInfo, partyVInfo inputs must be prefixed with the length
algID := lengthPrefixed([]byte(alg))
ptyUInfo := lengthPrefixed(apuData)
ptyVInfo := lengthPrefixed(apvData)
// suppPubInfo is the encoded length of the output size in bits
supPubInfo := make([]byte, 4)
binary.BigEndian.PutUint32(supPubInfo, uint32(size)*8)
if !priv.PublicKey.Curve.IsOnCurve(pub.X, pub.Y) {
panic("public key not on same curve as private key")
}
z, _ := priv.Curve.ScalarMult(pub.X, pub.Y, priv.D.Bytes())
zBytes := z.Bytes()
// Note that calling z.Bytes() on a big.Int may strip leading zero bytes from
// the returned byte array. This can lead to a problem where zBytes will be
// shorter than expected which breaks the key derivation. Therefore we must pad
// to the full length of the expected coordinate here before calling the KDF.
octSize := dSize(priv.Curve)
if len(zBytes) != octSize {
zBytes = append(bytes.Repeat([]byte{0}, octSize-len(zBytes)), zBytes...)
}
reader := NewConcatKDF(crypto.SHA256, zBytes, algID, ptyUInfo, ptyVInfo, supPubInfo, []byte{})
key := make([]byte, size)
// Read on the KDF will never fail
_, _ = reader.Read(key)
return key
}
// dSize returns the size in octets for a coordinate on a elliptic curve.
func dSize(curve elliptic.Curve) int {
order := curve.Params().P
bitLen := order.BitLen()
size := bitLen / 8
if bitLen%8 != 0 {
size++
}
return size
}
func lengthPrefixed(data []byte) []byte {
out := make([]byte, len(data)+4)
binary.BigEndian.PutUint32(out, uint32(len(data)))
copy(out[4:], data)
return out
}

109
vendor/github.com/go-jose/go-jose/v3/cipher/key_wrap.go generated vendored Normal file
View File

@ -0,0 +1,109 @@
/*-
* Copyright 2014 Square Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package josecipher
import (
"crypto/cipher"
"crypto/subtle"
"encoding/binary"
"errors"
)
var defaultIV = []byte{0xA6, 0xA6, 0xA6, 0xA6, 0xA6, 0xA6, 0xA6, 0xA6}
// KeyWrap implements NIST key wrapping; it wraps a content encryption key (cek) with the given block cipher.
func KeyWrap(block cipher.Block, cek []byte) ([]byte, error) {
if len(cek)%8 != 0 {
return nil, errors.New("go-jose/go-jose: key wrap input must be 8 byte blocks")
}
n := len(cek) / 8
r := make([][]byte, n)
for i := range r {
r[i] = make([]byte, 8)
copy(r[i], cek[i*8:])
}
buffer := make([]byte, 16)
tBytes := make([]byte, 8)
copy(buffer, defaultIV)
for t := 0; t < 6*n; t++ {
copy(buffer[8:], r[t%n])
block.Encrypt(buffer, buffer)
binary.BigEndian.PutUint64(tBytes, uint64(t+1))
for i := 0; i < 8; i++ {
buffer[i] ^= tBytes[i]
}
copy(r[t%n], buffer[8:])
}
out := make([]byte, (n+1)*8)
copy(out, buffer[:8])
for i := range r {
copy(out[(i+1)*8:], r[i])
}
return out, nil
}
// KeyUnwrap implements NIST key unwrapping; it unwraps a content encryption key (cek) with the given block cipher.
func KeyUnwrap(block cipher.Block, ciphertext []byte) ([]byte, error) {
if len(ciphertext)%8 != 0 {
return nil, errors.New("go-jose/go-jose: key wrap input must be 8 byte blocks")
}
n := (len(ciphertext) / 8) - 1
r := make([][]byte, n)
for i := range r {
r[i] = make([]byte, 8)
copy(r[i], ciphertext[(i+1)*8:])
}
buffer := make([]byte, 16)
tBytes := make([]byte, 8)
copy(buffer[:8], ciphertext[:8])
for t := 6*n - 1; t >= 0; t-- {
binary.BigEndian.PutUint64(tBytes, uint64(t+1))
for i := 0; i < 8; i++ {
buffer[i] ^= tBytes[i]
}
copy(buffer[8:], r[t%n])
block.Decrypt(buffer, buffer)
copy(r[t%n], buffer[8:])
}
if subtle.ConstantTimeCompare(buffer[:8], defaultIV) == 0 {
return nil, errors.New("go-jose/go-jose: failed to unwrap key")
}
out := make([]byte, n*8)
for i := range r {
copy(out[i*8:], r[i])
}
return out, nil
}

544
vendor/github.com/go-jose/go-jose/v3/crypter.go generated vendored Normal file
View File

@ -0,0 +1,544 @@
/*-
* Copyright 2014 Square Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package jose
import (
"crypto/ecdsa"
"crypto/rsa"
"errors"
"fmt"
"reflect"
"github.com/go-jose/go-jose/v3/json"
)
// Encrypter represents an encrypter which produces an encrypted JWE object.
type Encrypter interface {
Encrypt(plaintext []byte) (*JSONWebEncryption, error)
EncryptWithAuthData(plaintext []byte, aad []byte) (*JSONWebEncryption, error)
Options() EncrypterOptions
}
// A generic content cipher
type contentCipher interface {
keySize() int
encrypt(cek []byte, aad, plaintext []byte) (*aeadParts, error)
decrypt(cek []byte, aad []byte, parts *aeadParts) ([]byte, error)
}
// A key generator (for generating/getting a CEK)
type keyGenerator interface {
keySize() int
genKey() ([]byte, rawHeader, error)
}
// A generic key encrypter
type keyEncrypter interface {
encryptKey(cek []byte, alg KeyAlgorithm) (recipientInfo, error) // Encrypt a key
}
// A generic key decrypter
type keyDecrypter interface {
decryptKey(headers rawHeader, recipient *recipientInfo, generator keyGenerator) ([]byte, error) // Decrypt a key
}
// A generic encrypter based on the given key encrypter and content cipher.
type genericEncrypter struct {
contentAlg ContentEncryption
compressionAlg CompressionAlgorithm
cipher contentCipher
recipients []recipientKeyInfo
keyGenerator keyGenerator
extraHeaders map[HeaderKey]interface{}
}
type recipientKeyInfo struct {
keyID string
keyAlg KeyAlgorithm
keyEncrypter keyEncrypter
}
// EncrypterOptions represents options that can be set on new encrypters.
type EncrypterOptions struct {
Compression CompressionAlgorithm
// Optional map of additional keys to be inserted into the protected header
// of a JWS object. Some specifications which make use of JWS like to insert
// additional values here. All values must be JSON-serializable.
ExtraHeaders map[HeaderKey]interface{}
}
// WithHeader adds an arbitrary value to the ExtraHeaders map, initializing it
// if necessary. It returns itself and so can be used in a fluent style.
func (eo *EncrypterOptions) WithHeader(k HeaderKey, v interface{}) *EncrypterOptions {
if eo.ExtraHeaders == nil {
eo.ExtraHeaders = map[HeaderKey]interface{}{}
}
eo.ExtraHeaders[k] = v
return eo
}
// WithContentType adds a content type ("cty") header and returns the updated
// EncrypterOptions.
func (eo *EncrypterOptions) WithContentType(contentType ContentType) *EncrypterOptions {
return eo.WithHeader(HeaderContentType, contentType)
}
// WithType adds a type ("typ") header and returns the updated EncrypterOptions.
func (eo *EncrypterOptions) WithType(typ ContentType) *EncrypterOptions {
return eo.WithHeader(HeaderType, typ)
}
// Recipient represents an algorithm/key to encrypt messages to.
//
// PBES2Count and PBES2Salt correspond with the "p2c" and "p2s" headers used
// on the password-based encryption algorithms PBES2-HS256+A128KW,
// PBES2-HS384+A192KW, and PBES2-HS512+A256KW. If they are not provided a safe
// default of 100000 will be used for the count and a 128-bit random salt will
// be generated.
type Recipient struct {
Algorithm KeyAlgorithm
Key interface{}
KeyID string
PBES2Count int
PBES2Salt []byte
}
// NewEncrypter creates an appropriate encrypter based on the key type
func NewEncrypter(enc ContentEncryption, rcpt Recipient, opts *EncrypterOptions) (Encrypter, error) {
encrypter := &genericEncrypter{
contentAlg: enc,
recipients: []recipientKeyInfo{},
cipher: getContentCipher(enc),
}
if opts != nil {
encrypter.compressionAlg = opts.Compression
encrypter.extraHeaders = opts.ExtraHeaders
}
if encrypter.cipher == nil {
return nil, ErrUnsupportedAlgorithm
}
var keyID string
var rawKey interface{}
switch encryptionKey := rcpt.Key.(type) {
case JSONWebKey:
keyID, rawKey = encryptionKey.KeyID, encryptionKey.Key
case *JSONWebKey:
keyID, rawKey = encryptionKey.KeyID, encryptionKey.Key
case OpaqueKeyEncrypter:
keyID, rawKey = encryptionKey.KeyID(), encryptionKey
default:
rawKey = encryptionKey
}
switch rcpt.Algorithm {
case DIRECT:
// Direct encryption mode must be treated differently
if reflect.TypeOf(rawKey) != reflect.TypeOf([]byte{}) {
return nil, ErrUnsupportedKeyType
}
if encrypter.cipher.keySize() != len(rawKey.([]byte)) {
return nil, ErrInvalidKeySize
}
encrypter.keyGenerator = staticKeyGenerator{
key: rawKey.([]byte),
}
recipientInfo, _ := newSymmetricRecipient(rcpt.Algorithm, rawKey.([]byte))
recipientInfo.keyID = keyID
if rcpt.KeyID != "" {
recipientInfo.keyID = rcpt.KeyID
}
encrypter.recipients = []recipientKeyInfo{recipientInfo}
return encrypter, nil
case ECDH_ES:
// ECDH-ES (w/o key wrapping) is similar to DIRECT mode
typeOf := reflect.TypeOf(rawKey)
if typeOf != reflect.TypeOf(&ecdsa.PublicKey{}) {
return nil, ErrUnsupportedKeyType
}
encrypter.keyGenerator = ecKeyGenerator{
size: encrypter.cipher.keySize(),
algID: string(enc),
publicKey: rawKey.(*ecdsa.PublicKey),
}
recipientInfo, _ := newECDHRecipient(rcpt.Algorithm, rawKey.(*ecdsa.PublicKey))
recipientInfo.keyID = keyID
if rcpt.KeyID != "" {
recipientInfo.keyID = rcpt.KeyID
}
encrypter.recipients = []recipientKeyInfo{recipientInfo}
return encrypter, nil
default:
// Can just add a standard recipient
encrypter.keyGenerator = randomKeyGenerator{
size: encrypter.cipher.keySize(),
}
err := encrypter.addRecipient(rcpt)
return encrypter, err
}
}
// NewMultiEncrypter creates a multi-encrypter based on the given parameters
func NewMultiEncrypter(enc ContentEncryption, rcpts []Recipient, opts *EncrypterOptions) (Encrypter, error) {
cipher := getContentCipher(enc)
if cipher == nil {
return nil, ErrUnsupportedAlgorithm
}
if len(rcpts) == 0 {
return nil, fmt.Errorf("go-jose/go-jose: recipients is nil or empty")
}
encrypter := &genericEncrypter{
contentAlg: enc,
recipients: []recipientKeyInfo{},
cipher: cipher,
keyGenerator: randomKeyGenerator{
size: cipher.keySize(),
},
}
if opts != nil {
encrypter.compressionAlg = opts.Compression
encrypter.extraHeaders = opts.ExtraHeaders
}
for _, recipient := range rcpts {
err := encrypter.addRecipient(recipient)
if err != nil {
return nil, err
}
}
return encrypter, nil
}
func (ctx *genericEncrypter) addRecipient(recipient Recipient) (err error) {
var recipientInfo recipientKeyInfo
switch recipient.Algorithm {
case DIRECT, ECDH_ES:
return fmt.Errorf("go-jose/go-jose: key algorithm '%s' not supported in multi-recipient mode", recipient.Algorithm)
}
recipientInfo, err = makeJWERecipient(recipient.Algorithm, recipient.Key)
if recipient.KeyID != "" {
recipientInfo.keyID = recipient.KeyID
}
switch recipient.Algorithm {
case PBES2_HS256_A128KW, PBES2_HS384_A192KW, PBES2_HS512_A256KW:
if sr, ok := recipientInfo.keyEncrypter.(*symmetricKeyCipher); ok {
sr.p2c = recipient.PBES2Count
sr.p2s = recipient.PBES2Salt
}
}
if err == nil {
ctx.recipients = append(ctx.recipients, recipientInfo)
}
return err
}
func makeJWERecipient(alg KeyAlgorithm, encryptionKey interface{}) (recipientKeyInfo, error) {
switch encryptionKey := encryptionKey.(type) {
case *rsa.PublicKey:
return newRSARecipient(alg, encryptionKey)
case *ecdsa.PublicKey:
return newECDHRecipient(alg, encryptionKey)
case []byte:
return newSymmetricRecipient(alg, encryptionKey)
case string:
return newSymmetricRecipient(alg, []byte(encryptionKey))
case *JSONWebKey:
recipient, err := makeJWERecipient(alg, encryptionKey.Key)
recipient.keyID = encryptionKey.KeyID
return recipient, err
}
if encrypter, ok := encryptionKey.(OpaqueKeyEncrypter); ok {
return newOpaqueKeyEncrypter(alg, encrypter)
}
return recipientKeyInfo{}, ErrUnsupportedKeyType
}
// newDecrypter creates an appropriate decrypter based on the key type
func newDecrypter(decryptionKey interface{}) (keyDecrypter, error) {
switch decryptionKey := decryptionKey.(type) {
case *rsa.PrivateKey:
return &rsaDecrypterSigner{
privateKey: decryptionKey,
}, nil
case *ecdsa.PrivateKey:
return &ecDecrypterSigner{
privateKey: decryptionKey,
}, nil
case []byte:
return &symmetricKeyCipher{
key: decryptionKey,
}, nil
case string:
return &symmetricKeyCipher{
key: []byte(decryptionKey),
}, nil
case JSONWebKey:
return newDecrypter(decryptionKey.Key)
case *JSONWebKey:
return newDecrypter(decryptionKey.Key)
}
if okd, ok := decryptionKey.(OpaqueKeyDecrypter); ok {
return &opaqueKeyDecrypter{decrypter: okd}, nil
}
return nil, ErrUnsupportedKeyType
}
// Implementation of encrypt method producing a JWE object.
func (ctx *genericEncrypter) Encrypt(plaintext []byte) (*JSONWebEncryption, error) {
return ctx.EncryptWithAuthData(plaintext, nil)
}
// Implementation of encrypt method producing a JWE object.
func (ctx *genericEncrypter) EncryptWithAuthData(plaintext, aad []byte) (*JSONWebEncryption, error) {
obj := &JSONWebEncryption{}
obj.aad = aad
obj.protected = &rawHeader{}
err := obj.protected.set(headerEncryption, ctx.contentAlg)
if err != nil {
return nil, err
}
obj.recipients = make([]recipientInfo, len(ctx.recipients))
if len(ctx.recipients) == 0 {
return nil, fmt.Errorf("go-jose/go-jose: no recipients to encrypt to")
}
cek, headers, err := ctx.keyGenerator.genKey()
if err != nil {
return nil, err
}
obj.protected.merge(&headers)
for i, info := range ctx.recipients {
recipient, err := info.keyEncrypter.encryptKey(cek, info.keyAlg)
if err != nil {
return nil, err
}
err = recipient.header.set(headerAlgorithm, info.keyAlg)
if err != nil {
return nil, err
}
if info.keyID != "" {
err = recipient.header.set(headerKeyID, info.keyID)
if err != nil {
return nil, err
}
}
obj.recipients[i] = recipient
}
if len(ctx.recipients) == 1 {
// Move per-recipient headers into main protected header if there's
// only a single recipient.
obj.protected.merge(obj.recipients[0].header)
obj.recipients[0].header = nil
}
if ctx.compressionAlg != NONE {
plaintext, err = compress(ctx.compressionAlg, plaintext)
if err != nil {
return nil, err
}
err = obj.protected.set(headerCompression, ctx.compressionAlg)
if err != nil {
return nil, err
}
}
for k, v := range ctx.extraHeaders {
b, err := json.Marshal(v)
if err != nil {
return nil, err
}
(*obj.protected)[k] = makeRawMessage(b)
}
authData := obj.computeAuthData()
parts, err := ctx.cipher.encrypt(cek, authData, plaintext)
if err != nil {
return nil, err
}
obj.iv = parts.iv
obj.ciphertext = parts.ciphertext
obj.tag = parts.tag
return obj, nil
}
func (ctx *genericEncrypter) Options() EncrypterOptions {
return EncrypterOptions{
Compression: ctx.compressionAlg,
ExtraHeaders: ctx.extraHeaders,
}
}
// Decrypt and validate the object and return the plaintext. Note that this
// function does not support multi-recipient, if you desire multi-recipient
// decryption use DecryptMulti instead.
func (obj JSONWebEncryption) Decrypt(decryptionKey interface{}) ([]byte, error) {
headers := obj.mergedHeaders(nil)
if len(obj.recipients) > 1 {
return nil, errors.New("go-jose/go-jose: too many recipients in payload; expecting only one")
}
critical, err := headers.getCritical()
if err != nil {
return nil, fmt.Errorf("go-jose/go-jose: invalid crit header")
}
if len(critical) > 0 {
return nil, fmt.Errorf("go-jose/go-jose: unsupported crit header")
}
key := tryJWKS(decryptionKey, obj.Header)
decrypter, err := newDecrypter(key)
if err != nil {
return nil, err
}
cipher := getContentCipher(headers.getEncryption())
if cipher == nil {
return nil, fmt.Errorf("go-jose/go-jose: unsupported enc value '%s'", string(headers.getEncryption()))
}
generator := randomKeyGenerator{
size: cipher.keySize(),
}
parts := &aeadParts{
iv: obj.iv,
ciphertext: obj.ciphertext,
tag: obj.tag,
}
authData := obj.computeAuthData()
var plaintext []byte
recipient := obj.recipients[0]
recipientHeaders := obj.mergedHeaders(&recipient)
cek, err := decrypter.decryptKey(recipientHeaders, &recipient, generator)
if err == nil {
// Found a valid CEK -- let's try to decrypt.
plaintext, err = cipher.decrypt(cek, authData, parts)
}
if plaintext == nil {
return nil, ErrCryptoFailure
}
// The "zip" header parameter may only be present in the protected header.
if comp := obj.protected.getCompression(); comp != "" {
plaintext, err = decompress(comp, plaintext)
}
return plaintext, err
}
// DecryptMulti decrypts and validates the object and returns the plaintexts,
// with support for multiple recipients. It returns the index of the recipient
// for which the decryption was successful, the merged headers for that recipient,
// and the plaintext.
func (obj JSONWebEncryption) DecryptMulti(decryptionKey interface{}) (int, Header, []byte, error) {
globalHeaders := obj.mergedHeaders(nil)
critical, err := globalHeaders.getCritical()
if err != nil {
return -1, Header{}, nil, fmt.Errorf("go-jose/go-jose: invalid crit header")
}
if len(critical) > 0 {
return -1, Header{}, nil, fmt.Errorf("go-jose/go-jose: unsupported crit header")
}
key := tryJWKS(decryptionKey, obj.Header)
decrypter, err := newDecrypter(key)
if err != nil {
return -1, Header{}, nil, err
}
encryption := globalHeaders.getEncryption()
cipher := getContentCipher(encryption)
if cipher == nil {
return -1, Header{}, nil, fmt.Errorf("go-jose/go-jose: unsupported enc value '%s'", string(encryption))
}
generator := randomKeyGenerator{
size: cipher.keySize(),
}
parts := &aeadParts{
iv: obj.iv,
ciphertext: obj.ciphertext,
tag: obj.tag,
}
authData := obj.computeAuthData()
index := -1
var plaintext []byte
var headers rawHeader
for i, recipient := range obj.recipients {
recipientHeaders := obj.mergedHeaders(&recipient)
cek, err := decrypter.decryptKey(recipientHeaders, &recipient, generator)
if err == nil {
// Found a valid CEK -- let's try to decrypt.
plaintext, err = cipher.decrypt(cek, authData, parts)
if err == nil {
index = i
headers = recipientHeaders
break
}
}
}
if plaintext == nil {
return -1, Header{}, nil, ErrCryptoFailure
}
// The "zip" header parameter may only be present in the protected header.
if comp := obj.protected.getCompression(); comp != "" {
plaintext, _ = decompress(comp, plaintext)
}
sanitized, err := headers.sanitized()
if err != nil {
return -1, Header{}, nil, fmt.Errorf("go-jose/go-jose: failed to sanitize header: %v", err)
}
return index, sanitized, plaintext, err
}

27
vendor/github.com/go-jose/go-jose/v3/doc.go generated vendored Normal file
View File

@ -0,0 +1,27 @@
/*-
* Copyright 2014 Square Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
Package jose aims to provide an implementation of the Javascript Object Signing
and Encryption set of standards. It implements encryption and signing based on
the JSON Web Encryption and JSON Web Signature standards, with optional JSON Web
Token support available in a sub-package. The library supports both the compact
and JWS/JWE JSON Serialization formats, and has optional support for multiple
recipients.
*/
package jose

191
vendor/github.com/go-jose/go-jose/v3/encoding.go generated vendored Normal file
View File

@ -0,0 +1,191 @@
/*-
* Copyright 2014 Square Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package jose
import (
"bytes"
"compress/flate"
"encoding/base64"
"encoding/binary"
"io"
"math/big"
"strings"
"unicode"
"github.com/go-jose/go-jose/v3/json"
)
// Helper function to serialize known-good objects.
// Precondition: value is not a nil pointer.
func mustSerializeJSON(value interface{}) []byte {
out, err := json.Marshal(value)
if err != nil {
panic(err)
}
// We never want to serialize the top-level value "null," since it's not a
// valid JOSE message. But if a caller passes in a nil pointer to this method,
// MarshalJSON will happily serialize it as the top-level value "null". If
// that value is then embedded in another operation, for instance by being
// base64-encoded and fed as input to a signing algorithm
// (https://github.com/go-jose/go-jose/issues/22), the result will be
// incorrect. Because this method is intended for known-good objects, and a nil
// pointer is not a known-good object, we are free to panic in this case.
// Note: It's not possible to directly check whether the data pointed at by an
// interface is a nil pointer, so we do this hacky workaround.
// https://groups.google.com/forum/#!topic/golang-nuts/wnH302gBa4I
if string(out) == "null" {
panic("Tried to serialize a nil pointer.")
}
return out
}
// Strip all newlines and whitespace
func stripWhitespace(data string) string {
buf := strings.Builder{}
buf.Grow(len(data))
for _, r := range data {
if !unicode.IsSpace(r) {
buf.WriteRune(r)
}
}
return buf.String()
}
// Perform compression based on algorithm
func compress(algorithm CompressionAlgorithm, input []byte) ([]byte, error) {
switch algorithm {
case DEFLATE:
return deflate(input)
default:
return nil, ErrUnsupportedAlgorithm
}
}
// Perform decompression based on algorithm
func decompress(algorithm CompressionAlgorithm, input []byte) ([]byte, error) {
switch algorithm {
case DEFLATE:
return inflate(input)
default:
return nil, ErrUnsupportedAlgorithm
}
}
// Compress with DEFLATE
func deflate(input []byte) ([]byte, error) {
output := new(bytes.Buffer)
// Writing to byte buffer, err is always nil
writer, _ := flate.NewWriter(output, 1)
_, _ = io.Copy(writer, bytes.NewBuffer(input))
err := writer.Close()
return output.Bytes(), err
}
// Decompress with DEFLATE
func inflate(input []byte) ([]byte, error) {
output := new(bytes.Buffer)
reader := flate.NewReader(bytes.NewBuffer(input))
_, err := io.Copy(output, reader)
if err != nil {
return nil, err
}
err = reader.Close()
return output.Bytes(), err
}
// byteBuffer represents a slice of bytes that can be serialized to url-safe base64.
type byteBuffer struct {
data []byte
}
func newBuffer(data []byte) *byteBuffer {
if data == nil {
return nil
}
return &byteBuffer{
data: data,
}
}
func newFixedSizeBuffer(data []byte, length int) *byteBuffer {
if len(data) > length {
panic("go-jose/go-jose: invalid call to newFixedSizeBuffer (len(data) > length)")
}
pad := make([]byte, length-len(data))
return newBuffer(append(pad, data...))
}
func newBufferFromInt(num uint64) *byteBuffer {
data := make([]byte, 8)
binary.BigEndian.PutUint64(data, num)
return newBuffer(bytes.TrimLeft(data, "\x00"))
}
func (b *byteBuffer) MarshalJSON() ([]byte, error) {
return json.Marshal(b.base64())
}
func (b *byteBuffer) UnmarshalJSON(data []byte) error {
var encoded string
err := json.Unmarshal(data, &encoded)
if err != nil {
return err
}
if encoded == "" {
return nil
}
decoded, err := base64URLDecode(encoded)
if err != nil {
return err
}
*b = *newBuffer(decoded)
return nil
}
func (b *byteBuffer) base64() string {
return base64.RawURLEncoding.EncodeToString(b.data)
}
func (b *byteBuffer) bytes() []byte {
// Handling nil here allows us to transparently handle nil slices when serializing.
if b == nil {
return nil
}
return b.data
}
func (b byteBuffer) bigInt() *big.Int {
return new(big.Int).SetBytes(b.data)
}
func (b byteBuffer) toInt() int {
return int(b.bigInt().Int64())
}
// base64URLDecode is implemented as defined in https://www.rfc-editor.org/rfc/rfc7515.html#appendix-C
func base64URLDecode(value string) ([]byte, error) {
value = strings.TrimRight(value, "=")
return base64.RawURLEncoding.DecodeString(value)
}

27
vendor/github.com/go-jose/go-jose/v3/json/LICENSE generated vendored Normal file
View File

@ -0,0 +1,27 @@
Copyright (c) 2012 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

13
vendor/github.com/go-jose/go-jose/v3/json/README.md generated vendored Normal file
View File

@ -0,0 +1,13 @@
# Safe JSON
This repository contains a fork of the `encoding/json` package from Go 1.6.
The following changes were made:
* Object deserialization uses case-sensitive member name matching instead of
[case-insensitive matching](https://www.ietf.org/mail-archive/web/json/current/msg03763.html).
This is to avoid differences in the interpretation of JOSE messages between
go-jose and libraries written in other languages.
* When deserializing a JSON object, we check for duplicate keys and reject the
input whenever we detect a duplicate. Rather than trying to work with malformed
data, we prefer to reject it right away.

1217
vendor/github.com/go-jose/go-jose/v3/json/decode.go generated vendored Normal file

File diff suppressed because it is too large Load Diff

1197
vendor/github.com/go-jose/go-jose/v3/json/encode.go generated vendored Normal file

File diff suppressed because it is too large Load Diff

141
vendor/github.com/go-jose/go-jose/v3/json/indent.go generated vendored Normal file
View File

@ -0,0 +1,141 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
import "bytes"
// Compact appends to dst the JSON-encoded src with
// insignificant space characters elided.
func Compact(dst *bytes.Buffer, src []byte) error {
return compact(dst, src, false)
}
func compact(dst *bytes.Buffer, src []byte, escape bool) error {
origLen := dst.Len()
var scan scanner
scan.reset()
start := 0
for i, c := range src {
if escape && (c == '<' || c == '>' || c == '&') {
if start < i {
dst.Write(src[start:i])
}
dst.WriteString(`\u00`)
dst.WriteByte(hex[c>>4])
dst.WriteByte(hex[c&0xF])
start = i + 1
}
// Convert U+2028 and U+2029 (E2 80 A8 and E2 80 A9).
if c == 0xE2 && i+2 < len(src) && src[i+1] == 0x80 && src[i+2]&^1 == 0xA8 {
if start < i {
dst.Write(src[start:i])
}
dst.WriteString(`\u202`)
dst.WriteByte(hex[src[i+2]&0xF])
start = i + 3
}
v := scan.step(&scan, c)
if v >= scanSkipSpace {
if v == scanError {
break
}
if start < i {
dst.Write(src[start:i])
}
start = i + 1
}
}
if scan.eof() == scanError {
dst.Truncate(origLen)
return scan.err
}
if start < len(src) {
dst.Write(src[start:])
}
return nil
}
func newline(dst *bytes.Buffer, prefix, indent string, depth int) {
dst.WriteByte('\n')
dst.WriteString(prefix)
for i := 0; i < depth; i++ {
dst.WriteString(indent)
}
}
// Indent appends to dst an indented form of the JSON-encoded src.
// Each element in a JSON object or array begins on a new,
// indented line beginning with prefix followed by one or more
// copies of indent according to the indentation nesting.
// The data appended to dst does not begin with the prefix nor
// any indentation, to make it easier to embed inside other formatted JSON data.
// Although leading space characters (space, tab, carriage return, newline)
// at the beginning of src are dropped, trailing space characters
// at the end of src are preserved and copied to dst.
// For example, if src has no trailing spaces, neither will dst;
// if src ends in a trailing newline, so will dst.
func Indent(dst *bytes.Buffer, src []byte, prefix, indent string) error {
origLen := dst.Len()
var scan scanner
scan.reset()
needIndent := false
depth := 0
for _, c := range src {
scan.bytes++
v := scan.step(&scan, c)
if v == scanSkipSpace {
continue
}
if v == scanError {
break
}
if needIndent && v != scanEndObject && v != scanEndArray {
needIndent = false
depth++
newline(dst, prefix, indent, depth)
}
// Emit semantically uninteresting bytes
// (in particular, punctuation in strings) unmodified.
if v == scanContinue {
dst.WriteByte(c)
continue
}
// Add spacing around real punctuation.
switch c {
case '{', '[':
// delay indent so that empty object and array are formatted as {} and [].
needIndent = true
dst.WriteByte(c)
case ',':
dst.WriteByte(c)
newline(dst, prefix, indent, depth)
case ':':
dst.WriteByte(c)
dst.WriteByte(' ')
case '}', ']':
if needIndent {
// suppress indent in empty object/array
needIndent = false
} else {
depth--
newline(dst, prefix, indent, depth)
}
dst.WriteByte(c)
default:
dst.WriteByte(c)
}
}
if scan.eof() == scanError {
dst.Truncate(origLen)
return scan.err
}
return nil
}

623
vendor/github.com/go-jose/go-jose/v3/json/scanner.go generated vendored Normal file
View File

@ -0,0 +1,623 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
// JSON value parser state machine.
// Just about at the limit of what is reasonable to write by hand.
// Some parts are a bit tedious, but overall it nicely factors out the
// otherwise common code from the multiple scanning functions
// in this package (Compact, Indent, checkValid, nextValue, etc).
//
// This file starts with two simple examples using the scanner
// before diving into the scanner itself.
import "strconv"
// checkValid verifies that data is valid JSON-encoded data.
// scan is passed in for use by checkValid to avoid an allocation.
func checkValid(data []byte, scan *scanner) error {
scan.reset()
for _, c := range data {
scan.bytes++
if scan.step(scan, c) == scanError {
return scan.err
}
}
if scan.eof() == scanError {
return scan.err
}
return nil
}
// nextValue splits data after the next whole JSON value,
// returning that value and the bytes that follow it as separate slices.
// scan is passed in for use by nextValue to avoid an allocation.
func nextValue(data []byte, scan *scanner) (value, rest []byte, err error) {
scan.reset()
for i, c := range data {
v := scan.step(scan, c)
if v >= scanEndObject {
switch v {
// probe the scanner with a space to determine whether we will
// get scanEnd on the next character. Otherwise, if the next character
// is not a space, scanEndTop allocates a needless error.
case scanEndObject, scanEndArray:
if scan.step(scan, ' ') == scanEnd {
return data[:i+1], data[i+1:], nil
}
case scanError:
return nil, nil, scan.err
case scanEnd:
return data[:i], data[i:], nil
}
}
}
if scan.eof() == scanError {
return nil, nil, scan.err
}
return data, nil, nil
}
// A SyntaxError is a description of a JSON syntax error.
type SyntaxError struct {
msg string // description of error
Offset int64 // error occurred after reading Offset bytes
}
func (e *SyntaxError) Error() string { return e.msg }
// A scanner is a JSON scanning state machine.
// Callers call scan.reset() and then pass bytes in one at a time
// by calling scan.step(&scan, c) for each byte.
// The return value, referred to as an opcode, tells the
// caller about significant parsing events like beginning
// and ending literals, objects, and arrays, so that the
// caller can follow along if it wishes.
// The return value scanEnd indicates that a single top-level
// JSON value has been completed, *before* the byte that
// just got passed in. (The indication must be delayed in order
// to recognize the end of numbers: is 123 a whole value or
// the beginning of 12345e+6?).
type scanner struct {
// The step is a func to be called to execute the next transition.
// Also tried using an integer constant and a single func
// with a switch, but using the func directly was 10% faster
// on a 64-bit Mac Mini, and it's nicer to read.
step func(*scanner, byte) int
// Reached end of top-level value.
endTop bool
// Stack of what we're in the middle of - array values, object keys, object values.
parseState []int
// Error that happened, if any.
err error
// 1-byte redo (see undo method)
redo bool
redoCode int
redoState func(*scanner, byte) int
// total bytes consumed, updated by decoder.Decode
bytes int64
}
// These values are returned by the state transition functions
// assigned to scanner.state and the method scanner.eof.
// They give details about the current state of the scan that
// callers might be interested to know about.
// It is okay to ignore the return value of any particular
// call to scanner.state: if one call returns scanError,
// every subsequent call will return scanError too.
const (
// Continue.
scanContinue = iota // uninteresting byte
scanBeginLiteral // end implied by next result != scanContinue
scanBeginObject // begin object
scanObjectKey // just finished object key (string)
scanObjectValue // just finished non-last object value
scanEndObject // end object (implies scanObjectValue if possible)
scanBeginArray // begin array
scanArrayValue // just finished array value
scanEndArray // end array (implies scanArrayValue if possible)
scanSkipSpace // space byte; can skip; known to be last "continue" result
// Stop.
scanEnd // top-level value ended *before* this byte; known to be first "stop" result
scanError // hit an error, scanner.err.
)
// These values are stored in the parseState stack.
// They give the current state of a composite value
// being scanned. If the parser is inside a nested value
// the parseState describes the nested state, outermost at entry 0.
const (
parseObjectKey = iota // parsing object key (before colon)
parseObjectValue // parsing object value (after colon)
parseArrayValue // parsing array value
)
// reset prepares the scanner for use.
// It must be called before calling s.step.
func (s *scanner) reset() {
s.step = stateBeginValue
s.parseState = s.parseState[0:0]
s.err = nil
s.redo = false
s.endTop = false
}
// eof tells the scanner that the end of input has been reached.
// It returns a scan status just as s.step does.
func (s *scanner) eof() int {
if s.err != nil {
return scanError
}
if s.endTop {
return scanEnd
}
s.step(s, ' ')
if s.endTop {
return scanEnd
}
if s.err == nil {
s.err = &SyntaxError{"unexpected end of JSON input", s.bytes}
}
return scanError
}
// pushParseState pushes a new parse state p onto the parse stack.
func (s *scanner) pushParseState(p int) {
s.parseState = append(s.parseState, p)
}
// popParseState pops a parse state (already obtained) off the stack
// and updates s.step accordingly.
func (s *scanner) popParseState() {
n := len(s.parseState) - 1
s.parseState = s.parseState[0:n]
s.redo = false
if n == 0 {
s.step = stateEndTop
s.endTop = true
} else {
s.step = stateEndValue
}
}
func isSpace(c byte) bool {
return c == ' ' || c == '\t' || c == '\r' || c == '\n'
}
// stateBeginValueOrEmpty is the state after reading `[`.
func stateBeginValueOrEmpty(s *scanner, c byte) int {
if c <= ' ' && isSpace(c) {
return scanSkipSpace
}
if c == ']' {
return stateEndValue(s, c)
}
return stateBeginValue(s, c)
}
// stateBeginValue is the state at the beginning of the input.
func stateBeginValue(s *scanner, c byte) int {
if c <= ' ' && isSpace(c) {
return scanSkipSpace
}
switch c {
case '{':
s.step = stateBeginStringOrEmpty
s.pushParseState(parseObjectKey)
return scanBeginObject
case '[':
s.step = stateBeginValueOrEmpty
s.pushParseState(parseArrayValue)
return scanBeginArray
case '"':
s.step = stateInString
return scanBeginLiteral
case '-':
s.step = stateNeg
return scanBeginLiteral
case '0': // beginning of 0.123
s.step = state0
return scanBeginLiteral
case 't': // beginning of true
s.step = stateT
return scanBeginLiteral
case 'f': // beginning of false
s.step = stateF
return scanBeginLiteral
case 'n': // beginning of null
s.step = stateN
return scanBeginLiteral
}
if '1' <= c && c <= '9' { // beginning of 1234.5
s.step = state1
return scanBeginLiteral
}
return s.error(c, "looking for beginning of value")
}
// stateBeginStringOrEmpty is the state after reading `{`.
func stateBeginStringOrEmpty(s *scanner, c byte) int {
if c <= ' ' && isSpace(c) {
return scanSkipSpace
}
if c == '}' {
n := len(s.parseState)
s.parseState[n-1] = parseObjectValue
return stateEndValue(s, c)
}
return stateBeginString(s, c)
}
// stateBeginString is the state after reading `{"key": value,`.
func stateBeginString(s *scanner, c byte) int {
if c <= ' ' && isSpace(c) {
return scanSkipSpace
}
if c == '"' {
s.step = stateInString
return scanBeginLiteral
}
return s.error(c, "looking for beginning of object key string")
}
// stateEndValue is the state after completing a value,
// such as after reading `{}` or `true` or `["x"`.
func stateEndValue(s *scanner, c byte) int {
n := len(s.parseState)
if n == 0 {
// Completed top-level before the current byte.
s.step = stateEndTop
s.endTop = true
return stateEndTop(s, c)
}
if c <= ' ' && isSpace(c) {
s.step = stateEndValue
return scanSkipSpace
}
ps := s.parseState[n-1]
switch ps {
case parseObjectKey:
if c == ':' {
s.parseState[n-1] = parseObjectValue
s.step = stateBeginValue
return scanObjectKey
}
return s.error(c, "after object key")
case parseObjectValue:
if c == ',' {
s.parseState[n-1] = parseObjectKey
s.step = stateBeginString
return scanObjectValue
}
if c == '}' {
s.popParseState()
return scanEndObject
}
return s.error(c, "after object key:value pair")
case parseArrayValue:
if c == ',' {
s.step = stateBeginValue
return scanArrayValue
}
if c == ']' {
s.popParseState()
return scanEndArray
}
return s.error(c, "after array element")
}
return s.error(c, "")
}
// stateEndTop is the state after finishing the top-level value,
// such as after reading `{}` or `[1,2,3]`.
// Only space characters should be seen now.
func stateEndTop(s *scanner, c byte) int {
if c != ' ' && c != '\t' && c != '\r' && c != '\n' {
// Complain about non-space byte on next call.
s.error(c, "after top-level value")
}
return scanEnd
}
// stateInString is the state after reading `"`.
func stateInString(s *scanner, c byte) int {
if c == '"' {
s.step = stateEndValue
return scanContinue
}
if c == '\\' {
s.step = stateInStringEsc
return scanContinue
}
if c < 0x20 {
return s.error(c, "in string literal")
}
return scanContinue
}
// stateInStringEsc is the state after reading `"\` during a quoted string.
func stateInStringEsc(s *scanner, c byte) int {
switch c {
case 'b', 'f', 'n', 'r', 't', '\\', '/', '"':
s.step = stateInString
return scanContinue
case 'u':
s.step = stateInStringEscU
return scanContinue
}
return s.error(c, "in string escape code")
}
// stateInStringEscU is the state after reading `"\u` during a quoted string.
func stateInStringEscU(s *scanner, c byte) int {
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
s.step = stateInStringEscU1
return scanContinue
}
// numbers
return s.error(c, "in \\u hexadecimal character escape")
}
// stateInStringEscU1 is the state after reading `"\u1` during a quoted string.
func stateInStringEscU1(s *scanner, c byte) int {
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
s.step = stateInStringEscU12
return scanContinue
}
// numbers
return s.error(c, "in \\u hexadecimal character escape")
}
// stateInStringEscU12 is the state after reading `"\u12` during a quoted string.
func stateInStringEscU12(s *scanner, c byte) int {
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
s.step = stateInStringEscU123
return scanContinue
}
// numbers
return s.error(c, "in \\u hexadecimal character escape")
}
// stateInStringEscU123 is the state after reading `"\u123` during a quoted string.
func stateInStringEscU123(s *scanner, c byte) int {
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
s.step = stateInString
return scanContinue
}
// numbers
return s.error(c, "in \\u hexadecimal character escape")
}
// stateNeg is the state after reading `-` during a number.
func stateNeg(s *scanner, c byte) int {
if c == '0' {
s.step = state0
return scanContinue
}
if '1' <= c && c <= '9' {
s.step = state1
return scanContinue
}
return s.error(c, "in numeric literal")
}
// state1 is the state after reading a non-zero integer during a number,
// such as after reading `1` or `100` but not `0`.
func state1(s *scanner, c byte) int {
if '0' <= c && c <= '9' {
s.step = state1
return scanContinue
}
return state0(s, c)
}
// state0 is the state after reading `0` during a number.
func state0(s *scanner, c byte) int {
if c == '.' {
s.step = stateDot
return scanContinue
}
if c == 'e' || c == 'E' {
s.step = stateE
return scanContinue
}
return stateEndValue(s, c)
}
// stateDot is the state after reading the integer and decimal point in a number,
// such as after reading `1.`.
func stateDot(s *scanner, c byte) int {
if '0' <= c && c <= '9' {
s.step = stateDot0
return scanContinue
}
return s.error(c, "after decimal point in numeric literal")
}
// stateDot0 is the state after reading the integer, decimal point, and subsequent
// digits of a number, such as after reading `3.14`.
func stateDot0(s *scanner, c byte) int {
if '0' <= c && c <= '9' {
return scanContinue
}
if c == 'e' || c == 'E' {
s.step = stateE
return scanContinue
}
return stateEndValue(s, c)
}
// stateE is the state after reading the mantissa and e in a number,
// such as after reading `314e` or `0.314e`.
func stateE(s *scanner, c byte) int {
if c == '+' || c == '-' {
s.step = stateESign
return scanContinue
}
return stateESign(s, c)
}
// stateESign is the state after reading the mantissa, e, and sign in a number,
// such as after reading `314e-` or `0.314e+`.
func stateESign(s *scanner, c byte) int {
if '0' <= c && c <= '9' {
s.step = stateE0
return scanContinue
}
return s.error(c, "in exponent of numeric literal")
}
// stateE0 is the state after reading the mantissa, e, optional sign,
// and at least one digit of the exponent in a number,
// such as after reading `314e-2` or `0.314e+1` or `3.14e0`.
func stateE0(s *scanner, c byte) int {
if '0' <= c && c <= '9' {
return scanContinue
}
return stateEndValue(s, c)
}
// stateT is the state after reading `t`.
func stateT(s *scanner, c byte) int {
if c == 'r' {
s.step = stateTr
return scanContinue
}
return s.error(c, "in literal true (expecting 'r')")
}
// stateTr is the state after reading `tr`.
func stateTr(s *scanner, c byte) int {
if c == 'u' {
s.step = stateTru
return scanContinue
}
return s.error(c, "in literal true (expecting 'u')")
}
// stateTru is the state after reading `tru`.
func stateTru(s *scanner, c byte) int {
if c == 'e' {
s.step = stateEndValue
return scanContinue
}
return s.error(c, "in literal true (expecting 'e')")
}
// stateF is the state after reading `f`.
func stateF(s *scanner, c byte) int {
if c == 'a' {
s.step = stateFa
return scanContinue
}
return s.error(c, "in literal false (expecting 'a')")
}
// stateFa is the state after reading `fa`.
func stateFa(s *scanner, c byte) int {
if c == 'l' {
s.step = stateFal
return scanContinue
}
return s.error(c, "in literal false (expecting 'l')")
}
// stateFal is the state after reading `fal`.
func stateFal(s *scanner, c byte) int {
if c == 's' {
s.step = stateFals
return scanContinue
}
return s.error(c, "in literal false (expecting 's')")
}
// stateFals is the state after reading `fals`.
func stateFals(s *scanner, c byte) int {
if c == 'e' {
s.step = stateEndValue
return scanContinue
}
return s.error(c, "in literal false (expecting 'e')")
}
// stateN is the state after reading `n`.
func stateN(s *scanner, c byte) int {
if c == 'u' {
s.step = stateNu
return scanContinue
}
return s.error(c, "in literal null (expecting 'u')")
}
// stateNu is the state after reading `nu`.
func stateNu(s *scanner, c byte) int {
if c == 'l' {
s.step = stateNul
return scanContinue
}
return s.error(c, "in literal null (expecting 'l')")
}
// stateNul is the state after reading `nul`.
func stateNul(s *scanner, c byte) int {
if c == 'l' {
s.step = stateEndValue
return scanContinue
}
return s.error(c, "in literal null (expecting 'l')")
}
// stateError is the state after reaching a syntax error,
// such as after reading `[1}` or `5.1.2`.
func stateError(s *scanner, c byte) int {
return scanError
}
// error records an error and switches to the error state.
func (s *scanner) error(c byte, context string) int {
s.step = stateError
s.err = &SyntaxError{"invalid character " + quoteChar(c) + " " + context, s.bytes}
return scanError
}
// quoteChar formats c as a quoted character literal
func quoteChar(c byte) string {
// special cases - different from quoted strings
if c == '\'' {
return `'\''`
}
if c == '"' {
return `'"'`
}
// use quoted string with different quotation marks
s := strconv.Quote(string(c))
return "'" + s[1:len(s)-1] + "'"
}
// undo causes the scanner to return scanCode from the next state transition.
// This gives callers a simple 1-byte undo mechanism.
func (s *scanner) undo(scanCode int) {
if s.redo {
panic("json: invalid use of scanner")
}
s.redoCode = scanCode
s.redoState = s.step
s.step = stateRedo
s.redo = true
}
// stateRedo helps implement the scanner's 1-byte undo.
func stateRedo(s *scanner, c byte) int {
s.redo = false
s.step = s.redoState
return s.redoCode
}

485
vendor/github.com/go-jose/go-jose/v3/json/stream.go generated vendored Normal file
View File

@ -0,0 +1,485 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
import (
"bytes"
"errors"
"io"
)
// A Decoder reads and decodes JSON objects from an input stream.
type Decoder struct {
r io.Reader
buf []byte
d decodeState
scanp int // start of unread data in buf
scan scanner
err error
tokenState int
tokenStack []int
}
// NewDecoder returns a new decoder that reads from r.
//
// The decoder introduces its own buffering and may
// read data from r beyond the JSON values requested.
func NewDecoder(r io.Reader) *Decoder {
return &Decoder{r: r}
}
// Deprecated: Use `SetNumberType` instead
// UseNumber causes the Decoder to unmarshal a number into an interface{} as a
// Number instead of as a float64.
func (dec *Decoder) UseNumber() { dec.d.numberType = UnmarshalJSONNumber }
// SetNumberType causes the Decoder to unmarshal a number into an interface{} as a
// Number, float64 or int64 depending on `t` enum value.
func (dec *Decoder) SetNumberType(t NumberUnmarshalType) { dec.d.numberType = t }
// Decode reads the next JSON-encoded value from its
// input and stores it in the value pointed to by v.
//
// See the documentation for Unmarshal for details about
// the conversion of JSON into a Go value.
func (dec *Decoder) Decode(v interface{}) error {
if dec.err != nil {
return dec.err
}
if err := dec.tokenPrepareForDecode(); err != nil {
return err
}
if !dec.tokenValueAllowed() {
return &SyntaxError{msg: "not at beginning of value"}
}
// Read whole value into buffer.
n, err := dec.readValue()
if err != nil {
return err
}
dec.d.init(dec.buf[dec.scanp : dec.scanp+n])
dec.scanp += n
// Don't save err from unmarshal into dec.err:
// the connection is still usable since we read a complete JSON
// object from it before the error happened.
err = dec.d.unmarshal(v)
// fixup token streaming state
dec.tokenValueEnd()
return err
}
// Buffered returns a reader of the data remaining in the Decoder's
// buffer. The reader is valid until the next call to Decode.
func (dec *Decoder) Buffered() io.Reader {
return bytes.NewReader(dec.buf[dec.scanp:])
}
// readValue reads a JSON value into dec.buf.
// It returns the length of the encoding.
func (dec *Decoder) readValue() (int, error) {
dec.scan.reset()
scanp := dec.scanp
var err error
Input:
for {
// Look in the buffer for a new value.
for i, c := range dec.buf[scanp:] {
dec.scan.bytes++
v := dec.scan.step(&dec.scan, c)
if v == scanEnd {
scanp += i
break Input
}
// scanEnd is delayed one byte.
// We might block trying to get that byte from src,
// so instead invent a space byte.
if (v == scanEndObject || v == scanEndArray) && dec.scan.step(&dec.scan, ' ') == scanEnd {
scanp += i + 1
break Input
}
if v == scanError {
dec.err = dec.scan.err
return 0, dec.scan.err
}
}
scanp = len(dec.buf)
// Did the last read have an error?
// Delayed until now to allow buffer scan.
if err != nil {
if err == io.EOF {
if dec.scan.step(&dec.scan, ' ') == scanEnd {
break Input
}
if nonSpace(dec.buf) {
err = io.ErrUnexpectedEOF
}
}
dec.err = err
return 0, err
}
n := scanp - dec.scanp
err = dec.refill()
scanp = dec.scanp + n
}
return scanp - dec.scanp, nil
}
func (dec *Decoder) refill() error {
// Make room to read more into the buffer.
// First slide down data already consumed.
if dec.scanp > 0 {
n := copy(dec.buf, dec.buf[dec.scanp:])
dec.buf = dec.buf[:n]
dec.scanp = 0
}
// Grow buffer if not large enough.
const minRead = 512
if cap(dec.buf)-len(dec.buf) < minRead {
newBuf := make([]byte, len(dec.buf), 2*cap(dec.buf)+minRead)
copy(newBuf, dec.buf)
dec.buf = newBuf
}
// Read. Delay error for next iteration (after scan).
n, err := dec.r.Read(dec.buf[len(dec.buf):cap(dec.buf)])
dec.buf = dec.buf[0 : len(dec.buf)+n]
return err
}
func nonSpace(b []byte) bool {
for _, c := range b {
if !isSpace(c) {
return true
}
}
return false
}
// An Encoder writes JSON objects to an output stream.
type Encoder struct {
w io.Writer
err error
}
// NewEncoder returns a new encoder that writes to w.
func NewEncoder(w io.Writer) *Encoder {
return &Encoder{w: w}
}
// Encode writes the JSON encoding of v to the stream,
// followed by a newline character.
//
// See the documentation for Marshal for details about the
// conversion of Go values to JSON.
func (enc *Encoder) Encode(v interface{}) error {
if enc.err != nil {
return enc.err
}
e := newEncodeState()
err := e.marshal(v)
if err != nil {
return err
}
// Terminate each value with a newline.
// This makes the output look a little nicer
// when debugging, and some kind of space
// is required if the encoded value was a number,
// so that the reader knows there aren't more
// digits coming.
e.WriteByte('\n')
if _, err = enc.w.Write(e.Bytes()); err != nil {
enc.err = err
}
encodeStatePool.Put(e)
return err
}
// RawMessage is a raw encoded JSON object.
// It implements Marshaler and Unmarshaler and can
// be used to delay JSON decoding or precompute a JSON encoding.
type RawMessage []byte
// MarshalJSON returns *m as the JSON encoding of m.
func (m *RawMessage) MarshalJSON() ([]byte, error) {
return *m, nil
}
// UnmarshalJSON sets *m to a copy of data.
func (m *RawMessage) UnmarshalJSON(data []byte) error {
if m == nil {
return errors.New("json.RawMessage: UnmarshalJSON on nil pointer")
}
*m = append((*m)[0:0], data...)
return nil
}
var _ Marshaler = (*RawMessage)(nil)
var _ Unmarshaler = (*RawMessage)(nil)
// A Token holds a value of one of these types:
//
// Delim, for the four JSON delimiters [ ] { }
// bool, for JSON booleans
// float64, for JSON numbers
// Number, for JSON numbers
// string, for JSON string literals
// nil, for JSON null
//
type Token interface{}
const (
tokenTopValue = iota
tokenArrayStart
tokenArrayValue
tokenArrayComma
tokenObjectStart
tokenObjectKey
tokenObjectColon
tokenObjectValue
tokenObjectComma
)
// advance tokenstate from a separator state to a value state
func (dec *Decoder) tokenPrepareForDecode() error {
// Note: Not calling peek before switch, to avoid
// putting peek into the standard Decode path.
// peek is only called when using the Token API.
switch dec.tokenState {
case tokenArrayComma:
c, err := dec.peek()
if err != nil {
return err
}
if c != ',' {
return &SyntaxError{"expected comma after array element", 0}
}
dec.scanp++
dec.tokenState = tokenArrayValue
case tokenObjectColon:
c, err := dec.peek()
if err != nil {
return err
}
if c != ':' {
return &SyntaxError{"expected colon after object key", 0}
}
dec.scanp++
dec.tokenState = tokenObjectValue
}
return nil
}
func (dec *Decoder) tokenValueAllowed() bool {
switch dec.tokenState {
case tokenTopValue, tokenArrayStart, tokenArrayValue, tokenObjectValue:
return true
}
return false
}
func (dec *Decoder) tokenValueEnd() {
switch dec.tokenState {
case tokenArrayStart, tokenArrayValue:
dec.tokenState = tokenArrayComma
case tokenObjectValue:
dec.tokenState = tokenObjectComma
}
}
// A Delim is a JSON array or object delimiter, one of [ ] { or }.
type Delim rune
func (d Delim) String() string {
return string(d)
}
// Token returns the next JSON token in the input stream.
// At the end of the input stream, Token returns nil, io.EOF.
//
// Token guarantees that the delimiters [ ] { } it returns are
// properly nested and matched: if Token encounters an unexpected
// delimiter in the input, it will return an error.
//
// The input stream consists of basic JSON values—bool, string,
// number, and null—along with delimiters [ ] { } of type Delim
// to mark the start and end of arrays and objects.
// Commas and colons are elided.
func (dec *Decoder) Token() (Token, error) {
for {
c, err := dec.peek()
if err != nil {
return nil, err
}
switch c {
case '[':
if !dec.tokenValueAllowed() {
return dec.tokenError(c)
}
dec.scanp++
dec.tokenStack = append(dec.tokenStack, dec.tokenState)
dec.tokenState = tokenArrayStart
return Delim('['), nil
case ']':
if dec.tokenState != tokenArrayStart && dec.tokenState != tokenArrayComma {
return dec.tokenError(c)
}
dec.scanp++
dec.tokenState = dec.tokenStack[len(dec.tokenStack)-1]
dec.tokenStack = dec.tokenStack[:len(dec.tokenStack)-1]
dec.tokenValueEnd()
return Delim(']'), nil
case '{':
if !dec.tokenValueAllowed() {
return dec.tokenError(c)
}
dec.scanp++
dec.tokenStack = append(dec.tokenStack, dec.tokenState)
dec.tokenState = tokenObjectStart
return Delim('{'), nil
case '}':
if dec.tokenState != tokenObjectStart && dec.tokenState != tokenObjectComma {
return dec.tokenError(c)
}
dec.scanp++
dec.tokenState = dec.tokenStack[len(dec.tokenStack)-1]
dec.tokenStack = dec.tokenStack[:len(dec.tokenStack)-1]
dec.tokenValueEnd()
return Delim('}'), nil
case ':':
if dec.tokenState != tokenObjectColon {
return dec.tokenError(c)
}
dec.scanp++
dec.tokenState = tokenObjectValue
continue
case ',':
if dec.tokenState == tokenArrayComma {
dec.scanp++
dec.tokenState = tokenArrayValue
continue
}
if dec.tokenState == tokenObjectComma {
dec.scanp++
dec.tokenState = tokenObjectKey
continue
}
return dec.tokenError(c)
case '"':
if dec.tokenState == tokenObjectStart || dec.tokenState == tokenObjectKey {
var x string
old := dec.tokenState
dec.tokenState = tokenTopValue
err := dec.Decode(&x)
dec.tokenState = old
if err != nil {
clearOffset(err)
return nil, err
}
dec.tokenState = tokenObjectColon
return x, nil
}
fallthrough
default:
if !dec.tokenValueAllowed() {
return dec.tokenError(c)
}
var x interface{}
if err := dec.Decode(&x); err != nil {
clearOffset(err)
return nil, err
}
return x, nil
}
}
}
func clearOffset(err error) {
if s, ok := err.(*SyntaxError); ok {
s.Offset = 0
}
}
func (dec *Decoder) tokenError(c byte) (Token, error) {
var context string
switch dec.tokenState {
case tokenTopValue:
context = " looking for beginning of value"
case tokenArrayStart, tokenArrayValue, tokenObjectValue:
context = " looking for beginning of value"
case tokenArrayComma:
context = " after array element"
case tokenObjectKey:
context = " looking for beginning of object key string"
case tokenObjectColon:
context = " after object key"
case tokenObjectComma:
context = " after object key:value pair"
}
return nil, &SyntaxError{"invalid character " + quoteChar(c) + " " + context, 0}
}
// More reports whether there is another element in the
// current array or object being parsed.
func (dec *Decoder) More() bool {
c, err := dec.peek()
return err == nil && c != ']' && c != '}'
}
func (dec *Decoder) peek() (byte, error) {
var err error
for {
for i := dec.scanp; i < len(dec.buf); i++ {
c := dec.buf[i]
if isSpace(c) {
continue
}
dec.scanp = i
return c, nil
}
// buffer has been scanned, now report any error
if err != nil {
return 0, err
}
err = dec.refill()
}
}
/*
TODO
// EncodeToken writes the given JSON token to the stream.
// It returns an error if the delimiters [ ] { } are not properly used.
//
// EncodeToken does not call Flush, because usually it is part of
// a larger operation such as Encode, and those will call Flush when finished.
// Callers that create an Encoder and then invoke EncodeToken directly,
// without using Encode, need to call Flush when finished to ensure that
// the JSON is written to the underlying writer.
func (e *Encoder) EncodeToken(t Token) error {
...
}
*/

44
vendor/github.com/go-jose/go-jose/v3/json/tags.go generated vendored Normal file
View File

@ -0,0 +1,44 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
import (
"strings"
)
// tagOptions is the string following a comma in a struct field's "json"
// tag, or the empty string. It does not include the leading comma.
type tagOptions string
// parseTag splits a struct field's json tag into its name and
// comma-separated options.
func parseTag(tag string) (string, tagOptions) {
if idx := strings.Index(tag, ","); idx != -1 {
return tag[:idx], tagOptions(tag[idx+1:])
}
return tag, tagOptions("")
}
// Contains reports whether a comma-separated list of options
// contains a particular substr flag. substr must be surrounded by a
// string boundary or commas.
func (o tagOptions) Contains(optionName string) bool {
if len(o) == 0 {
return false
}
s := string(o)
for s != "" {
var next string
i := strings.Index(s, ",")
if i >= 0 {
s, next = s[:i], s[i+1:]
}
if s == optionName {
return true
}
s = next
}
return false
}

295
vendor/github.com/go-jose/go-jose/v3/jwe.go generated vendored Normal file
View File

@ -0,0 +1,295 @@
/*-
* Copyright 2014 Square Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package jose
import (
"encoding/base64"
"fmt"
"strings"
"github.com/go-jose/go-jose/v3/json"
)
// rawJSONWebEncryption represents a raw JWE JSON object. Used for parsing/serializing.
type rawJSONWebEncryption struct {
Protected *byteBuffer `json:"protected,omitempty"`
Unprotected *rawHeader `json:"unprotected,omitempty"`
Header *rawHeader `json:"header,omitempty"`
Recipients []rawRecipientInfo `json:"recipients,omitempty"`
Aad *byteBuffer `json:"aad,omitempty"`
EncryptedKey *byteBuffer `json:"encrypted_key,omitempty"`
Iv *byteBuffer `json:"iv,omitempty"`
Ciphertext *byteBuffer `json:"ciphertext,omitempty"`
Tag *byteBuffer `json:"tag,omitempty"`
}
// rawRecipientInfo represents a raw JWE Per-Recipient header JSON object. Used for parsing/serializing.
type rawRecipientInfo struct {
Header *rawHeader `json:"header,omitempty"`
EncryptedKey string `json:"encrypted_key,omitempty"`
}
// JSONWebEncryption represents an encrypted JWE object after parsing.
type JSONWebEncryption struct {
Header Header
protected, unprotected *rawHeader
recipients []recipientInfo
aad, iv, ciphertext, tag []byte
original *rawJSONWebEncryption
}
// recipientInfo represents a raw JWE Per-Recipient header JSON object after parsing.
type recipientInfo struct {
header *rawHeader
encryptedKey []byte
}
// GetAuthData retrieves the (optional) authenticated data attached to the object.
func (obj JSONWebEncryption) GetAuthData() []byte {
if obj.aad != nil {
out := make([]byte, len(obj.aad))
copy(out, obj.aad)
return out
}
return nil
}
// Get the merged header values
func (obj JSONWebEncryption) mergedHeaders(recipient *recipientInfo) rawHeader {
out := rawHeader{}
out.merge(obj.protected)
out.merge(obj.unprotected)
if recipient != nil {
out.merge(recipient.header)
}
return out
}
// Get the additional authenticated data from a JWE object.
func (obj JSONWebEncryption) computeAuthData() []byte {
var protected string
switch {
case obj.original != nil && obj.original.Protected != nil:
protected = obj.original.Protected.base64()
case obj.protected != nil:
protected = base64.RawURLEncoding.EncodeToString(mustSerializeJSON((obj.protected)))
default:
protected = ""
}
output := []byte(protected)
if obj.aad != nil {
output = append(output, '.')
output = append(output, []byte(base64.RawURLEncoding.EncodeToString(obj.aad))...)
}
return output
}
// ParseEncrypted parses an encrypted message in compact or JWE JSON Serialization format.
func ParseEncrypted(input string) (*JSONWebEncryption, error) {
input = stripWhitespace(input)
if strings.HasPrefix(input, "{") {
return parseEncryptedFull(input)
}
return parseEncryptedCompact(input)
}
// parseEncryptedFull parses a message in compact format.
func parseEncryptedFull(input string) (*JSONWebEncryption, error) {
var parsed rawJSONWebEncryption
err := json.Unmarshal([]byte(input), &parsed)
if err != nil {
return nil, err
}
return parsed.sanitized()
}
// sanitized produces a cleaned-up JWE object from the raw JSON.
func (parsed *rawJSONWebEncryption) sanitized() (*JSONWebEncryption, error) {
obj := &JSONWebEncryption{
original: parsed,
unprotected: parsed.Unprotected,
}
// Check that there is not a nonce in the unprotected headers
if parsed.Unprotected != nil {
if nonce := parsed.Unprotected.getNonce(); nonce != "" {
return nil, ErrUnprotectedNonce
}
}
if parsed.Header != nil {
if nonce := parsed.Header.getNonce(); nonce != "" {
return nil, ErrUnprotectedNonce
}
}
if parsed.Protected != nil && len(parsed.Protected.bytes()) > 0 {
err := json.Unmarshal(parsed.Protected.bytes(), &obj.protected)
if err != nil {
return nil, fmt.Errorf("go-jose/go-jose: invalid protected header: %s, %s", err, parsed.Protected.base64())
}
}
// Note: this must be called _after_ we parse the protected header,
// otherwise fields from the protected header will not get picked up.
var err error
mergedHeaders := obj.mergedHeaders(nil)
obj.Header, err = mergedHeaders.sanitized()
if err != nil {
return nil, fmt.Errorf("go-jose/go-jose: cannot sanitize merged headers: %v (%v)", err, mergedHeaders)
}
if len(parsed.Recipients) == 0 {
obj.recipients = []recipientInfo{
{
header: parsed.Header,
encryptedKey: parsed.EncryptedKey.bytes(),
},
}
} else {
obj.recipients = make([]recipientInfo, len(parsed.Recipients))
for r := range parsed.Recipients {
encryptedKey, err := base64URLDecode(parsed.Recipients[r].EncryptedKey)
if err != nil {
return nil, err
}
// Check that there is not a nonce in the unprotected header
if parsed.Recipients[r].Header != nil && parsed.Recipients[r].Header.getNonce() != "" {
return nil, ErrUnprotectedNonce
}
obj.recipients[r].header = parsed.Recipients[r].Header
obj.recipients[r].encryptedKey = encryptedKey
}
}
for _, recipient := range obj.recipients {
headers := obj.mergedHeaders(&recipient)
if headers.getAlgorithm() == "" || headers.getEncryption() == "" {
return nil, fmt.Errorf("go-jose/go-jose: message is missing alg/enc headers")
}
}
obj.iv = parsed.Iv.bytes()
obj.ciphertext = parsed.Ciphertext.bytes()
obj.tag = parsed.Tag.bytes()
obj.aad = parsed.Aad.bytes()
return obj, nil
}
// parseEncryptedCompact parses a message in compact format.
func parseEncryptedCompact(input string) (*JSONWebEncryption, error) {
parts := strings.Split(input, ".")
if len(parts) != 5 {
return nil, fmt.Errorf("go-jose/go-jose: compact JWE format must have five parts")
}
rawProtected, err := base64URLDecode(parts[0])
if err != nil {
return nil, err
}
encryptedKey, err := base64URLDecode(parts[1])
if err != nil {
return nil, err
}
iv, err := base64URLDecode(parts[2])
if err != nil {
return nil, err
}
ciphertext, err := base64URLDecode(parts[3])
if err != nil {
return nil, err
}
tag, err := base64URLDecode(parts[4])
if err != nil {
return nil, err
}
raw := &rawJSONWebEncryption{
Protected: newBuffer(rawProtected),
EncryptedKey: newBuffer(encryptedKey),
Iv: newBuffer(iv),
Ciphertext: newBuffer(ciphertext),
Tag: newBuffer(tag),
}
return raw.sanitized()
}
// CompactSerialize serializes an object using the compact serialization format.
func (obj JSONWebEncryption) CompactSerialize() (string, error) {
if len(obj.recipients) != 1 || obj.unprotected != nil ||
obj.protected == nil || obj.recipients[0].header != nil {
return "", ErrNotSupported
}
serializedProtected := mustSerializeJSON(obj.protected)
return fmt.Sprintf(
"%s.%s.%s.%s.%s",
base64.RawURLEncoding.EncodeToString(serializedProtected),
base64.RawURLEncoding.EncodeToString(obj.recipients[0].encryptedKey),
base64.RawURLEncoding.EncodeToString(obj.iv),
base64.RawURLEncoding.EncodeToString(obj.ciphertext),
base64.RawURLEncoding.EncodeToString(obj.tag)), nil
}
// FullSerialize serializes an object using the full JSON serialization format.
func (obj JSONWebEncryption) FullSerialize() string {
raw := rawJSONWebEncryption{
Unprotected: obj.unprotected,
Iv: newBuffer(obj.iv),
Ciphertext: newBuffer(obj.ciphertext),
EncryptedKey: newBuffer(obj.recipients[0].encryptedKey),
Tag: newBuffer(obj.tag),
Aad: newBuffer(obj.aad),
Recipients: []rawRecipientInfo{},
}
if len(obj.recipients) > 1 {
for _, recipient := range obj.recipients {
info := rawRecipientInfo{
Header: recipient.header,
EncryptedKey: base64.RawURLEncoding.EncodeToString(recipient.encryptedKey),
}
raw.Recipients = append(raw.Recipients, info)
}
} else {
// Use flattened serialization
raw.Header = obj.recipients[0].header
raw.EncryptedKey = newBuffer(obj.recipients[0].encryptedKey)
}
if obj.protected != nil {
raw.Protected = newBuffer(mustSerializeJSON(obj.protected))
}
return string(mustSerializeJSON(raw))
}

798
vendor/github.com/go-jose/go-jose/v3/jwk.go generated vendored Normal file
View File

@ -0,0 +1,798 @@
/*-
* Copyright 2014 Square Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package jose
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rsa"
"crypto/sha1"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"math/big"
"net/url"
"reflect"
"strings"
"github.com/go-jose/go-jose/v3/json"
)
// rawJSONWebKey represents a public or private key in JWK format, used for parsing/serializing.
type rawJSONWebKey struct {
Use string `json:"use,omitempty"`
Kty string `json:"kty,omitempty"`
Kid string `json:"kid,omitempty"`
Crv string `json:"crv,omitempty"`
Alg string `json:"alg,omitempty"`
K *byteBuffer `json:"k,omitempty"`
X *byteBuffer `json:"x,omitempty"`
Y *byteBuffer `json:"y,omitempty"`
N *byteBuffer `json:"n,omitempty"`
E *byteBuffer `json:"e,omitempty"`
// -- Following fields are only used for private keys --
// RSA uses D, P and Q, while ECDSA uses only D. Fields Dp, Dq, and Qi are
// completely optional. Therefore for RSA/ECDSA, D != nil is a contract that
// we have a private key whereas D == nil means we have only a public key.
D *byteBuffer `json:"d,omitempty"`
P *byteBuffer `json:"p,omitempty"`
Q *byteBuffer `json:"q,omitempty"`
Dp *byteBuffer `json:"dp,omitempty"`
Dq *byteBuffer `json:"dq,omitempty"`
Qi *byteBuffer `json:"qi,omitempty"`
// Certificates
X5c []string `json:"x5c,omitempty"`
X5u string `json:"x5u,omitempty"`
X5tSHA1 string `json:"x5t,omitempty"`
X5tSHA256 string `json:"x5t#S256,omitempty"`
}
// JSONWebKey represents a public or private key in JWK format.
type JSONWebKey struct {
// Cryptographic key, can be a symmetric or asymmetric key.
Key interface{}
// Key identifier, parsed from `kid` header.
KeyID string
// Key algorithm, parsed from `alg` header.
Algorithm string
// Key use, parsed from `use` header.
Use string
// X.509 certificate chain, parsed from `x5c` header.
Certificates []*x509.Certificate
// X.509 certificate URL, parsed from `x5u` header.
CertificatesURL *url.URL
// X.509 certificate thumbprint (SHA-1), parsed from `x5t` header.
CertificateThumbprintSHA1 []byte
// X.509 certificate thumbprint (SHA-256), parsed from `x5t#S256` header.
CertificateThumbprintSHA256 []byte
}
// MarshalJSON serializes the given key to its JSON representation.
func (k JSONWebKey) MarshalJSON() ([]byte, error) {
var raw *rawJSONWebKey
var err error
switch key := k.Key.(type) {
case ed25519.PublicKey:
raw = fromEdPublicKey(key)
case *ecdsa.PublicKey:
raw, err = fromEcPublicKey(key)
case *rsa.PublicKey:
raw = fromRsaPublicKey(key)
case ed25519.PrivateKey:
raw, err = fromEdPrivateKey(key)
case *ecdsa.PrivateKey:
raw, err = fromEcPrivateKey(key)
case *rsa.PrivateKey:
raw, err = fromRsaPrivateKey(key)
case []byte:
raw, err = fromSymmetricKey(key)
default:
return nil, fmt.Errorf("go-jose/go-jose: unknown key type '%s'", reflect.TypeOf(key))
}
if err != nil {
return nil, err
}
raw.Kid = k.KeyID
raw.Alg = k.Algorithm
raw.Use = k.Use
for _, cert := range k.Certificates {
raw.X5c = append(raw.X5c, base64.StdEncoding.EncodeToString(cert.Raw))
}
x5tSHA1Len := len(k.CertificateThumbprintSHA1)
x5tSHA256Len := len(k.CertificateThumbprintSHA256)
if x5tSHA1Len > 0 {
if x5tSHA1Len != sha1.Size {
return nil, fmt.Errorf("go-jose/go-jose: invalid SHA-1 thumbprint (must be %d bytes, not %d)", sha1.Size, x5tSHA1Len)
}
raw.X5tSHA1 = base64.RawURLEncoding.EncodeToString(k.CertificateThumbprintSHA1)
}
if x5tSHA256Len > 0 {
if x5tSHA256Len != sha256.Size {
return nil, fmt.Errorf("go-jose/go-jose: invalid SHA-256 thumbprint (must be %d bytes, not %d)", sha256.Size, x5tSHA256Len)
}
raw.X5tSHA256 = base64.RawURLEncoding.EncodeToString(k.CertificateThumbprintSHA256)
}
// If cert chain is attached (as opposed to being behind a URL), check the
// keys thumbprints to make sure they match what is expected. This is to
// ensure we don't accidentally produce a JWK with semantically inconsistent
// data in the headers.
if len(k.Certificates) > 0 {
expectedSHA1 := sha1.Sum(k.Certificates[0].Raw)
expectedSHA256 := sha256.Sum256(k.Certificates[0].Raw)
if len(k.CertificateThumbprintSHA1) > 0 && !bytes.Equal(k.CertificateThumbprintSHA1, expectedSHA1[:]) {
return nil, errors.New("go-jose/go-jose: invalid SHA-1 thumbprint, does not match cert chain")
}
if len(k.CertificateThumbprintSHA256) > 0 && !bytes.Equal(k.CertificateThumbprintSHA256, expectedSHA256[:]) {
return nil, errors.New("go-jose/go-jose: invalid or SHA-256 thumbprint, does not match cert chain")
}
}
if k.CertificatesURL != nil {
raw.X5u = k.CertificatesURL.String()
}
return json.Marshal(raw)
}
// UnmarshalJSON reads a key from its JSON representation.
func (k *JSONWebKey) UnmarshalJSON(data []byte) (err error) {
var raw rawJSONWebKey
err = json.Unmarshal(data, &raw)
if err != nil {
return err
}
certs, err := parseCertificateChain(raw.X5c)
if err != nil {
return fmt.Errorf("go-jose/go-jose: failed to unmarshal x5c field: %s", err)
}
var key interface{}
var certPub interface{}
var keyPub interface{}
if len(certs) > 0 {
// We need to check that leaf public key matches the key embedded in this
// JWK, as required by the standard (see RFC 7517, Section 4.7). Otherwise
// the JWK parsed could be semantically invalid. Technically, should also
// check key usage fields and other extensions on the cert here, but the
// standard doesn't exactly explain how they're supposed to map from the
// JWK representation to the X.509 extensions.
certPub = certs[0].PublicKey
}
switch raw.Kty {
case "EC":
if raw.D != nil {
key, err = raw.ecPrivateKey()
if err == nil {
keyPub = key.(*ecdsa.PrivateKey).Public()
}
} else {
key, err = raw.ecPublicKey()
keyPub = key
}
case "RSA":
if raw.D != nil {
key, err = raw.rsaPrivateKey()
if err == nil {
keyPub = key.(*rsa.PrivateKey).Public()
}
} else {
key, err = raw.rsaPublicKey()
keyPub = key
}
case "oct":
if certPub != nil {
return errors.New("go-jose/go-jose: invalid JWK, found 'oct' (symmetric) key with cert chain")
}
key, err = raw.symmetricKey()
case "OKP":
if raw.Crv == "Ed25519" && raw.X != nil {
if raw.D != nil {
key, err = raw.edPrivateKey()
if err == nil {
keyPub = key.(ed25519.PrivateKey).Public()
}
} else {
key, err = raw.edPublicKey()
keyPub = key
}
} else {
err = fmt.Errorf("go-jose/go-jose: unknown curve %s'", raw.Crv)
}
default:
err = fmt.Errorf("go-jose/go-jose: unknown json web key type '%s'", raw.Kty)
}
if err != nil {
return
}
if certPub != nil && keyPub != nil {
if !reflect.DeepEqual(certPub, keyPub) {
return errors.New("go-jose/go-jose: invalid JWK, public keys in key and x5c fields do not match")
}
}
*k = JSONWebKey{Key: key, KeyID: raw.Kid, Algorithm: raw.Alg, Use: raw.Use, Certificates: certs}
if raw.X5u != "" {
k.CertificatesURL, err = url.Parse(raw.X5u)
if err != nil {
return fmt.Errorf("go-jose/go-jose: invalid JWK, x5u header is invalid URL: %w", err)
}
}
// x5t parameters are base64url-encoded SHA thumbprints
// See RFC 7517, Section 4.8, https://tools.ietf.org/html/rfc7517#section-4.8
x5tSHA1bytes, err := base64URLDecode(raw.X5tSHA1)
if err != nil {
return errors.New("go-jose/go-jose: invalid JWK, x5t header has invalid encoding")
}
// RFC 7517, Section 4.8 is ambiguous as to whether the digest output should be byte or hex,
// for this reason, after base64 decoding, if the size is sha1.Size it's likely that the value is a byte encoded
// checksum so we skip this. Otherwise if the checksum was hex encoded we expect a 40 byte sized array so we'll
// try to hex decode it. When Marshalling this value we'll always use a base64 encoded version of byte format checksum.
if len(x5tSHA1bytes) == 2*sha1.Size {
hx, err := hex.DecodeString(string(x5tSHA1bytes))
if err != nil {
return fmt.Errorf("go-jose/go-jose: invalid JWK, unable to hex decode x5t: %v", err)
}
x5tSHA1bytes = hx
}
k.CertificateThumbprintSHA1 = x5tSHA1bytes
x5tSHA256bytes, err := base64URLDecode(raw.X5tSHA256)
if err != nil {
return errors.New("go-jose/go-jose: invalid JWK, x5t#S256 header has invalid encoding")
}
if len(x5tSHA256bytes) == 2*sha256.Size {
hx256, err := hex.DecodeString(string(x5tSHA256bytes))
if err != nil {
return fmt.Errorf("go-jose/go-jose: invalid JWK, unable to hex decode x5t#S256: %v", err)
}
x5tSHA256bytes = hx256
}
k.CertificateThumbprintSHA256 = x5tSHA256bytes
x5tSHA1Len := len(k.CertificateThumbprintSHA1)
x5tSHA256Len := len(k.CertificateThumbprintSHA256)
if x5tSHA1Len > 0 && x5tSHA1Len != sha1.Size {
return errors.New("go-jose/go-jose: invalid JWK, x5t header is of incorrect size")
}
if x5tSHA256Len > 0 && x5tSHA256Len != sha256.Size {
return errors.New("go-jose/go-jose: invalid JWK, x5t#S256 header is of incorrect size")
}
// If certificate chain *and* thumbprints are set, verify correctness.
if len(k.Certificates) > 0 {
leaf := k.Certificates[0]
sha1sum := sha1.Sum(leaf.Raw)
sha256sum := sha256.Sum256(leaf.Raw)
if len(k.CertificateThumbprintSHA1) > 0 && !bytes.Equal(sha1sum[:], k.CertificateThumbprintSHA1) {
return errors.New("go-jose/go-jose: invalid JWK, x5c thumbprint does not match x5t value")
}
if len(k.CertificateThumbprintSHA256) > 0 && !bytes.Equal(sha256sum[:], k.CertificateThumbprintSHA256) {
return errors.New("go-jose/go-jose: invalid JWK, x5c thumbprint does not match x5t#S256 value")
}
}
return
}
// JSONWebKeySet represents a JWK Set object.
type JSONWebKeySet struct {
Keys []JSONWebKey `json:"keys"`
}
// Key convenience method returns keys by key ID. Specification states
// that a JWK Set "SHOULD" use distinct key IDs, but allows for some
// cases where they are not distinct. Hence method returns a slice
// of JSONWebKeys.
func (s *JSONWebKeySet) Key(kid string) []JSONWebKey {
var keys []JSONWebKey
for _, key := range s.Keys {
if key.KeyID == kid {
keys = append(keys, key)
}
}
return keys
}
const rsaThumbprintTemplate = `{"e":"%s","kty":"RSA","n":"%s"}`
const ecThumbprintTemplate = `{"crv":"%s","kty":"EC","x":"%s","y":"%s"}`
const edThumbprintTemplate = `{"crv":"%s","kty":"OKP","x":"%s"}`
func ecThumbprintInput(curve elliptic.Curve, x, y *big.Int) (string, error) {
coordLength := curveSize(curve)
crv, err := curveName(curve)
if err != nil {
return "", err
}
if len(x.Bytes()) > coordLength || len(y.Bytes()) > coordLength {
return "", errors.New("go-jose/go-jose: invalid elliptic key (too large)")
}
return fmt.Sprintf(ecThumbprintTemplate, crv,
newFixedSizeBuffer(x.Bytes(), coordLength).base64(),
newFixedSizeBuffer(y.Bytes(), coordLength).base64()), nil
}
func rsaThumbprintInput(n *big.Int, e int) (string, error) {
return fmt.Sprintf(rsaThumbprintTemplate,
newBufferFromInt(uint64(e)).base64(),
newBuffer(n.Bytes()).base64()), nil
}
func edThumbprintInput(ed ed25519.PublicKey) (string, error) {
crv := "Ed25519"
if len(ed) > 32 {
return "", errors.New("go-jose/go-jose: invalid elliptic key (too large)")
}
return fmt.Sprintf(edThumbprintTemplate, crv,
newFixedSizeBuffer(ed, 32).base64()), nil
}
// Thumbprint computes the JWK Thumbprint of a key using the
// indicated hash algorithm.
func (k *JSONWebKey) Thumbprint(hash crypto.Hash) ([]byte, error) {
var input string
var err error
switch key := k.Key.(type) {
case ed25519.PublicKey:
input, err = edThumbprintInput(key)
case *ecdsa.PublicKey:
input, err = ecThumbprintInput(key.Curve, key.X, key.Y)
case *ecdsa.PrivateKey:
input, err = ecThumbprintInput(key.Curve, key.X, key.Y)
case *rsa.PublicKey:
input, err = rsaThumbprintInput(key.N, key.E)
case *rsa.PrivateKey:
input, err = rsaThumbprintInput(key.N, key.E)
case ed25519.PrivateKey:
input, err = edThumbprintInput(ed25519.PublicKey(key[32:]))
default:
return nil, fmt.Errorf("go-jose/go-jose: unknown key type '%s'", reflect.TypeOf(key))
}
if err != nil {
return nil, err
}
h := hash.New()
_, _ = h.Write([]byte(input))
return h.Sum(nil), nil
}
// IsPublic returns true if the JWK represents a public key (not symmetric, not private).
func (k *JSONWebKey) IsPublic() bool {
switch k.Key.(type) {
case *ecdsa.PublicKey, *rsa.PublicKey, ed25519.PublicKey:
return true
default:
return false
}
}
// Public creates JSONWebKey with corresponding public key if JWK represents asymmetric private key.
func (k *JSONWebKey) Public() JSONWebKey {
if k.IsPublic() {
return *k
}
ret := *k
switch key := k.Key.(type) {
case *ecdsa.PrivateKey:
ret.Key = key.Public()
case *rsa.PrivateKey:
ret.Key = key.Public()
case ed25519.PrivateKey:
ret.Key = key.Public()
default:
return JSONWebKey{} // returning invalid key
}
return ret
}
// Valid checks that the key contains the expected parameters.
func (k *JSONWebKey) Valid() bool {
if k.Key == nil {
return false
}
switch key := k.Key.(type) {
case *ecdsa.PublicKey:
if key.Curve == nil || key.X == nil || key.Y == nil {
return false
}
case *ecdsa.PrivateKey:
if key.Curve == nil || key.X == nil || key.Y == nil || key.D == nil {
return false
}
case *rsa.PublicKey:
if key.N == nil || key.E == 0 {
return false
}
case *rsa.PrivateKey:
if key.N == nil || key.E == 0 || key.D == nil || len(key.Primes) < 2 {
return false
}
case ed25519.PublicKey:
if len(key) != 32 {
return false
}
case ed25519.PrivateKey:
if len(key) != 64 {
return false
}
default:
return false
}
return true
}
func (key rawJSONWebKey) rsaPublicKey() (*rsa.PublicKey, error) {
if key.N == nil || key.E == nil {
return nil, fmt.Errorf("go-jose/go-jose: invalid RSA key, missing n/e values")
}
return &rsa.PublicKey{
N: key.N.bigInt(),
E: key.E.toInt(),
}, nil
}
func fromEdPublicKey(pub ed25519.PublicKey) *rawJSONWebKey {
return &rawJSONWebKey{
Kty: "OKP",
Crv: "Ed25519",
X: newBuffer(pub),
}
}
func fromRsaPublicKey(pub *rsa.PublicKey) *rawJSONWebKey {
return &rawJSONWebKey{
Kty: "RSA",
N: newBuffer(pub.N.Bytes()),
E: newBufferFromInt(uint64(pub.E)),
}
}
func (key rawJSONWebKey) ecPublicKey() (*ecdsa.PublicKey, error) {
var curve elliptic.Curve
switch key.Crv {
case "P-256":
curve = elliptic.P256()
case "P-384":
curve = elliptic.P384()
case "P-521":
curve = elliptic.P521()
default:
return nil, fmt.Errorf("go-jose/go-jose: unsupported elliptic curve '%s'", key.Crv)
}
if key.X == nil || key.Y == nil {
return nil, errors.New("go-jose/go-jose: invalid EC key, missing x/y values")
}
// The length of this octet string MUST be the full size of a coordinate for
// the curve specified in the "crv" parameter.
// https://tools.ietf.org/html/rfc7518#section-6.2.1.2
if curveSize(curve) != len(key.X.data) {
return nil, fmt.Errorf("go-jose/go-jose: invalid EC public key, wrong length for x")
}
if curveSize(curve) != len(key.Y.data) {
return nil, fmt.Errorf("go-jose/go-jose: invalid EC public key, wrong length for y")
}
x := key.X.bigInt()
y := key.Y.bigInt()
if !curve.IsOnCurve(x, y) {
return nil, errors.New("go-jose/go-jose: invalid EC key, X/Y are not on declared curve")
}
return &ecdsa.PublicKey{
Curve: curve,
X: x,
Y: y,
}, nil
}
func fromEcPublicKey(pub *ecdsa.PublicKey) (*rawJSONWebKey, error) {
if pub == nil || pub.X == nil || pub.Y == nil {
return nil, fmt.Errorf("go-jose/go-jose: invalid EC key (nil, or X/Y missing)")
}
name, err := curveName(pub.Curve)
if err != nil {
return nil, err
}
size := curveSize(pub.Curve)
xBytes := pub.X.Bytes()
yBytes := pub.Y.Bytes()
if len(xBytes) > size || len(yBytes) > size {
return nil, fmt.Errorf("go-jose/go-jose: invalid EC key (X/Y too large)")
}
key := &rawJSONWebKey{
Kty: "EC",
Crv: name,
X: newFixedSizeBuffer(xBytes, size),
Y: newFixedSizeBuffer(yBytes, size),
}
return key, nil
}
func (key rawJSONWebKey) edPrivateKey() (ed25519.PrivateKey, error) {
var missing []string
switch {
case key.D == nil:
missing = append(missing, "D")
case key.X == nil:
missing = append(missing, "X")
}
if len(missing) > 0 {
return nil, fmt.Errorf("go-jose/go-jose: invalid Ed25519 private key, missing %s value(s)", strings.Join(missing, ", "))
}
privateKey := make([]byte, ed25519.PrivateKeySize)
copy(privateKey[0:32], key.D.bytes())
copy(privateKey[32:], key.X.bytes())
rv := ed25519.PrivateKey(privateKey)
return rv, nil
}
func (key rawJSONWebKey) edPublicKey() (ed25519.PublicKey, error) {
if key.X == nil {
return nil, fmt.Errorf("go-jose/go-jose: invalid Ed key, missing x value")
}
publicKey := make([]byte, ed25519.PublicKeySize)
copy(publicKey[0:32], key.X.bytes())
rv := ed25519.PublicKey(publicKey)
return rv, nil
}
func (key rawJSONWebKey) rsaPrivateKey() (*rsa.PrivateKey, error) {
var missing []string
switch {
case key.N == nil:
missing = append(missing, "N")
case key.E == nil:
missing = append(missing, "E")
case key.D == nil:
missing = append(missing, "D")
case key.P == nil:
missing = append(missing, "P")
case key.Q == nil:
missing = append(missing, "Q")
}
if len(missing) > 0 {
return nil, fmt.Errorf("go-jose/go-jose: invalid RSA private key, missing %s value(s)", strings.Join(missing, ", "))
}
rv := &rsa.PrivateKey{
PublicKey: rsa.PublicKey{
N: key.N.bigInt(),
E: key.E.toInt(),
},
D: key.D.bigInt(),
Primes: []*big.Int{
key.P.bigInt(),
key.Q.bigInt(),
},
}
if key.Dp != nil {
rv.Precomputed.Dp = key.Dp.bigInt()
}
if key.Dq != nil {
rv.Precomputed.Dq = key.Dq.bigInt()
}
if key.Qi != nil {
rv.Precomputed.Qinv = key.Qi.bigInt()
}
err := rv.Validate()
return rv, err
}
func fromEdPrivateKey(ed ed25519.PrivateKey) (*rawJSONWebKey, error) {
raw := fromEdPublicKey(ed25519.PublicKey(ed[32:]))
raw.D = newBuffer(ed[0:32])
return raw, nil
}
func fromRsaPrivateKey(rsa *rsa.PrivateKey) (*rawJSONWebKey, error) {
if len(rsa.Primes) != 2 {
return nil, ErrUnsupportedKeyType
}
raw := fromRsaPublicKey(&rsa.PublicKey)
raw.D = newBuffer(rsa.D.Bytes())
raw.P = newBuffer(rsa.Primes[0].Bytes())
raw.Q = newBuffer(rsa.Primes[1].Bytes())
if rsa.Precomputed.Dp != nil {
raw.Dp = newBuffer(rsa.Precomputed.Dp.Bytes())
}
if rsa.Precomputed.Dq != nil {
raw.Dq = newBuffer(rsa.Precomputed.Dq.Bytes())
}
if rsa.Precomputed.Qinv != nil {
raw.Qi = newBuffer(rsa.Precomputed.Qinv.Bytes())
}
return raw, nil
}
func (key rawJSONWebKey) ecPrivateKey() (*ecdsa.PrivateKey, error) {
var curve elliptic.Curve
switch key.Crv {
case "P-256":
curve = elliptic.P256()
case "P-384":
curve = elliptic.P384()
case "P-521":
curve = elliptic.P521()
default:
return nil, fmt.Errorf("go-jose/go-jose: unsupported elliptic curve '%s'", key.Crv)
}
if key.X == nil || key.Y == nil || key.D == nil {
return nil, fmt.Errorf("go-jose/go-jose: invalid EC private key, missing x/y/d values")
}
// The length of this octet string MUST be the full size of a coordinate for
// the curve specified in the "crv" parameter.
// https://tools.ietf.org/html/rfc7518#section-6.2.1.2
if curveSize(curve) != len(key.X.data) {
return nil, fmt.Errorf("go-jose/go-jose: invalid EC private key, wrong length for x")
}
if curveSize(curve) != len(key.Y.data) {
return nil, fmt.Errorf("go-jose/go-jose: invalid EC private key, wrong length for y")
}
// https://tools.ietf.org/html/rfc7518#section-6.2.2.1
if dSize(curve) != len(key.D.data) {
return nil, fmt.Errorf("go-jose/go-jose: invalid EC private key, wrong length for d")
}
x := key.X.bigInt()
y := key.Y.bigInt()
if !curve.IsOnCurve(x, y) {
return nil, errors.New("go-jose/go-jose: invalid EC key, X/Y are not on declared curve")
}
return &ecdsa.PrivateKey{
PublicKey: ecdsa.PublicKey{
Curve: curve,
X: x,
Y: y,
},
D: key.D.bigInt(),
}, nil
}
func fromEcPrivateKey(ec *ecdsa.PrivateKey) (*rawJSONWebKey, error) {
raw, err := fromEcPublicKey(&ec.PublicKey)
if err != nil {
return nil, err
}
if ec.D == nil {
return nil, fmt.Errorf("go-jose/go-jose: invalid EC private key")
}
raw.D = newFixedSizeBuffer(ec.D.Bytes(), dSize(ec.PublicKey.Curve))
return raw, nil
}
// dSize returns the size in octets for the "d" member of an elliptic curve
// private key.
// The length of this octet string MUST be ceiling(log-base-2(n)/8)
// octets (where n is the order of the curve).
// https://tools.ietf.org/html/rfc7518#section-6.2.2.1
func dSize(curve elliptic.Curve) int {
order := curve.Params().P
bitLen := order.BitLen()
size := bitLen / 8
if bitLen%8 != 0 {
size++
}
return size
}
func fromSymmetricKey(key []byte) (*rawJSONWebKey, error) {
return &rawJSONWebKey{
Kty: "oct",
K: newBuffer(key),
}, nil
}
func (key rawJSONWebKey) symmetricKey() ([]byte, error) {
if key.K == nil {
return nil, fmt.Errorf("go-jose/go-jose: invalid OCT (symmetric) key, missing k value")
}
return key.K.bytes(), nil
}
func tryJWKS(key interface{}, headers ...Header) interface{} {
var jwks JSONWebKeySet
switch jwksType := key.(type) {
case *JSONWebKeySet:
jwks = *jwksType
case JSONWebKeySet:
jwks = jwksType
default:
return key
}
var kid string
for _, header := range headers {
if header.KeyID != "" {
kid = header.KeyID
break
}
}
if kid == "" {
return key
}
keys := jwks.Key(kid)
if len(keys) == 0 {
return key
}
return keys[0].Key
}

366
vendor/github.com/go-jose/go-jose/v3/jws.go generated vendored Normal file
View File

@ -0,0 +1,366 @@
/*-
* Copyright 2014 Square Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package jose
import (
"bytes"
"encoding/base64"
"errors"
"fmt"
"strings"
"github.com/go-jose/go-jose/v3/json"
)
// rawJSONWebSignature represents a raw JWS JSON object. Used for parsing/serializing.
type rawJSONWebSignature struct {
Payload *byteBuffer `json:"payload,omitempty"`
Signatures []rawSignatureInfo `json:"signatures,omitempty"`
Protected *byteBuffer `json:"protected,omitempty"`
Header *rawHeader `json:"header,omitempty"`
Signature *byteBuffer `json:"signature,omitempty"`
}
// rawSignatureInfo represents a single JWS signature over the JWS payload and protected header.
type rawSignatureInfo struct {
Protected *byteBuffer `json:"protected,omitempty"`
Header *rawHeader `json:"header,omitempty"`
Signature *byteBuffer `json:"signature,omitempty"`
}
// JSONWebSignature represents a signed JWS object after parsing.
type JSONWebSignature struct {
payload []byte
// Signatures attached to this object (may be more than one for multi-sig).
// Be careful about accessing these directly, prefer to use Verify() or
// VerifyMulti() to ensure that the data you're getting is verified.
Signatures []Signature
}
// Signature represents a single signature over the JWS payload and protected header.
type Signature struct {
// Merged header fields. Contains both protected and unprotected header
// values. Prefer using Protected and Unprotected fields instead of this.
// Values in this header may or may not have been signed and in general
// should not be trusted.
Header Header
// Protected header. Values in this header were signed and
// will be verified as part of the signature verification process.
Protected Header
// Unprotected header. Values in this header were not signed
// and in general should not be trusted.
Unprotected Header
// The actual signature value
Signature []byte
protected *rawHeader
header *rawHeader
original *rawSignatureInfo
}
// ParseSigned parses a signed message in compact or JWS JSON Serialization format.
func ParseSigned(signature string) (*JSONWebSignature, error) {
signature = stripWhitespace(signature)
if strings.HasPrefix(signature, "{") {
return parseSignedFull(signature)
}
return parseSignedCompact(signature, nil)
}
// ParseDetached parses a signed message in compact serialization format with detached payload.
func ParseDetached(signature string, payload []byte) (*JSONWebSignature, error) {
if payload == nil {
return nil, errors.New("go-jose/go-jose: nil payload")
}
return parseSignedCompact(stripWhitespace(signature), payload)
}
// Get a header value
func (sig Signature) mergedHeaders() rawHeader {
out := rawHeader{}
out.merge(sig.protected)
out.merge(sig.header)
return out
}
// Compute data to be signed
func (obj JSONWebSignature) computeAuthData(payload []byte, signature *Signature) ([]byte, error) {
var authData bytes.Buffer
protectedHeader := new(rawHeader)
if signature.original != nil && signature.original.Protected != nil {
if err := json.Unmarshal(signature.original.Protected.bytes(), protectedHeader); err != nil {
return nil, err
}
authData.WriteString(signature.original.Protected.base64())
} else if signature.protected != nil {
protectedHeader = signature.protected
authData.WriteString(base64.RawURLEncoding.EncodeToString(mustSerializeJSON(protectedHeader)))
}
needsBase64 := true
if protectedHeader != nil {
var err error
if needsBase64, err = protectedHeader.getB64(); err != nil {
needsBase64 = true
}
}
authData.WriteByte('.')
if needsBase64 {
authData.WriteString(base64.RawURLEncoding.EncodeToString(payload))
} else {
authData.Write(payload)
}
return authData.Bytes(), nil
}
// parseSignedFull parses a message in full format.
func parseSignedFull(input string) (*JSONWebSignature, error) {
var parsed rawJSONWebSignature
err := json.Unmarshal([]byte(input), &parsed)
if err != nil {
return nil, err
}
return parsed.sanitized()
}
// sanitized produces a cleaned-up JWS object from the raw JSON.
func (parsed *rawJSONWebSignature) sanitized() (*JSONWebSignature, error) {
if parsed.Payload == nil {
return nil, fmt.Errorf("go-jose/go-jose: missing payload in JWS message")
}
obj := &JSONWebSignature{
payload: parsed.Payload.bytes(),
Signatures: make([]Signature, len(parsed.Signatures)),
}
if len(parsed.Signatures) == 0 {
// No signatures array, must be flattened serialization
signature := Signature{}
if parsed.Protected != nil && len(parsed.Protected.bytes()) > 0 {
signature.protected = &rawHeader{}
err := json.Unmarshal(parsed.Protected.bytes(), signature.protected)
if err != nil {
return nil, err
}
}
// Check that there is not a nonce in the unprotected header
if parsed.Header != nil && parsed.Header.getNonce() != "" {
return nil, ErrUnprotectedNonce
}
signature.header = parsed.Header
signature.Signature = parsed.Signature.bytes()
// Make a fake "original" rawSignatureInfo to store the unprocessed
// Protected header. This is necessary because the Protected header can
// contain arbitrary fields not registered as part of the spec. See
// https://tools.ietf.org/html/draft-ietf-jose-json-web-signature-41#section-4
// If we unmarshal Protected into a rawHeader with its explicit list of fields,
// we cannot marshal losslessly. So we have to keep around the original bytes.
// This is used in computeAuthData, which will first attempt to use
// the original bytes of a protected header, and fall back on marshaling the
// header struct only if those bytes are not available.
signature.original = &rawSignatureInfo{
Protected: parsed.Protected,
Header: parsed.Header,
Signature: parsed.Signature,
}
var err error
signature.Header, err = signature.mergedHeaders().sanitized()
if err != nil {
return nil, err
}
if signature.header != nil {
signature.Unprotected, err = signature.header.sanitized()
if err != nil {
return nil, err
}
}
if signature.protected != nil {
signature.Protected, err = signature.protected.sanitized()
if err != nil {
return nil, err
}
}
// As per RFC 7515 Section 4.1.3, only public keys are allowed to be embedded.
jwk := signature.Header.JSONWebKey
if jwk != nil && (!jwk.Valid() || !jwk.IsPublic()) {
return nil, errors.New("go-jose/go-jose: invalid embedded jwk, must be public key")
}
obj.Signatures = append(obj.Signatures, signature)
}
for i, sig := range parsed.Signatures {
if sig.Protected != nil && len(sig.Protected.bytes()) > 0 {
obj.Signatures[i].protected = &rawHeader{}
err := json.Unmarshal(sig.Protected.bytes(), obj.Signatures[i].protected)
if err != nil {
return nil, err
}
}
// Check that there is not a nonce in the unprotected header
if sig.Header != nil && sig.Header.getNonce() != "" {
return nil, ErrUnprotectedNonce
}
var err error
obj.Signatures[i].Header, err = obj.Signatures[i].mergedHeaders().sanitized()
if err != nil {
return nil, err
}
if obj.Signatures[i].header != nil {
obj.Signatures[i].Unprotected, err = obj.Signatures[i].header.sanitized()
if err != nil {
return nil, err
}
}
if obj.Signatures[i].protected != nil {
obj.Signatures[i].Protected, err = obj.Signatures[i].protected.sanitized()
if err != nil {
return nil, err
}
}
obj.Signatures[i].Signature = sig.Signature.bytes()
// As per RFC 7515 Section 4.1.3, only public keys are allowed to be embedded.
jwk := obj.Signatures[i].Header.JSONWebKey
if jwk != nil && (!jwk.Valid() || !jwk.IsPublic()) {
return nil, errors.New("go-jose/go-jose: invalid embedded jwk, must be public key")
}
// Copy value of sig
original := sig
obj.Signatures[i].header = sig.Header
obj.Signatures[i].original = &original
}
return obj, nil
}
// parseSignedCompact parses a message in compact format.
func parseSignedCompact(input string, payload []byte) (*JSONWebSignature, error) {
parts := strings.Split(input, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("go-jose/go-jose: compact JWS format must have three parts")
}
if parts[1] != "" && payload != nil {
return nil, fmt.Errorf("go-jose/go-jose: payload is not detached")
}
rawProtected, err := base64URLDecode(parts[0])
if err != nil {
return nil, err
}
if payload == nil {
payload, err = base64URLDecode(parts[1])
if err != nil {
return nil, err
}
}
signature, err := base64URLDecode(parts[2])
if err != nil {
return nil, err
}
raw := &rawJSONWebSignature{
Payload: newBuffer(payload),
Protected: newBuffer(rawProtected),
Signature: newBuffer(signature),
}
return raw.sanitized()
}
func (obj JSONWebSignature) compactSerialize(detached bool) (string, error) {
if len(obj.Signatures) != 1 || obj.Signatures[0].header != nil || obj.Signatures[0].protected == nil {
return "", ErrNotSupported
}
serializedProtected := base64.RawURLEncoding.EncodeToString(mustSerializeJSON(obj.Signatures[0].protected))
payload := ""
signature := base64.RawURLEncoding.EncodeToString(obj.Signatures[0].Signature)
if !detached {
payload = base64.RawURLEncoding.EncodeToString(obj.payload)
}
return fmt.Sprintf("%s.%s.%s", serializedProtected, payload, signature), nil
}
// CompactSerialize serializes an object using the compact serialization format.
func (obj JSONWebSignature) CompactSerialize() (string, error) {
return obj.compactSerialize(false)
}
// DetachedCompactSerialize serializes an object using the compact serialization format with detached payload.
func (obj JSONWebSignature) DetachedCompactSerialize() (string, error) {
return obj.compactSerialize(true)
}
// FullSerialize serializes an object using the full JSON serialization format.
func (obj JSONWebSignature) FullSerialize() string {
raw := rawJSONWebSignature{
Payload: newBuffer(obj.payload),
}
if len(obj.Signatures) == 1 {
if obj.Signatures[0].protected != nil {
serializedProtected := mustSerializeJSON(obj.Signatures[0].protected)
raw.Protected = newBuffer(serializedProtected)
}
raw.Header = obj.Signatures[0].header
raw.Signature = newBuffer(obj.Signatures[0].Signature)
} else {
raw.Signatures = make([]rawSignatureInfo, len(obj.Signatures))
for i, signature := range obj.Signatures {
raw.Signatures[i] = rawSignatureInfo{
Header: signature.header,
Signature: newBuffer(signature.Signature),
}
if signature.protected != nil {
raw.Signatures[i].Protected = newBuffer(mustSerializeJSON(signature.protected))
}
}
}
return string(mustSerializeJSON(raw))
}

334
vendor/github.com/go-jose/go-jose/v3/jwt/builder.go generated vendored Normal file
View File

@ -0,0 +1,334 @@
/*-
* Copyright 2016 Zbigniew Mandziejewicz
* Copyright 2016 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package jwt
import (
"bytes"
"reflect"
"github.com/go-jose/go-jose/v3/json"
"github.com/go-jose/go-jose/v3"
)
// Builder is a utility for making JSON Web Tokens. Calls can be chained, and
// errors are accumulated until the final call to CompactSerialize/FullSerialize.
type Builder interface {
// Claims encodes claims into JWE/JWS form. Multiple calls will merge claims
// into single JSON object. If you are passing private claims, make sure to set
// struct field tags to specify the name for the JSON key to be used when
// serializing.
Claims(i interface{}) Builder
// Token builds a JSONWebToken from provided data.
Token() (*JSONWebToken, error)
// FullSerialize serializes a token using the JWS/JWE JSON Serialization format.
FullSerialize() (string, error)
// CompactSerialize serializes a token using the compact serialization format.
CompactSerialize() (string, error)
}
// NestedBuilder is a utility for making Signed-Then-Encrypted JSON Web Tokens.
// Calls can be chained, and errors are accumulated until final call to
// CompactSerialize/FullSerialize.
type NestedBuilder interface {
// Claims encodes claims into JWE/JWS form. Multiple calls will merge claims
// into single JSON object. If you are passing private claims, make sure to set
// struct field tags to specify the name for the JSON key to be used when
// serializing.
Claims(i interface{}) NestedBuilder
// Token builds a NestedJSONWebToken from provided data.
Token() (*NestedJSONWebToken, error)
// FullSerialize serializes a token using the JSON Serialization format.
FullSerialize() (string, error)
// CompactSerialize serializes a token using the compact serialization format.
CompactSerialize() (string, error)
}
type builder struct {
payload map[string]interface{}
err error
}
type signedBuilder struct {
builder
sig jose.Signer
}
type encryptedBuilder struct {
builder
enc jose.Encrypter
}
type nestedBuilder struct {
builder
sig jose.Signer
enc jose.Encrypter
}
// Signed creates builder for signed tokens.
func Signed(sig jose.Signer) Builder {
return &signedBuilder{
sig: sig,
}
}
// Encrypted creates builder for encrypted tokens.
func Encrypted(enc jose.Encrypter) Builder {
return &encryptedBuilder{
enc: enc,
}
}
// SignedAndEncrypted creates builder for signed-then-encrypted tokens.
// ErrInvalidContentType will be returned if encrypter doesn't have JWT content type.
func SignedAndEncrypted(sig jose.Signer, enc jose.Encrypter) NestedBuilder {
if contentType, _ := enc.Options().ExtraHeaders[jose.HeaderContentType].(jose.ContentType); contentType != "JWT" {
return &nestedBuilder{
builder: builder{
err: ErrInvalidContentType,
},
}
}
return &nestedBuilder{
sig: sig,
enc: enc,
}
}
func (b builder) claims(i interface{}) builder {
if b.err != nil {
return b
}
m, ok := i.(map[string]interface{})
switch {
case ok:
return b.merge(m)
case reflect.Indirect(reflect.ValueOf(i)).Kind() == reflect.Struct:
m, err := normalize(i)
if err != nil {
return builder{
err: err,
}
}
return b.merge(m)
default:
return builder{
err: ErrInvalidClaims,
}
}
}
func normalize(i interface{}) (map[string]interface{}, error) {
m := make(map[string]interface{})
raw, err := json.Marshal(i)
if err != nil {
return nil, err
}
d := json.NewDecoder(bytes.NewReader(raw))
d.SetNumberType(json.UnmarshalJSONNumber)
if err := d.Decode(&m); err != nil {
return nil, err
}
return m, nil
}
func (b *builder) merge(m map[string]interface{}) builder {
p := make(map[string]interface{})
for k, v := range b.payload {
p[k] = v
}
for k, v := range m {
p[k] = v
}
return builder{
payload: p,
}
}
func (b *builder) token(p func(interface{}) ([]byte, error), h []jose.Header) (*JSONWebToken, error) {
return &JSONWebToken{
payload: p,
Headers: h,
}, nil
}
func (b *signedBuilder) Claims(i interface{}) Builder {
return &signedBuilder{
builder: b.builder.claims(i),
sig: b.sig,
}
}
func (b *signedBuilder) Token() (*JSONWebToken, error) {
sig, err := b.sign()
if err != nil {
return nil, err
}
h := make([]jose.Header, len(sig.Signatures))
for i, v := range sig.Signatures {
h[i] = v.Header
}
return b.builder.token(sig.Verify, h)
}
func (b *signedBuilder) CompactSerialize() (string, error) {
sig, err := b.sign()
if err != nil {
return "", err
}
return sig.CompactSerialize()
}
func (b *signedBuilder) FullSerialize() (string, error) {
sig, err := b.sign()
if err != nil {
return "", err
}
return sig.FullSerialize(), nil
}
func (b *signedBuilder) sign() (*jose.JSONWebSignature, error) {
if b.err != nil {
return nil, b.err
}
p, err := json.Marshal(b.payload)
if err != nil {
return nil, err
}
return b.sig.Sign(p)
}
func (b *encryptedBuilder) Claims(i interface{}) Builder {
return &encryptedBuilder{
builder: b.builder.claims(i),
enc: b.enc,
}
}
func (b *encryptedBuilder) CompactSerialize() (string, error) {
enc, err := b.encrypt()
if err != nil {
return "", err
}
return enc.CompactSerialize()
}
func (b *encryptedBuilder) FullSerialize() (string, error) {
enc, err := b.encrypt()
if err != nil {
return "", err
}
return enc.FullSerialize(), nil
}
func (b *encryptedBuilder) Token() (*JSONWebToken, error) {
enc, err := b.encrypt()
if err != nil {
return nil, err
}
return b.builder.token(enc.Decrypt, []jose.Header{enc.Header})
}
func (b *encryptedBuilder) encrypt() (*jose.JSONWebEncryption, error) {
if b.err != nil {
return nil, b.err
}
p, err := json.Marshal(b.payload)
if err != nil {
return nil, err
}
return b.enc.Encrypt(p)
}
func (b *nestedBuilder) Claims(i interface{}) NestedBuilder {
return &nestedBuilder{
builder: b.builder.claims(i),
sig: b.sig,
enc: b.enc,
}
}
func (b *nestedBuilder) Token() (*NestedJSONWebToken, error) {
enc, err := b.signAndEncrypt()
if err != nil {
return nil, err
}
return &NestedJSONWebToken{
enc: enc,
Headers: []jose.Header{enc.Header},
}, nil
}
func (b *nestedBuilder) CompactSerialize() (string, error) {
enc, err := b.signAndEncrypt()
if err != nil {
return "", err
}
return enc.CompactSerialize()
}
func (b *nestedBuilder) FullSerialize() (string, error) {
enc, err := b.signAndEncrypt()
if err != nil {
return "", err
}
return enc.FullSerialize(), nil
}
func (b *nestedBuilder) signAndEncrypt() (*jose.JSONWebEncryption, error) {
if b.err != nil {
return nil, b.err
}
p, err := json.Marshal(b.payload)
if err != nil {
return nil, err
}
sig, err := b.sig.Sign(p)
if err != nil {
return nil, err
}
p2, err := sig.CompactSerialize()
if err != nil {
return nil, err
}
return b.enc.Encrypt([]byte(p2))
}

130
vendor/github.com/go-jose/go-jose/v3/jwt/claims.go generated vendored Normal file
View File

@ -0,0 +1,130 @@
/*-
* Copyright 2016 Zbigniew Mandziejewicz
* Copyright 2016 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package jwt
import (
"strconv"
"time"
"github.com/go-jose/go-jose/v3/json"
)
// Claims represents public claim values (as specified in RFC 7519).
type Claims struct {
Issuer string `json:"iss,omitempty"`
Subject string `json:"sub,omitempty"`
Audience Audience `json:"aud,omitempty"`
Expiry *NumericDate `json:"exp,omitempty"`
NotBefore *NumericDate `json:"nbf,omitempty"`
IssuedAt *NumericDate `json:"iat,omitempty"`
ID string `json:"jti,omitempty"`
}
// NumericDate represents date and time as the number of seconds since the
// epoch, ignoring leap seconds. Non-integer values can be represented
// in the serialized format, but we round to the nearest second.
// See RFC7519 Section 2: https://tools.ietf.org/html/rfc7519#section-2
type NumericDate int64
// NewNumericDate constructs NumericDate from time.Time value.
func NewNumericDate(t time.Time) *NumericDate {
if t.IsZero() {
return nil
}
// While RFC 7519 technically states that NumericDate values may be
// non-integer values, we don't bother serializing timestamps in
// claims with sub-second accurancy and just round to the nearest
// second instead. Not convined sub-second accuracy is useful here.
out := NumericDate(t.Unix())
return &out
}
// MarshalJSON serializes the given NumericDate into its JSON representation.
func (n NumericDate) MarshalJSON() ([]byte, error) {
return []byte(strconv.FormatInt(int64(n), 10)), nil
}
// UnmarshalJSON reads a date from its JSON representation.
func (n *NumericDate) UnmarshalJSON(b []byte) error {
s := string(b)
f, err := strconv.ParseFloat(s, 64)
if err != nil {
return ErrUnmarshalNumericDate
}
*n = NumericDate(f)
return nil
}
// Time returns time.Time representation of NumericDate.
func (n *NumericDate) Time() time.Time {
if n == nil {
return time.Time{}
}
return time.Unix(int64(*n), 0)
}
// Audience represents the recipients that the token is intended for.
type Audience []string
// UnmarshalJSON reads an audience from its JSON representation.
func (s *Audience) UnmarshalJSON(b []byte) error {
var v interface{}
if err := json.Unmarshal(b, &v); err != nil {
return err
}
switch v := v.(type) {
case string:
*s = []string{v}
case []interface{}:
a := make([]string, len(v))
for i, e := range v {
s, ok := e.(string)
if !ok {
return ErrUnmarshalAudience
}
a[i] = s
}
*s = a
default:
return ErrUnmarshalAudience
}
return nil
}
// MarshalJSON converts audience to json representation.
func (s Audience) MarshalJSON() ([]byte, error) {
if len(s) == 1 {
return json.Marshal(s[0])
}
return json.Marshal([]string(s))
}
//Contains checks whether a given string is included in the Audience
func (s Audience) Contains(v string) bool {
for _, a := range s {
if a == v {
return true
}
}
return false
}

22
vendor/github.com/go-jose/go-jose/v3/jwt/doc.go generated vendored Normal file
View File

@ -0,0 +1,22 @@
/*-
* Copyright 2017 Square Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
Package jwt provides an implementation of the JSON Web Token standard.
*/
package jwt

53
vendor/github.com/go-jose/go-jose/v3/jwt/errors.go generated vendored Normal file
View File

@ -0,0 +1,53 @@
/*-
* Copyright 2016 Zbigniew Mandziejewicz
* Copyright 2016 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package jwt
import "errors"
// ErrUnmarshalAudience indicates that aud claim could not be unmarshalled.
var ErrUnmarshalAudience = errors.New("go-jose/go-jose/jwt: expected string or array value to unmarshal to Audience")
// ErrUnmarshalNumericDate indicates that JWT NumericDate could not be unmarshalled.
var ErrUnmarshalNumericDate = errors.New("go-jose/go-jose/jwt: expected number value to unmarshal NumericDate")
// ErrInvalidClaims indicates that given claims have invalid type.
var ErrInvalidClaims = errors.New("go-jose/go-jose/jwt: expected claims to be value convertible into JSON object")
// ErrInvalidIssuer indicates invalid iss claim.
var ErrInvalidIssuer = errors.New("go-jose/go-jose/jwt: validation failed, invalid issuer claim (iss)")
// ErrInvalidSubject indicates invalid sub claim.
var ErrInvalidSubject = errors.New("go-jose/go-jose/jwt: validation failed, invalid subject claim (sub)")
// ErrInvalidAudience indicated invalid aud claim.
var ErrInvalidAudience = errors.New("go-jose/go-jose/jwt: validation failed, invalid audience claim (aud)")
// ErrInvalidID indicates invalid jti claim.
var ErrInvalidID = errors.New("go-jose/go-jose/jwt: validation failed, invalid ID claim (jti)")
// ErrNotValidYet indicates that token is used before time indicated in nbf claim.
var ErrNotValidYet = errors.New("go-jose/go-jose/jwt: validation failed, token not valid yet (nbf)")
// ErrExpired indicates that token is used after expiry time indicated in exp claim.
var ErrExpired = errors.New("go-jose/go-jose/jwt: validation failed, token is expired (exp)")
// ErrIssuedInTheFuture indicates that the iat field is in the future.
var ErrIssuedInTheFuture = errors.New("go-jose/go-jose/jwt: validation field, token issued in the future (iat)")
// ErrInvalidContentType indicates that token requires JWT cty header.
var ErrInvalidContentType = errors.New("go-jose/go-jose/jwt: expected content type to be JWT (cty header)")

133
vendor/github.com/go-jose/go-jose/v3/jwt/jwt.go generated vendored Normal file
View File

@ -0,0 +1,133 @@
/*-
* Copyright 2016 Zbigniew Mandziejewicz
* Copyright 2016 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package jwt
import (
"fmt"
"strings"
jose "github.com/go-jose/go-jose/v3"
"github.com/go-jose/go-jose/v3/json"
)
// JSONWebToken represents a JSON Web Token (as specified in RFC7519).
type JSONWebToken struct {
payload func(k interface{}) ([]byte, error)
unverifiedPayload func() []byte
Headers []jose.Header
}
type NestedJSONWebToken struct {
enc *jose.JSONWebEncryption
Headers []jose.Header
}
// Claims deserializes a JSONWebToken into dest using the provided key.
func (t *JSONWebToken) Claims(key interface{}, dest ...interface{}) error {
b, err := t.payload(key)
if err != nil {
return err
}
for _, d := range dest {
if err := json.Unmarshal(b, d); err != nil {
return err
}
}
return nil
}
// UnsafeClaimsWithoutVerification deserializes the claims of a
// JSONWebToken into the dests. For signed JWTs, the claims are not
// verified. This function won't work for encrypted JWTs.
func (t *JSONWebToken) UnsafeClaimsWithoutVerification(dest ...interface{}) error {
if t.unverifiedPayload == nil {
return fmt.Errorf("go-jose/go-jose: Cannot get unverified claims")
}
claims := t.unverifiedPayload()
for _, d := range dest {
if err := json.Unmarshal(claims, d); err != nil {
return err
}
}
return nil
}
func (t *NestedJSONWebToken) Decrypt(decryptionKey interface{}) (*JSONWebToken, error) {
b, err := t.enc.Decrypt(decryptionKey)
if err != nil {
return nil, err
}
sig, err := ParseSigned(string(b))
if err != nil {
return nil, err
}
return sig, nil
}
// ParseSigned parses token from JWS form.
func ParseSigned(s string) (*JSONWebToken, error) {
sig, err := jose.ParseSigned(s)
if err != nil {
return nil, err
}
headers := make([]jose.Header, len(sig.Signatures))
for i, signature := range sig.Signatures {
headers[i] = signature.Header
}
return &JSONWebToken{
payload: sig.Verify,
unverifiedPayload: sig.UnsafePayloadWithoutVerification,
Headers: headers,
}, nil
}
// ParseEncrypted parses token from JWE form.
func ParseEncrypted(s string) (*JSONWebToken, error) {
enc, err := jose.ParseEncrypted(s)
if err != nil {
return nil, err
}
return &JSONWebToken{
payload: enc.Decrypt,
Headers: []jose.Header{enc.Header},
}, nil
}
// ParseSignedAndEncrypted parses signed-then-encrypted token from JWE form.
func ParseSignedAndEncrypted(s string) (*NestedJSONWebToken, error) {
enc, err := jose.ParseEncrypted(s)
if err != nil {
return nil, err
}
contentType, _ := enc.Header.ExtraHeaders[jose.HeaderContentType].(string)
if strings.ToUpper(contentType) != "JWT" {
return nil, ErrInvalidContentType
}
return &NestedJSONWebToken{
enc: enc,
Headers: []jose.Header{enc.Header},
}, nil
}

120
vendor/github.com/go-jose/go-jose/v3/jwt/validation.go generated vendored Normal file
View File

@ -0,0 +1,120 @@
/*-
* Copyright 2016 Zbigniew Mandziejewicz
* Copyright 2016 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package jwt
import "time"
const (
// DefaultLeeway defines the default leeway for matching NotBefore/Expiry claims.
DefaultLeeway = 1.0 * time.Minute
)
// Expected defines values used for protected claims validation.
// If field has zero value then validation is skipped, with the exception of
// Time, where the zero value means "now." To skip validating them, set the
// corresponding field in the Claims struct to nil.
type Expected struct {
// Issuer matches the "iss" claim exactly.
Issuer string
// Subject matches the "sub" claim exactly.
Subject string
// Audience matches the values in "aud" claim, regardless of their order.
Audience Audience
// ID matches the "jti" claim exactly.
ID string
// Time matches the "exp", "nbf" and "iat" claims with leeway.
Time time.Time
}
// WithTime copies expectations with new time.
func (e Expected) WithTime(t time.Time) Expected {
e.Time = t
return e
}
// Validate checks claims in a token against expected values.
// A default leeway value of one minute is used to compare time values.
//
// The default leeway will cause the token to be deemed valid until one
// minute after the expiration time. If you're a server application that
// wants to give an extra minute to client tokens, use this
// function. If you're a client application wondering if the server
// will accept your token, use ValidateWithLeeway with a leeway <=0,
// otherwise this function might make you think a token is valid when
// it is not.
func (c Claims) Validate(e Expected) error {
return c.ValidateWithLeeway(e, DefaultLeeway)
}
// ValidateWithLeeway checks claims in a token against expected values. A
// custom leeway may be specified for comparing time values. You may pass a
// zero value to check time values with no leeway, but you should note that
// numeric date values are rounded to the nearest second and sub-second
// precision is not supported.
//
// The leeway gives some extra time to the token from the server's
// point of view. That is, if the token is expired, ValidateWithLeeway
// will still accept the token for 'leeway' amount of time. This fails
// if you're using this function to check if a server will accept your
// token, because it will think the token is valid even after it
// expires. So if you're a client validating if the token is valid to
// be submitted to a server, use leeway <=0, if you're a server
// validation a token, use leeway >=0.
func (c Claims) ValidateWithLeeway(e Expected, leeway time.Duration) error {
if e.Issuer != "" && e.Issuer != c.Issuer {
return ErrInvalidIssuer
}
if e.Subject != "" && e.Subject != c.Subject {
return ErrInvalidSubject
}
if e.ID != "" && e.ID != c.ID {
return ErrInvalidID
}
if len(e.Audience) != 0 {
for _, v := range e.Audience {
if !c.Audience.Contains(v) {
return ErrInvalidAudience
}
}
}
// validate using the e.Time, or time.Now if not provided
validationTime := e.Time
if validationTime.IsZero() {
validationTime = time.Now()
}
if c.NotBefore != nil && validationTime.Add(leeway).Before(c.NotBefore.Time()) {
return ErrNotValidYet
}
if c.Expiry != nil && validationTime.Add(-leeway).After(c.Expiry.Time()) {
return ErrExpired
}
// IssuedAt is optional but cannot be in the future. This is not required by the RFC, but
// something is misconfigured if this happens and we should not trust it.
if c.IssuedAt != nil && validationTime.Add(leeway).Before(c.IssuedAt.Time()) {
return ErrIssuedInTheFuture
}
return nil
}

144
vendor/github.com/go-jose/go-jose/v3/opaque.go generated vendored Normal file
View File

@ -0,0 +1,144 @@
/*-
* Copyright 2018 Square Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package jose
// OpaqueSigner is an interface that supports signing payloads with opaque
// private key(s). Private key operations performed by implementers may, for
// example, occur in a hardware module. An OpaqueSigner may rotate signing keys
// transparently to the user of this interface.
type OpaqueSigner interface {
// Public returns the public key of the current signing key.
Public() *JSONWebKey
// Algs returns a list of supported signing algorithms.
Algs() []SignatureAlgorithm
// SignPayload signs a payload with the current signing key using the given
// algorithm.
SignPayload(payload []byte, alg SignatureAlgorithm) ([]byte, error)
}
type opaqueSigner struct {
signer OpaqueSigner
}
func newOpaqueSigner(alg SignatureAlgorithm, signer OpaqueSigner) (recipientSigInfo, error) {
var algSupported bool
for _, salg := range signer.Algs() {
if alg == salg {
algSupported = true
break
}
}
if !algSupported {
return recipientSigInfo{}, ErrUnsupportedAlgorithm
}
return recipientSigInfo{
sigAlg: alg,
publicKey: signer.Public,
signer: &opaqueSigner{
signer: signer,
},
}, nil
}
func (o *opaqueSigner) signPayload(payload []byte, alg SignatureAlgorithm) (Signature, error) {
out, err := o.signer.SignPayload(payload, alg)
if err != nil {
return Signature{}, err
}
return Signature{
Signature: out,
protected: &rawHeader{},
}, nil
}
// OpaqueVerifier is an interface that supports verifying payloads with opaque
// public key(s). An OpaqueSigner may rotate signing keys transparently to the
// user of this interface.
type OpaqueVerifier interface {
VerifyPayload(payload []byte, signature []byte, alg SignatureAlgorithm) error
}
type opaqueVerifier struct {
verifier OpaqueVerifier
}
func (o *opaqueVerifier) verifyPayload(payload []byte, signature []byte, alg SignatureAlgorithm) error {
return o.verifier.VerifyPayload(payload, signature, alg)
}
// OpaqueKeyEncrypter is an interface that supports encrypting keys with an opaque key.
type OpaqueKeyEncrypter interface {
// KeyID returns the kid
KeyID() string
// Algs returns a list of supported key encryption algorithms.
Algs() []KeyAlgorithm
// encryptKey encrypts the CEK using the given algorithm.
encryptKey(cek []byte, alg KeyAlgorithm) (recipientInfo, error)
}
type opaqueKeyEncrypter struct {
encrypter OpaqueKeyEncrypter
}
func newOpaqueKeyEncrypter(alg KeyAlgorithm, encrypter OpaqueKeyEncrypter) (recipientKeyInfo, error) {
var algSupported bool
for _, salg := range encrypter.Algs() {
if alg == salg {
algSupported = true
break
}
}
if !algSupported {
return recipientKeyInfo{}, ErrUnsupportedAlgorithm
}
return recipientKeyInfo{
keyID: encrypter.KeyID(),
keyAlg: alg,
keyEncrypter: &opaqueKeyEncrypter{
encrypter: encrypter,
},
}, nil
}
func (oke *opaqueKeyEncrypter) encryptKey(cek []byte, alg KeyAlgorithm) (recipientInfo, error) {
return oke.encrypter.encryptKey(cek, alg)
}
//OpaqueKeyDecrypter is an interface that supports decrypting keys with an opaque key.
type OpaqueKeyDecrypter interface {
DecryptKey(encryptedKey []byte, header Header) ([]byte, error)
}
type opaqueKeyDecrypter struct {
decrypter OpaqueKeyDecrypter
}
func (okd *opaqueKeyDecrypter) decryptKey(headers rawHeader, recipient *recipientInfo, generator keyGenerator) ([]byte, error) {
mergedHeaders := rawHeader{}
mergedHeaders.merge(&headers)
mergedHeaders.merge(recipient.header)
header, err := mergedHeaders.sanitized()
if err != nil {
return nil, err
}
return okd.decrypter.DecryptKey(recipient.encryptedKey, header)
}

520
vendor/github.com/go-jose/go-jose/v3/shared.go generated vendored Normal file
View File

@ -0,0 +1,520 @@
/*-
* Copyright 2014 Square Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package jose
import (
"crypto/elliptic"
"crypto/x509"
"encoding/base64"
"errors"
"fmt"
"github.com/go-jose/go-jose/v3/json"
)
// KeyAlgorithm represents a key management algorithm.
type KeyAlgorithm string
// SignatureAlgorithm represents a signature (or MAC) algorithm.
type SignatureAlgorithm string
// ContentEncryption represents a content encryption algorithm.
type ContentEncryption string
// CompressionAlgorithm represents an algorithm used for plaintext compression.
type CompressionAlgorithm string
// ContentType represents type of the contained data.
type ContentType string
var (
// ErrCryptoFailure represents an error in cryptographic primitive. This
// occurs when, for example, a message had an invalid authentication tag or
// could not be decrypted.
ErrCryptoFailure = errors.New("go-jose/go-jose: error in cryptographic primitive")
// ErrUnsupportedAlgorithm indicates that a selected algorithm is not
// supported. This occurs when trying to instantiate an encrypter for an
// algorithm that is not yet implemented.
ErrUnsupportedAlgorithm = errors.New("go-jose/go-jose: unknown/unsupported algorithm")
// ErrUnsupportedKeyType indicates that the given key type/format is not
// supported. This occurs when trying to instantiate an encrypter and passing
// it a key of an unrecognized type or with unsupported parameters, such as
// an RSA private key with more than two primes.
ErrUnsupportedKeyType = errors.New("go-jose/go-jose: unsupported key type/format")
// ErrInvalidKeySize indicates that the given key is not the correct size
// for the selected algorithm. This can occur, for example, when trying to
// encrypt with AES-256 but passing only a 128-bit key as input.
ErrInvalidKeySize = errors.New("go-jose/go-jose: invalid key size for algorithm")
// ErrNotSupported serialization of object is not supported. This occurs when
// trying to compact-serialize an object which can't be represented in
// compact form.
ErrNotSupported = errors.New("go-jose/go-jose: compact serialization not supported for object")
// ErrUnprotectedNonce indicates that while parsing a JWS or JWE object, a
// nonce header parameter was included in an unprotected header object.
ErrUnprotectedNonce = errors.New("go-jose/go-jose: Nonce parameter included in unprotected header")
)
// Key management algorithms
const (
ED25519 = KeyAlgorithm("ED25519")
RSA1_5 = KeyAlgorithm("RSA1_5") // RSA-PKCS1v1.5
RSA_OAEP = KeyAlgorithm("RSA-OAEP") // RSA-OAEP-SHA1
RSA_OAEP_256 = KeyAlgorithm("RSA-OAEP-256") // RSA-OAEP-SHA256
A128KW = KeyAlgorithm("A128KW") // AES key wrap (128)
A192KW = KeyAlgorithm("A192KW") // AES key wrap (192)
A256KW = KeyAlgorithm("A256KW") // AES key wrap (256)
DIRECT = KeyAlgorithm("dir") // Direct encryption
ECDH_ES = KeyAlgorithm("ECDH-ES") // ECDH-ES
ECDH_ES_A128KW = KeyAlgorithm("ECDH-ES+A128KW") // ECDH-ES + AES key wrap (128)
ECDH_ES_A192KW = KeyAlgorithm("ECDH-ES+A192KW") // ECDH-ES + AES key wrap (192)
ECDH_ES_A256KW = KeyAlgorithm("ECDH-ES+A256KW") // ECDH-ES + AES key wrap (256)
A128GCMKW = KeyAlgorithm("A128GCMKW") // AES-GCM key wrap (128)
A192GCMKW = KeyAlgorithm("A192GCMKW") // AES-GCM key wrap (192)
A256GCMKW = KeyAlgorithm("A256GCMKW") // AES-GCM key wrap (256)
PBES2_HS256_A128KW = KeyAlgorithm("PBES2-HS256+A128KW") // PBES2 + HMAC-SHA256 + AES key wrap (128)
PBES2_HS384_A192KW = KeyAlgorithm("PBES2-HS384+A192KW") // PBES2 + HMAC-SHA384 + AES key wrap (192)
PBES2_HS512_A256KW = KeyAlgorithm("PBES2-HS512+A256KW") // PBES2 + HMAC-SHA512 + AES key wrap (256)
)
// Signature algorithms
const (
EdDSA = SignatureAlgorithm("EdDSA")
HS256 = SignatureAlgorithm("HS256") // HMAC using SHA-256
HS384 = SignatureAlgorithm("HS384") // HMAC using SHA-384
HS512 = SignatureAlgorithm("HS512") // HMAC using SHA-512
RS256 = SignatureAlgorithm("RS256") // RSASSA-PKCS-v1.5 using SHA-256
RS384 = SignatureAlgorithm("RS384") // RSASSA-PKCS-v1.5 using SHA-384
RS512 = SignatureAlgorithm("RS512") // RSASSA-PKCS-v1.5 using SHA-512
ES256 = SignatureAlgorithm("ES256") // ECDSA using P-256 and SHA-256
ES384 = SignatureAlgorithm("ES384") // ECDSA using P-384 and SHA-384
ES512 = SignatureAlgorithm("ES512") // ECDSA using P-521 and SHA-512
PS256 = SignatureAlgorithm("PS256") // RSASSA-PSS using SHA256 and MGF1-SHA256
PS384 = SignatureAlgorithm("PS384") // RSASSA-PSS using SHA384 and MGF1-SHA384
PS512 = SignatureAlgorithm("PS512") // RSASSA-PSS using SHA512 and MGF1-SHA512
)
// Content encryption algorithms
const (
A128CBC_HS256 = ContentEncryption("A128CBC-HS256") // AES-CBC + HMAC-SHA256 (128)
A192CBC_HS384 = ContentEncryption("A192CBC-HS384") // AES-CBC + HMAC-SHA384 (192)
A256CBC_HS512 = ContentEncryption("A256CBC-HS512") // AES-CBC + HMAC-SHA512 (256)
A128GCM = ContentEncryption("A128GCM") // AES-GCM (128)
A192GCM = ContentEncryption("A192GCM") // AES-GCM (192)
A256GCM = ContentEncryption("A256GCM") // AES-GCM (256)
)
// Compression algorithms
const (
NONE = CompressionAlgorithm("") // No compression
DEFLATE = CompressionAlgorithm("DEF") // DEFLATE (RFC 1951)
)
// A key in the protected header of a JWS object. Use of the Header...
// constants is preferred to enhance type safety.
type HeaderKey string
const (
HeaderType = "typ" // string
HeaderContentType = "cty" // string
// These are set by go-jose and shouldn't need to be set by consumers of the
// library.
headerAlgorithm = "alg" // string
headerEncryption = "enc" // ContentEncryption
headerCompression = "zip" // CompressionAlgorithm
headerCritical = "crit" // []string
headerAPU = "apu" // *byteBuffer
headerAPV = "apv" // *byteBuffer
headerEPK = "epk" // *JSONWebKey
headerIV = "iv" // *byteBuffer
headerTag = "tag" // *byteBuffer
headerX5c = "x5c" // []*x509.Certificate
headerJWK = "jwk" // *JSONWebKey
headerKeyID = "kid" // string
headerNonce = "nonce" // string
headerB64 = "b64" // bool
headerP2C = "p2c" // *byteBuffer (int)
headerP2S = "p2s" // *byteBuffer ([]byte)
)
// supportedCritical is the set of supported extensions that are understood and processed.
var supportedCritical = map[string]bool{
headerB64: true,
}
// rawHeader represents the JOSE header for JWE/JWS objects (used for parsing).
//
// The decoding of the constituent items is deferred because we want to marshal
// some members into particular structs rather than generic maps, but at the
// same time we need to receive any extra fields unhandled by this library to
// pass through to consuming code in case it wants to examine them.
type rawHeader map[HeaderKey]*json.RawMessage
// Header represents the read-only JOSE header for JWE/JWS objects.
type Header struct {
KeyID string
JSONWebKey *JSONWebKey
Algorithm string
Nonce string
// Unverified certificate chain parsed from x5c header.
certificates []*x509.Certificate
// Any headers not recognised above get unmarshalled
// from JSON in a generic manner and placed in this map.
ExtraHeaders map[HeaderKey]interface{}
}
// Certificates verifies & returns the certificate chain present
// in the x5c header field of a message, if one was present. Returns
// an error if there was no x5c header present or the chain could
// not be validated with the given verify options.
func (h Header) Certificates(opts x509.VerifyOptions) ([][]*x509.Certificate, error) {
if len(h.certificates) == 0 {
return nil, errors.New("go-jose/go-jose: no x5c header present in message")
}
leaf := h.certificates[0]
if opts.Intermediates == nil {
opts.Intermediates = x509.NewCertPool()
for _, intermediate := range h.certificates[1:] {
opts.Intermediates.AddCert(intermediate)
}
}
return leaf.Verify(opts)
}
func (parsed rawHeader) set(k HeaderKey, v interface{}) error {
b, err := json.Marshal(v)
if err != nil {
return err
}
parsed[k] = makeRawMessage(b)
return nil
}
// getString gets a string from the raw JSON, defaulting to "".
func (parsed rawHeader) getString(k HeaderKey) string {
v, ok := parsed[k]
if !ok || v == nil {
return ""
}
var s string
err := json.Unmarshal(*v, &s)
if err != nil {
return ""
}
return s
}
// getByteBuffer gets a byte buffer from the raw JSON. Returns (nil, nil) if
// not specified.
func (parsed rawHeader) getByteBuffer(k HeaderKey) (*byteBuffer, error) {
v := parsed[k]
if v == nil {
return nil, nil
}
var bb *byteBuffer
err := json.Unmarshal(*v, &bb)
if err != nil {
return nil, err
}
return bb, nil
}
// getAlgorithm extracts parsed "alg" from the raw JSON as a KeyAlgorithm.
func (parsed rawHeader) getAlgorithm() KeyAlgorithm {
return KeyAlgorithm(parsed.getString(headerAlgorithm))
}
// getSignatureAlgorithm extracts parsed "alg" from the raw JSON as a SignatureAlgorithm.
func (parsed rawHeader) getSignatureAlgorithm() SignatureAlgorithm {
return SignatureAlgorithm(parsed.getString(headerAlgorithm))
}
// getEncryption extracts parsed "enc" from the raw JSON.
func (parsed rawHeader) getEncryption() ContentEncryption {
return ContentEncryption(parsed.getString(headerEncryption))
}
// getCompression extracts parsed "zip" from the raw JSON.
func (parsed rawHeader) getCompression() CompressionAlgorithm {
return CompressionAlgorithm(parsed.getString(headerCompression))
}
func (parsed rawHeader) getNonce() string {
return parsed.getString(headerNonce)
}
// getEPK extracts parsed "epk" from the raw JSON.
func (parsed rawHeader) getEPK() (*JSONWebKey, error) {
v := parsed[headerEPK]
if v == nil {
return nil, nil
}
var epk *JSONWebKey
err := json.Unmarshal(*v, &epk)
if err != nil {
return nil, err
}
return epk, nil
}
// getAPU extracts parsed "apu" from the raw JSON.
func (parsed rawHeader) getAPU() (*byteBuffer, error) {
return parsed.getByteBuffer(headerAPU)
}
// getAPV extracts parsed "apv" from the raw JSON.
func (parsed rawHeader) getAPV() (*byteBuffer, error) {
return parsed.getByteBuffer(headerAPV)
}
// getIV extracts parsed "iv" from the raw JSON.
func (parsed rawHeader) getIV() (*byteBuffer, error) {
return parsed.getByteBuffer(headerIV)
}
// getTag extracts parsed "tag" from the raw JSON.
func (parsed rawHeader) getTag() (*byteBuffer, error) {
return parsed.getByteBuffer(headerTag)
}
// getJWK extracts parsed "jwk" from the raw JSON.
func (parsed rawHeader) getJWK() (*JSONWebKey, error) {
v := parsed[headerJWK]
if v == nil {
return nil, nil
}
var jwk *JSONWebKey
err := json.Unmarshal(*v, &jwk)
if err != nil {
return nil, err
}
return jwk, nil
}
// getCritical extracts parsed "crit" from the raw JSON. If omitted, it
// returns an empty slice.
func (parsed rawHeader) getCritical() ([]string, error) {
v := parsed[headerCritical]
if v == nil {
return nil, nil
}
var q []string
err := json.Unmarshal(*v, &q)
if err != nil {
return nil, err
}
return q, nil
}
// getS2C extracts parsed "p2c" from the raw JSON.
func (parsed rawHeader) getP2C() (int, error) {
v := parsed[headerP2C]
if v == nil {
return 0, nil
}
var p2c int
err := json.Unmarshal(*v, &p2c)
if err != nil {
return 0, err
}
return p2c, nil
}
// getS2S extracts parsed "p2s" from the raw JSON.
func (parsed rawHeader) getP2S() (*byteBuffer, error) {
return parsed.getByteBuffer(headerP2S)
}
// getB64 extracts parsed "b64" from the raw JSON, defaulting to true.
func (parsed rawHeader) getB64() (bool, error) {
v := parsed[headerB64]
if v == nil {
return true, nil
}
var b64 bool
err := json.Unmarshal(*v, &b64)
if err != nil {
return true, err
}
return b64, nil
}
// sanitized produces a cleaned-up header object from the raw JSON.
func (parsed rawHeader) sanitized() (h Header, err error) {
for k, v := range parsed {
if v == nil {
continue
}
switch k {
case headerJWK:
var jwk *JSONWebKey
err = json.Unmarshal(*v, &jwk)
if err != nil {
err = fmt.Errorf("failed to unmarshal JWK: %v: %#v", err, string(*v))
return
}
h.JSONWebKey = jwk
case headerKeyID:
var s string
err = json.Unmarshal(*v, &s)
if err != nil {
err = fmt.Errorf("failed to unmarshal key ID: %v: %#v", err, string(*v))
return
}
h.KeyID = s
case headerAlgorithm:
var s string
err = json.Unmarshal(*v, &s)
if err != nil {
err = fmt.Errorf("failed to unmarshal algorithm: %v: %#v", err, string(*v))
return
}
h.Algorithm = s
case headerNonce:
var s string
err = json.Unmarshal(*v, &s)
if err != nil {
err = fmt.Errorf("failed to unmarshal nonce: %v: %#v", err, string(*v))
return
}
h.Nonce = s
case headerX5c:
c := []string{}
err = json.Unmarshal(*v, &c)
if err != nil {
err = fmt.Errorf("failed to unmarshal x5c header: %v: %#v", err, string(*v))
return
}
h.certificates, err = parseCertificateChain(c)
if err != nil {
err = fmt.Errorf("failed to unmarshal x5c header: %v: %#v", err, string(*v))
return
}
default:
if h.ExtraHeaders == nil {
h.ExtraHeaders = map[HeaderKey]interface{}{}
}
var v2 interface{}
err = json.Unmarshal(*v, &v2)
if err != nil {
err = fmt.Errorf("failed to unmarshal value: %v: %#v", err, string(*v))
return
}
h.ExtraHeaders[k] = v2
}
}
return
}
func parseCertificateChain(chain []string) ([]*x509.Certificate, error) {
out := make([]*x509.Certificate, len(chain))
for i, cert := range chain {
raw, err := base64.StdEncoding.DecodeString(cert)
if err != nil {
return nil, err
}
out[i], err = x509.ParseCertificate(raw)
if err != nil {
return nil, err
}
}
return out, nil
}
func (parsed rawHeader) isSet(k HeaderKey) bool {
dvr := parsed[k]
if dvr == nil {
return false
}
var dv interface{}
err := json.Unmarshal(*dvr, &dv)
if err != nil {
return true
}
if dvStr, ok := dv.(string); ok {
return dvStr != ""
}
return true
}
// Merge headers from src into dst, giving precedence to headers from l.
func (parsed rawHeader) merge(src *rawHeader) {
if src == nil {
return
}
for k, v := range *src {
if parsed.isSet(k) {
continue
}
parsed[k] = v
}
}
// Get JOSE name of curve
func curveName(crv elliptic.Curve) (string, error) {
switch crv {
case elliptic.P256():
return "P-256", nil
case elliptic.P384():
return "P-384", nil
case elliptic.P521():
return "P-521", nil
default:
return "", fmt.Errorf("go-jose/go-jose: unsupported/unknown elliptic curve")
}
}
// Get size of curve in bytes
func curveSize(crv elliptic.Curve) int {
bits := crv.Params().BitSize
div := bits / 8
mod := bits % 8
if mod == 0 {
return div
}
return div + 1
}
func makeRawMessage(b []byte) *json.RawMessage {
rm := json.RawMessage(b)
return &rm
}

450
vendor/github.com/go-jose/go-jose/v3/signing.go generated vendored Normal file
View File

@ -0,0 +1,450 @@
/*-
* Copyright 2014 Square Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package jose
import (
"bytes"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
"encoding/base64"
"errors"
"fmt"
"github.com/go-jose/go-jose/v3/json"
)
// NonceSource represents a source of random nonces to go into JWS objects
type NonceSource interface {
Nonce() (string, error)
}
// Signer represents a signer which takes a payload and produces a signed JWS object.
type Signer interface {
Sign(payload []byte) (*JSONWebSignature, error)
Options() SignerOptions
}
// SigningKey represents an algorithm/key used to sign a message.
type SigningKey struct {
Algorithm SignatureAlgorithm
Key interface{}
}
// SignerOptions represents options that can be set when creating signers.
type SignerOptions struct {
NonceSource NonceSource
EmbedJWK bool
// Optional map of additional keys to be inserted into the protected header
// of a JWS object. Some specifications which make use of JWS like to insert
// additional values here. All values must be JSON-serializable.
ExtraHeaders map[HeaderKey]interface{}
}
// WithHeader adds an arbitrary value to the ExtraHeaders map, initializing it
// if necessary. It returns itself and so can be used in a fluent style.
func (so *SignerOptions) WithHeader(k HeaderKey, v interface{}) *SignerOptions {
if so.ExtraHeaders == nil {
so.ExtraHeaders = map[HeaderKey]interface{}{}
}
so.ExtraHeaders[k] = v
return so
}
// WithContentType adds a content type ("cty") header and returns the updated
// SignerOptions.
func (so *SignerOptions) WithContentType(contentType ContentType) *SignerOptions {
return so.WithHeader(HeaderContentType, contentType)
}
// WithType adds a type ("typ") header and returns the updated SignerOptions.
func (so *SignerOptions) WithType(typ ContentType) *SignerOptions {
return so.WithHeader(HeaderType, typ)
}
// WithCritical adds the given names to the critical ("crit") header and returns
// the updated SignerOptions.
func (so *SignerOptions) WithCritical(names ...string) *SignerOptions {
if so.ExtraHeaders[headerCritical] == nil {
so.WithHeader(headerCritical, make([]string, 0, len(names)))
}
crit := so.ExtraHeaders[headerCritical].([]string)
so.ExtraHeaders[headerCritical] = append(crit, names...)
return so
}
// WithBase64 adds a base64url-encode payload ("b64") header and returns the updated
// SignerOptions. When the "b64" value is "false", the payload is not base64 encoded.
func (so *SignerOptions) WithBase64(b64 bool) *SignerOptions {
if !b64 {
so.WithHeader(headerB64, b64)
so.WithCritical(headerB64)
}
return so
}
type payloadSigner interface {
signPayload(payload []byte, alg SignatureAlgorithm) (Signature, error)
}
type payloadVerifier interface {
verifyPayload(payload []byte, signature []byte, alg SignatureAlgorithm) error
}
type genericSigner struct {
recipients []recipientSigInfo
nonceSource NonceSource
embedJWK bool
extraHeaders map[HeaderKey]interface{}
}
type recipientSigInfo struct {
sigAlg SignatureAlgorithm
publicKey func() *JSONWebKey
signer payloadSigner
}
func staticPublicKey(jwk *JSONWebKey) func() *JSONWebKey {
return func() *JSONWebKey {
return jwk
}
}
// NewSigner creates an appropriate signer based on the key type
func NewSigner(sig SigningKey, opts *SignerOptions) (Signer, error) {
return NewMultiSigner([]SigningKey{sig}, opts)
}
// NewMultiSigner creates a signer for multiple recipients
func NewMultiSigner(sigs []SigningKey, opts *SignerOptions) (Signer, error) {
signer := &genericSigner{recipients: []recipientSigInfo{}}
if opts != nil {
signer.nonceSource = opts.NonceSource
signer.embedJWK = opts.EmbedJWK
signer.extraHeaders = opts.ExtraHeaders
}
for _, sig := range sigs {
err := signer.addRecipient(sig.Algorithm, sig.Key)
if err != nil {
return nil, err
}
}
return signer, nil
}
// newVerifier creates a verifier based on the key type
func newVerifier(verificationKey interface{}) (payloadVerifier, error) {
switch verificationKey := verificationKey.(type) {
case ed25519.PublicKey:
return &edEncrypterVerifier{
publicKey: verificationKey,
}, nil
case *rsa.PublicKey:
return &rsaEncrypterVerifier{
publicKey: verificationKey,
}, nil
case *ecdsa.PublicKey:
return &ecEncrypterVerifier{
publicKey: verificationKey,
}, nil
case []byte:
return &symmetricMac{
key: verificationKey,
}, nil
case JSONWebKey:
return newVerifier(verificationKey.Key)
case *JSONWebKey:
return newVerifier(verificationKey.Key)
}
if ov, ok := verificationKey.(OpaqueVerifier); ok {
return &opaqueVerifier{verifier: ov}, nil
}
return nil, ErrUnsupportedKeyType
}
func (ctx *genericSigner) addRecipient(alg SignatureAlgorithm, signingKey interface{}) error {
recipient, err := makeJWSRecipient(alg, signingKey)
if err != nil {
return err
}
ctx.recipients = append(ctx.recipients, recipient)
return nil
}
func makeJWSRecipient(alg SignatureAlgorithm, signingKey interface{}) (recipientSigInfo, error) {
switch signingKey := signingKey.(type) {
case ed25519.PrivateKey:
return newEd25519Signer(alg, signingKey)
case *rsa.PrivateKey:
return newRSASigner(alg, signingKey)
case *ecdsa.PrivateKey:
return newECDSASigner(alg, signingKey)
case []byte:
return newSymmetricSigner(alg, signingKey)
case JSONWebKey:
return newJWKSigner(alg, signingKey)
case *JSONWebKey:
return newJWKSigner(alg, *signingKey)
}
if signer, ok := signingKey.(OpaqueSigner); ok {
return newOpaqueSigner(alg, signer)
}
return recipientSigInfo{}, ErrUnsupportedKeyType
}
func newJWKSigner(alg SignatureAlgorithm, signingKey JSONWebKey) (recipientSigInfo, error) {
recipient, err := makeJWSRecipient(alg, signingKey.Key)
if err != nil {
return recipientSigInfo{}, err
}
if recipient.publicKey != nil && recipient.publicKey() != nil {
// recipient.publicKey is a JWK synthesized for embedding when recipientSigInfo
// was created for the inner key (such as a RSA or ECDSA public key). It contains
// the pub key for embedding, but doesn't have extra params like key id.
publicKey := signingKey
publicKey.Key = recipient.publicKey().Key
recipient.publicKey = staticPublicKey(&publicKey)
// This should be impossible, but let's check anyway.
if !recipient.publicKey().IsPublic() {
return recipientSigInfo{}, errors.New("go-jose/go-jose: public key was unexpectedly not public")
}
}
return recipient, nil
}
func (ctx *genericSigner) Sign(payload []byte) (*JSONWebSignature, error) {
obj := &JSONWebSignature{}
obj.payload = payload
obj.Signatures = make([]Signature, len(ctx.recipients))
for i, recipient := range ctx.recipients {
protected := map[HeaderKey]interface{}{
headerAlgorithm: string(recipient.sigAlg),
}
if recipient.publicKey != nil && recipient.publicKey() != nil {
// We want to embed the JWK or set the kid header, but not both. Having a protected
// header that contains an embedded JWK while also simultaneously containing the kid
// header is confusing, and at least in ACME the two are considered to be mutually
// exclusive. The fact that both can exist at the same time is a somewhat unfortunate
// result of the JOSE spec. We've decided that this library will only include one or
// the other to avoid this confusion.
//
// See https://github.com/go-jose/go-jose/issues/157 for more context.
if ctx.embedJWK {
protected[headerJWK] = recipient.publicKey()
} else {
keyID := recipient.publicKey().KeyID
if keyID != "" {
protected[headerKeyID] = keyID
}
}
}
if ctx.nonceSource != nil {
nonce, err := ctx.nonceSource.Nonce()
if err != nil {
return nil, fmt.Errorf("go-jose/go-jose: Error generating nonce: %v", err)
}
protected[headerNonce] = nonce
}
for k, v := range ctx.extraHeaders {
protected[k] = v
}
serializedProtected := mustSerializeJSON(protected)
needsBase64 := true
if b64, ok := protected[headerB64]; ok {
if needsBase64, ok = b64.(bool); !ok {
return nil, errors.New("go-jose/go-jose: Invalid b64 header parameter")
}
}
var input bytes.Buffer
input.WriteString(base64.RawURLEncoding.EncodeToString(serializedProtected))
input.WriteByte('.')
if needsBase64 {
input.WriteString(base64.RawURLEncoding.EncodeToString(payload))
} else {
input.Write(payload)
}
signatureInfo, err := recipient.signer.signPayload(input.Bytes(), recipient.sigAlg)
if err != nil {
return nil, err
}
signatureInfo.protected = &rawHeader{}
for k, v := range protected {
b, err := json.Marshal(v)
if err != nil {
return nil, fmt.Errorf("go-jose/go-jose: Error marshalling item %#v: %v", k, err)
}
(*signatureInfo.protected)[k] = makeRawMessage(b)
}
obj.Signatures[i] = signatureInfo
}
return obj, nil
}
func (ctx *genericSigner) Options() SignerOptions {
return SignerOptions{
NonceSource: ctx.nonceSource,
EmbedJWK: ctx.embedJWK,
ExtraHeaders: ctx.extraHeaders,
}
}
// Verify validates the signature on the object and returns the payload.
// This function does not support multi-signature, if you desire multi-sig
// verification use VerifyMulti instead.
//
// Be careful when verifying signatures based on embedded JWKs inside the
// payload header. You cannot assume that the key received in a payload is
// trusted.
func (obj JSONWebSignature) Verify(verificationKey interface{}) ([]byte, error) {
err := obj.DetachedVerify(obj.payload, verificationKey)
if err != nil {
return nil, err
}
return obj.payload, nil
}
// UnsafePayloadWithoutVerification returns the payload without
// verifying it. The content returned from this function cannot be
// trusted.
func (obj JSONWebSignature) UnsafePayloadWithoutVerification() []byte {
return obj.payload
}
// DetachedVerify validates a detached signature on the given payload. In
// most cases, you will probably want to use Verify instead. DetachedVerify
// is only useful if you have a payload and signature that are separated from
// each other.
func (obj JSONWebSignature) DetachedVerify(payload []byte, verificationKey interface{}) error {
key := tryJWKS(verificationKey, obj.headers()...)
verifier, err := newVerifier(key)
if err != nil {
return err
}
if len(obj.Signatures) > 1 {
return errors.New("go-jose/go-jose: too many signatures in payload; expecting only one")
}
signature := obj.Signatures[0]
headers := signature.mergedHeaders()
critical, err := headers.getCritical()
if err != nil {
return err
}
for _, name := range critical {
if !supportedCritical[name] {
return ErrCryptoFailure
}
}
input, err := obj.computeAuthData(payload, &signature)
if err != nil {
return ErrCryptoFailure
}
alg := headers.getSignatureAlgorithm()
err = verifier.verifyPayload(input, signature.Signature, alg)
if err == nil {
return nil
}
return ErrCryptoFailure
}
// VerifyMulti validates (one of the multiple) signatures on the object and
// returns the index of the signature that was verified, along with the signature
// object and the payload. We return the signature and index to guarantee that
// callers are getting the verified value.
func (obj JSONWebSignature) VerifyMulti(verificationKey interface{}) (int, Signature, []byte, error) {
idx, sig, err := obj.DetachedVerifyMulti(obj.payload, verificationKey)
if err != nil {
return -1, Signature{}, nil, err
}
return idx, sig, obj.payload, nil
}
// DetachedVerifyMulti validates a detached signature on the given payload with
// a signature/object that has potentially multiple signers. This returns the index
// of the signature that was verified, along with the signature object. We return
// the signature and index to guarantee that callers are getting the verified value.
//
// In most cases, you will probably want to use Verify or VerifyMulti instead.
// DetachedVerifyMulti is only useful if you have a payload and signature that are
// separated from each other, and the signature can have multiple signers at the
// same time.
func (obj JSONWebSignature) DetachedVerifyMulti(payload []byte, verificationKey interface{}) (int, Signature, error) {
key := tryJWKS(verificationKey, obj.headers()...)
verifier, err := newVerifier(key)
if err != nil {
return -1, Signature{}, err
}
outer:
for i, signature := range obj.Signatures {
headers := signature.mergedHeaders()
critical, err := headers.getCritical()
if err != nil {
continue
}
for _, name := range critical {
if !supportedCritical[name] {
continue outer
}
}
input, err := obj.computeAuthData(payload, &signature)
if err != nil {
continue
}
alg := headers.getSignatureAlgorithm()
err = verifier.verifyPayload(input, signature.Signature, alg)
if err == nil {
return i, signature, nil
}
}
return -1, Signature{}, ErrCryptoFailure
}
func (obj JSONWebSignature) headers() []Header {
headers := make([]Header, len(obj.Signatures))
for i, sig := range obj.Signatures {
headers[i] = sig.Header
}
return headers
}

495
vendor/github.com/go-jose/go-jose/v3/symmetric.go generated vendored Normal file
View File

@ -0,0 +1,495 @@
/*-
* Copyright 2014 Square Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package jose
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"crypto/sha512"
"crypto/subtle"
"errors"
"fmt"
"hash"
"io"
"golang.org/x/crypto/pbkdf2"
josecipher "github.com/go-jose/go-jose/v3/cipher"
)
// RandReader is a cryptographically secure random number generator (stubbed out in tests).
var RandReader = rand.Reader
const (
// RFC7518 recommends a minimum of 1,000 iterations:
// https://tools.ietf.org/html/rfc7518#section-4.8.1.2
// NIST recommends a minimum of 10,000:
// https://pages.nist.gov/800-63-3/sp800-63b.html
// 1Password uses 100,000:
// https://support.1password.com/pbkdf2/
defaultP2C = 100000
// Default salt size: 128 bits
defaultP2SSize = 16
)
// Dummy key cipher for shared symmetric key mode
type symmetricKeyCipher struct {
key []byte // Pre-shared content-encryption key
p2c int // PBES2 Count
p2s []byte // PBES2 Salt Input
}
// Signer/verifier for MAC modes
type symmetricMac struct {
key []byte
}
// Input/output from an AEAD operation
type aeadParts struct {
iv, ciphertext, tag []byte
}
// A content cipher based on an AEAD construction
type aeadContentCipher struct {
keyBytes int
authtagBytes int
getAead func(key []byte) (cipher.AEAD, error)
}
// Random key generator
type randomKeyGenerator struct {
size int
}
// Static key generator
type staticKeyGenerator struct {
key []byte
}
// Create a new content cipher based on AES-GCM
func newAESGCM(keySize int) contentCipher {
return &aeadContentCipher{
keyBytes: keySize,
authtagBytes: 16,
getAead: func(key []byte) (cipher.AEAD, error) {
aes, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
return cipher.NewGCM(aes)
},
}
}
// Create a new content cipher based on AES-CBC+HMAC
func newAESCBC(keySize int) contentCipher {
return &aeadContentCipher{
keyBytes: keySize * 2,
authtagBytes: keySize,
getAead: func(key []byte) (cipher.AEAD, error) {
return josecipher.NewCBCHMAC(key, aes.NewCipher)
},
}
}
// Get an AEAD cipher object for the given content encryption algorithm
func getContentCipher(alg ContentEncryption) contentCipher {
switch alg {
case A128GCM:
return newAESGCM(16)
case A192GCM:
return newAESGCM(24)
case A256GCM:
return newAESGCM(32)
case A128CBC_HS256:
return newAESCBC(16)
case A192CBC_HS384:
return newAESCBC(24)
case A256CBC_HS512:
return newAESCBC(32)
default:
return nil
}
}
// getPbkdf2Params returns the key length and hash function used in
// pbkdf2.Key.
func getPbkdf2Params(alg KeyAlgorithm) (int, func() hash.Hash) {
switch alg {
case PBES2_HS256_A128KW:
return 16, sha256.New
case PBES2_HS384_A192KW:
return 24, sha512.New384
case PBES2_HS512_A256KW:
return 32, sha512.New
default:
panic("invalid algorithm")
}
}
// getRandomSalt generates a new salt of the given size.
func getRandomSalt(size int) ([]byte, error) {
salt := make([]byte, size)
_, err := io.ReadFull(RandReader, salt)
if err != nil {
return nil, err
}
return salt, nil
}
// newSymmetricRecipient creates a JWE encrypter based on AES-GCM key wrap.
func newSymmetricRecipient(keyAlg KeyAlgorithm, key []byte) (recipientKeyInfo, error) {
switch keyAlg {
case DIRECT, A128GCMKW, A192GCMKW, A256GCMKW, A128KW, A192KW, A256KW:
case PBES2_HS256_A128KW, PBES2_HS384_A192KW, PBES2_HS512_A256KW:
default:
return recipientKeyInfo{}, ErrUnsupportedAlgorithm
}
return recipientKeyInfo{
keyAlg: keyAlg,
keyEncrypter: &symmetricKeyCipher{
key: key,
},
}, nil
}
// newSymmetricSigner creates a recipientSigInfo based on the given key.
func newSymmetricSigner(sigAlg SignatureAlgorithm, key []byte) (recipientSigInfo, error) {
// Verify that key management algorithm is supported by this encrypter
switch sigAlg {
case HS256, HS384, HS512:
default:
return recipientSigInfo{}, ErrUnsupportedAlgorithm
}
return recipientSigInfo{
sigAlg: sigAlg,
signer: &symmetricMac{
key: key,
},
}, nil
}
// Generate a random key for the given content cipher
func (ctx randomKeyGenerator) genKey() ([]byte, rawHeader, error) {
key := make([]byte, ctx.size)
_, err := io.ReadFull(RandReader, key)
if err != nil {
return nil, rawHeader{}, err
}
return key, rawHeader{}, nil
}
// Key size for random generator
func (ctx randomKeyGenerator) keySize() int {
return ctx.size
}
// Generate a static key (for direct mode)
func (ctx staticKeyGenerator) genKey() ([]byte, rawHeader, error) {
cek := make([]byte, len(ctx.key))
copy(cek, ctx.key)
return cek, rawHeader{}, nil
}
// Key size for static generator
func (ctx staticKeyGenerator) keySize() int {
return len(ctx.key)
}
// Get key size for this cipher
func (ctx aeadContentCipher) keySize() int {
return ctx.keyBytes
}
// Encrypt some data
func (ctx aeadContentCipher) encrypt(key, aad, pt []byte) (*aeadParts, error) {
// Get a new AEAD instance
aead, err := ctx.getAead(key)
if err != nil {
return nil, err
}
// Initialize a new nonce
iv := make([]byte, aead.NonceSize())
_, err = io.ReadFull(RandReader, iv)
if err != nil {
return nil, err
}
ciphertextAndTag := aead.Seal(nil, iv, pt, aad)
offset := len(ciphertextAndTag) - ctx.authtagBytes
return &aeadParts{
iv: iv,
ciphertext: ciphertextAndTag[:offset],
tag: ciphertextAndTag[offset:],
}, nil
}
// Decrypt some data
func (ctx aeadContentCipher) decrypt(key, aad []byte, parts *aeadParts) ([]byte, error) {
aead, err := ctx.getAead(key)
if err != nil {
return nil, err
}
if len(parts.iv) != aead.NonceSize() || len(parts.tag) < ctx.authtagBytes {
return nil, ErrCryptoFailure
}
return aead.Open(nil, parts.iv, append(parts.ciphertext, parts.tag...), aad)
}
// Encrypt the content encryption key.
func (ctx *symmetricKeyCipher) encryptKey(cek []byte, alg KeyAlgorithm) (recipientInfo, error) {
switch alg {
case DIRECT:
return recipientInfo{
header: &rawHeader{},
}, nil
case A128GCMKW, A192GCMKW, A256GCMKW:
aead := newAESGCM(len(ctx.key))
parts, err := aead.encrypt(ctx.key, []byte{}, cek)
if err != nil {
return recipientInfo{}, err
}
header := &rawHeader{}
if err = header.set(headerIV, newBuffer(parts.iv)); err != nil {
return recipientInfo{}, err
}
if err = header.set(headerTag, newBuffer(parts.tag)); err != nil {
return recipientInfo{}, err
}
return recipientInfo{
header: header,
encryptedKey: parts.ciphertext,
}, nil
case A128KW, A192KW, A256KW:
block, err := aes.NewCipher(ctx.key)
if err != nil {
return recipientInfo{}, err
}
jek, err := josecipher.KeyWrap(block, cek)
if err != nil {
return recipientInfo{}, err
}
return recipientInfo{
encryptedKey: jek,
header: &rawHeader{},
}, nil
case PBES2_HS256_A128KW, PBES2_HS384_A192KW, PBES2_HS512_A256KW:
if len(ctx.p2s) == 0 {
salt, err := getRandomSalt(defaultP2SSize)
if err != nil {
return recipientInfo{}, err
}
ctx.p2s = salt
}
if ctx.p2c <= 0 {
ctx.p2c = defaultP2C
}
// salt is UTF8(Alg) || 0x00 || Salt Input
salt := bytes.Join([][]byte{[]byte(alg), ctx.p2s}, []byte{0x00})
// derive key
keyLen, h := getPbkdf2Params(alg)
key := pbkdf2.Key(ctx.key, salt, ctx.p2c, keyLen, h)
// use AES cipher with derived key
block, err := aes.NewCipher(key)
if err != nil {
return recipientInfo{}, err
}
jek, err := josecipher.KeyWrap(block, cek)
if err != nil {
return recipientInfo{}, err
}
header := &rawHeader{}
if err = header.set(headerP2C, ctx.p2c); err != nil {
return recipientInfo{}, err
}
if err = header.set(headerP2S, newBuffer(ctx.p2s)); err != nil {
return recipientInfo{}, err
}
return recipientInfo{
encryptedKey: jek,
header: header,
}, nil
}
return recipientInfo{}, ErrUnsupportedAlgorithm
}
// Decrypt the content encryption key.
func (ctx *symmetricKeyCipher) decryptKey(headers rawHeader, recipient *recipientInfo, generator keyGenerator) ([]byte, error) {
switch headers.getAlgorithm() {
case DIRECT:
cek := make([]byte, len(ctx.key))
copy(cek, ctx.key)
return cek, nil
case A128GCMKW, A192GCMKW, A256GCMKW:
aead := newAESGCM(len(ctx.key))
iv, err := headers.getIV()
if err != nil {
return nil, fmt.Errorf("go-jose/go-jose: invalid IV: %v", err)
}
tag, err := headers.getTag()
if err != nil {
return nil, fmt.Errorf("go-jose/go-jose: invalid tag: %v", err)
}
parts := &aeadParts{
iv: iv.bytes(),
ciphertext: recipient.encryptedKey,
tag: tag.bytes(),
}
cek, err := aead.decrypt(ctx.key, []byte{}, parts)
if err != nil {
return nil, err
}
return cek, nil
case A128KW, A192KW, A256KW:
block, err := aes.NewCipher(ctx.key)
if err != nil {
return nil, err
}
cek, err := josecipher.KeyUnwrap(block, recipient.encryptedKey)
if err != nil {
return nil, err
}
return cek, nil
case PBES2_HS256_A128KW, PBES2_HS384_A192KW, PBES2_HS512_A256KW:
p2s, err := headers.getP2S()
if err != nil {
return nil, fmt.Errorf("go-jose/go-jose: invalid P2S: %v", err)
}
if p2s == nil || len(p2s.data) == 0 {
return nil, fmt.Errorf("go-jose/go-jose: invalid P2S: must be present")
}
p2c, err := headers.getP2C()
if err != nil {
return nil, fmt.Errorf("go-jose/go-jose: invalid P2C: %v", err)
}
if p2c <= 0 {
return nil, fmt.Errorf("go-jose/go-jose: invalid P2C: must be a positive integer")
}
// salt is UTF8(Alg) || 0x00 || Salt Input
alg := headers.getAlgorithm()
salt := bytes.Join([][]byte{[]byte(alg), p2s.bytes()}, []byte{0x00})
// derive key
keyLen, h := getPbkdf2Params(alg)
key := pbkdf2.Key(ctx.key, salt, p2c, keyLen, h)
// use AES cipher with derived key
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
cek, err := josecipher.KeyUnwrap(block, recipient.encryptedKey)
if err != nil {
return nil, err
}
return cek, nil
}
return nil, ErrUnsupportedAlgorithm
}
// Sign the given payload
func (ctx symmetricMac) signPayload(payload []byte, alg SignatureAlgorithm) (Signature, error) {
mac, err := ctx.hmac(payload, alg)
if err != nil {
return Signature{}, errors.New("go-jose/go-jose: failed to compute hmac")
}
return Signature{
Signature: mac,
protected: &rawHeader{},
}, nil
}
// Verify the given payload
func (ctx symmetricMac) verifyPayload(payload []byte, mac []byte, alg SignatureAlgorithm) error {
expected, err := ctx.hmac(payload, alg)
if err != nil {
return errors.New("go-jose/go-jose: failed to compute hmac")
}
if len(mac) != len(expected) {
return errors.New("go-jose/go-jose: invalid hmac")
}
match := subtle.ConstantTimeCompare(mac, expected)
if match != 1 {
return errors.New("go-jose/go-jose: invalid hmac")
}
return nil
}
// Compute the HMAC based on the given alg value
func (ctx symmetricMac) hmac(payload []byte, alg SignatureAlgorithm) ([]byte, error) {
var hash func() hash.Hash
switch alg {
case HS256:
hash = sha256.New
case HS384:
hash = sha512.New384
case HS512:
hash = sha512.New
default:
return nil, ErrUnsupportedAlgorithm
}
hmac := hmac.New(hash, ctx.key)
// According to documentation, Write() on hash never fails
_, _ = hmac.Write(payload)
return hmac.Sum(nil), nil
}

View File

@ -559,7 +559,7 @@ func (s *String) ReadASN1BitString(out *encoding_asn1.BitString) bool {
return true return true
} }
// ReadASN1BitString decodes an ASN.1 BIT STRING into out and advances. It is // ReadASN1BitStringAsBytes decodes an ASN.1 BIT STRING into out and advances. It is
// an error if the BIT STRING is not a whole number of bytes. It reports // an error if the BIT STRING is not a whole number of bytes. It reports
// whether the read was successful. // whether the read was successful.
func (s *String) ReadASN1BitStringAsBytes(out *[]byte) bool { func (s *String) ReadASN1BitStringAsBytes(out *[]byte) bool {

View File

@ -303,9 +303,9 @@ func (b *Builder) add(bytes ...byte) {
b.result = append(b.result, bytes...) b.result = append(b.result, bytes...)
} }
// Unwrite rolls back n bytes written directly to the Builder. An attempt by a // Unwrite rolls back non-negative n bytes written directly to the Builder.
// child builder passed to a continuation to unwrite bytes from its parent will // An attempt by a child builder passed to a continuation to unwrite bytes
// panic. // from its parent will panic.
func (b *Builder) Unwrite(n int) { func (b *Builder) Unwrite(n int) {
if b.err != nil { if b.err != nil {
return return
@ -317,6 +317,9 @@ func (b *Builder) Unwrite(n int) {
if length < 0 { if length < 0 {
panic("cryptobyte: internal error") panic("cryptobyte: internal error")
} }
if n < 0 {
panic("cryptobyte: attempted to unwrite negative number of bytes")
}
if n > length { if n > length {
panic("cryptobyte: attempted to unwrite more than was written") panic("cryptobyte: attempted to unwrite more than was written")
} }

View File

@ -5,71 +5,18 @@
// Package curve25519 provides an implementation of the X25519 function, which // Package curve25519 provides an implementation of the X25519 function, which
// performs scalar multiplication on the elliptic curve known as Curve25519. // performs scalar multiplication on the elliptic curve known as Curve25519.
// See RFC 7748. // See RFC 7748.
//
// Starting in Go 1.20, this package is a wrapper for the X25519 implementation
// in the crypto/ecdh package.
package curve25519 // import "golang.org/x/crypto/curve25519" package curve25519 // import "golang.org/x/crypto/curve25519"
import (
"crypto/subtle"
"errors"
"strconv"
"golang.org/x/crypto/curve25519/internal/field"
)
// ScalarMult sets dst to the product scalar * point. // ScalarMult sets dst to the product scalar * point.
// //
// Deprecated: when provided a low-order point, ScalarMult will set dst to all // Deprecated: when provided a low-order point, ScalarMult will set dst to all
// zeroes, irrespective of the scalar. Instead, use the X25519 function, which // zeroes, irrespective of the scalar. Instead, use the X25519 function, which
// will return an error. // will return an error.
func ScalarMult(dst, scalar, point *[32]byte) { func ScalarMult(dst, scalar, point *[32]byte) {
var e [32]byte scalarMult(dst, scalar, point)
copy(e[:], scalar[:])
e[0] &= 248
e[31] &= 127
e[31] |= 64
var x1, x2, z2, x3, z3, tmp0, tmp1 field.Element
x1.SetBytes(point[:])
x2.One()
x3.Set(&x1)
z3.One()
swap := 0
for pos := 254; pos >= 0; pos-- {
b := e[pos/8] >> uint(pos&7)
b &= 1
swap ^= int(b)
x2.Swap(&x3, swap)
z2.Swap(&z3, swap)
swap = int(b)
tmp0.Subtract(&x3, &z3)
tmp1.Subtract(&x2, &z2)
x2.Add(&x2, &z2)
z2.Add(&x3, &z3)
z3.Multiply(&tmp0, &x2)
z2.Multiply(&z2, &tmp1)
tmp0.Square(&tmp1)
tmp1.Square(&x2)
x3.Add(&z3, &z2)
z2.Subtract(&z3, &z2)
x2.Multiply(&tmp1, &tmp0)
tmp1.Subtract(&tmp1, &tmp0)
z2.Square(&z2)
z3.Mult32(&tmp1, 121666)
x3.Square(&x3)
tmp0.Add(&tmp0, &z3)
z3.Multiply(&x1, &z2)
z2.Multiply(&tmp1, &tmp0)
}
x2.Swap(&x3, swap)
z2.Swap(&z3, swap)
z2.Invert(&z2)
x2.Multiply(&x2, &z2)
copy(dst[:], x2.Bytes())
} }
// ScalarBaseMult sets dst to the product scalar * base where base is the // ScalarBaseMult sets dst to the product scalar * base where base is the
@ -78,7 +25,7 @@ func ScalarMult(dst, scalar, point *[32]byte) {
// It is recommended to use the X25519 function with Basepoint instead, as // It is recommended to use the X25519 function with Basepoint instead, as
// copying into fixed size arrays can lead to unexpected bugs. // copying into fixed size arrays can lead to unexpected bugs.
func ScalarBaseMult(dst, scalar *[32]byte) { func ScalarBaseMult(dst, scalar *[32]byte) {
ScalarMult(dst, scalar, &basePoint) scalarBaseMult(dst, scalar)
} }
const ( const (
@ -91,21 +38,10 @@ const (
// Basepoint is the canonical Curve25519 generator. // Basepoint is the canonical Curve25519 generator.
var Basepoint []byte var Basepoint []byte
var basePoint = [32]byte{9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} var basePoint = [32]byte{9}
func init() { Basepoint = basePoint[:] } func init() { Basepoint = basePoint[:] }
func checkBasepoint() {
if subtle.ConstantTimeCompare(Basepoint, []byte{
0x09, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
}) != 1 {
panic("curve25519: global Basepoint value was modified")
}
}
// X25519 returns the result of the scalar multiplication (scalar * point), // X25519 returns the result of the scalar multiplication (scalar * point),
// according to RFC 7748, Section 5. scalar, point and the return value are // according to RFC 7748, Section 5. scalar, point and the return value are
// slices of 32 bytes. // slices of 32 bytes.
@ -121,26 +57,3 @@ func X25519(scalar, point []byte) ([]byte, error) {
var dst [32]byte var dst [32]byte
return x25519(&dst, scalar, point) return x25519(&dst, scalar, point)
} }
func x25519(dst *[32]byte, scalar, point []byte) ([]byte, error) {
var in [32]byte
if l := len(scalar); l != 32 {
return nil, errors.New("bad scalar length: " + strconv.Itoa(l) + ", expected 32")
}
if l := len(point); l != 32 {
return nil, errors.New("bad point length: " + strconv.Itoa(l) + ", expected 32")
}
copy(in[:], scalar)
if &point[0] == &Basepoint[0] {
checkBasepoint()
ScalarBaseMult(dst, &in)
} else {
var base, zero [32]byte
copy(base[:], point)
ScalarMult(dst, &in, &base)
if subtle.ConstantTimeCompare(dst[:], zero[:]) == 1 {
return nil, errors.New("bad input point: low order point")
}
}
return dst[:], nil
}

View File

@ -0,0 +1,105 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !go1.20
package curve25519
import (
"crypto/subtle"
"errors"
"strconv"
"golang.org/x/crypto/curve25519/internal/field"
)
func scalarMult(dst, scalar, point *[32]byte) {
var e [32]byte
copy(e[:], scalar[:])
e[0] &= 248
e[31] &= 127
e[31] |= 64
var x1, x2, z2, x3, z3, tmp0, tmp1 field.Element
x1.SetBytes(point[:])
x2.One()
x3.Set(&x1)
z3.One()
swap := 0
for pos := 254; pos >= 0; pos-- {
b := e[pos/8] >> uint(pos&7)
b &= 1
swap ^= int(b)
x2.Swap(&x3, swap)
z2.Swap(&z3, swap)
swap = int(b)
tmp0.Subtract(&x3, &z3)
tmp1.Subtract(&x2, &z2)
x2.Add(&x2, &z2)
z2.Add(&x3, &z3)
z3.Multiply(&tmp0, &x2)
z2.Multiply(&z2, &tmp1)
tmp0.Square(&tmp1)
tmp1.Square(&x2)
x3.Add(&z3, &z2)
z2.Subtract(&z3, &z2)
x2.Multiply(&tmp1, &tmp0)
tmp1.Subtract(&tmp1, &tmp0)
z2.Square(&z2)
z3.Mult32(&tmp1, 121666)
x3.Square(&x3)
tmp0.Add(&tmp0, &z3)
z3.Multiply(&x1, &z2)
z2.Multiply(&tmp1, &tmp0)
}
x2.Swap(&x3, swap)
z2.Swap(&z3, swap)
z2.Invert(&z2)
x2.Multiply(&x2, &z2)
copy(dst[:], x2.Bytes())
}
func scalarBaseMult(dst, scalar *[32]byte) {
checkBasepoint()
scalarMult(dst, scalar, &basePoint)
}
func x25519(dst *[32]byte, scalar, point []byte) ([]byte, error) {
var in [32]byte
if l := len(scalar); l != 32 {
return nil, errors.New("bad scalar length: " + strconv.Itoa(l) + ", expected 32")
}
if l := len(point); l != 32 {
return nil, errors.New("bad point length: " + strconv.Itoa(l) + ", expected 32")
}
copy(in[:], scalar)
if &point[0] == &Basepoint[0] {
scalarBaseMult(dst, &in)
} else {
var base, zero [32]byte
copy(base[:], point)
scalarMult(dst, &in, &base)
if subtle.ConstantTimeCompare(dst[:], zero[:]) == 1 {
return nil, errors.New("bad input point: low order point")
}
}
return dst[:], nil
}
func checkBasepoint() {
if subtle.ConstantTimeCompare(Basepoint, []byte{
0x09, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
}) != 1 {
panic("curve25519: global Basepoint value was modified")
}
}

View File

@ -0,0 +1,46 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.20
package curve25519
import "crypto/ecdh"
func x25519(dst *[32]byte, scalar, point []byte) ([]byte, error) {
curve := ecdh.X25519()
pub, err := curve.NewPublicKey(point)
if err != nil {
return nil, err
}
priv, err := curve.NewPrivateKey(scalar)
if err != nil {
return nil, err
}
out, err := priv.ECDH(pub)
if err != nil {
return nil, err
}
copy(dst[:], out)
return dst[:], nil
}
func scalarMult(dst, scalar, point *[32]byte) {
if _, err := x25519(dst, scalar[:], point[:]); err != nil {
// The only error condition for x25519 when the inputs are 32 bytes long
// is if the output would have been the all-zero value.
for i := range dst {
dst[i] = 0
}
}
}
func scalarBaseMult(dst, scalar *[32]byte) {
curve := ecdh.X25519()
priv, err := curve.NewPrivateKey(scalar[:])
if err != nil {
panic("curve25519: internal error: scalarBaseMult was not 32 bytes")
}
copy(dst[:], priv.PublicKey().Bytes())
}

View File

@ -245,7 +245,7 @@ func feSquareGeneric(v, a *Element) {
v.carryPropagate() v.carryPropagate()
} }
// carryPropagate brings the limbs below 52 bits by applying the reduction // carryPropagateGeneric brings the limbs below 52 bits by applying the reduction
// identity (a * 2²⁵⁵ + b = a * 19 + b) to the l4 carry. TODO inline // identity (a * 2²⁵⁵ + b = a * 19 + b) to the l4 carry. TODO inline
func (v *Element) carryPropagateGeneric() *Element { func (v *Element) carryPropagateGeneric() *Element {
c0 := v.l0 >> 51 c0 := v.l0 >> 51

View File

@ -114,7 +114,8 @@ var cipherModes = map[string]*cipherMode{
"arcfour": {16, 0, streamCipherMode(0, newRC4)}, "arcfour": {16, 0, streamCipherMode(0, newRC4)},
// AEAD ciphers // AEAD ciphers
gcmCipherID: {16, 12, newGCMCipher}, gcm128CipherID: {16, 12, newGCMCipher},
gcm256CipherID: {32, 12, newGCMCipher},
chacha20Poly1305ID: {64, 0, newChaCha20Cipher}, chacha20Poly1305ID: {64, 0, newChaCha20Cipher},
// CBC mode is insecure and so is not included in the default config. // CBC mode is insecure and so is not included in the default config.

View File

@ -28,7 +28,7 @@ const (
// supportedCiphers lists ciphers we support but might not recommend. // supportedCiphers lists ciphers we support but might not recommend.
var supportedCiphers = []string{ var supportedCiphers = []string{
"aes128-ctr", "aes192-ctr", "aes256-ctr", "aes128-ctr", "aes192-ctr", "aes256-ctr",
"aes128-gcm@openssh.com", "aes128-gcm@openssh.com", gcm256CipherID,
chacha20Poly1305ID, chacha20Poly1305ID,
"arcfour256", "arcfour128", "arcfour", "arcfour256", "arcfour128", "arcfour",
aes128cbcID, aes128cbcID,
@ -37,7 +37,7 @@ var supportedCiphers = []string{
// preferredCiphers specifies the default preference for ciphers. // preferredCiphers specifies the default preference for ciphers.
var preferredCiphers = []string{ var preferredCiphers = []string{
"aes128-gcm@openssh.com", "aes128-gcm@openssh.com", gcm256CipherID,
chacha20Poly1305ID, chacha20Poly1305ID,
"aes128-ctr", "aes192-ctr", "aes256-ctr", "aes128-ctr", "aes192-ctr", "aes256-ctr",
} }
@ -168,7 +168,7 @@ func (a *directionAlgorithms) rekeyBytes() int64 {
// 2^(BLOCKSIZE/4) blocks. For all AES flavors BLOCKSIZE is // 2^(BLOCKSIZE/4) blocks. For all AES flavors BLOCKSIZE is
// 128. // 128.
switch a.Cipher { switch a.Cipher {
case "aes128-ctr", "aes192-ctr", "aes256-ctr", gcmCipherID, aes128cbcID: case "aes128-ctr", "aes192-ctr", "aes256-ctr", gcm128CipherID, gcm256CipherID, aes128cbcID:
return 16 * (1 << 32) return 16 * (1 << 32)
} }
@ -178,7 +178,8 @@ func (a *directionAlgorithms) rekeyBytes() int64 {
} }
var aeadCiphers = map[string]bool{ var aeadCiphers = map[string]bool{
gcmCipherID: true, gcm128CipherID: true,
gcm256CipherID: true,
chacha20Poly1305ID: true, chacha20Poly1305ID: true,
} }

View File

@ -97,7 +97,7 @@ func (c *connection) Close() error {
return c.sshConn.conn.Close() return c.sshConn.conn.Close()
} }
// sshconn provides net.Conn metadata, but disallows direct reads and // sshConn provides net.Conn metadata, but disallows direct reads and
// writes. // writes.
type sshConn struct { type sshConn struct {
conn net.Conn conn net.Conn

View File

@ -1087,9 +1087,9 @@ func (*PassphraseMissingError) Error() string {
return "ssh: this private key is passphrase protected" return "ssh: this private key is passphrase protected"
} }
// ParseRawPrivateKey returns a private key from a PEM encoded private key. It // ParseRawPrivateKey returns a private key from a PEM encoded private key. It supports
// supports RSA (PKCS#1), PKCS#8, DSA (OpenSSL), and ECDSA private keys. If the // RSA, DSA, ECDSA, and Ed25519 private keys in PKCS#1, PKCS#8, OpenSSL, and OpenSSH
// private key is encrypted, it will return a PassphraseMissingError. // formats. If the private key is encrypted, it will return a PassphraseMissingError.
func ParseRawPrivateKey(pemBytes []byte) (interface{}, error) { func ParseRawPrivateKey(pemBytes []byte) (interface{}, error) {
block, _ := pem.Decode(pemBytes) block, _ := pem.Decode(pemBytes)
if block == nil { if block == nil {

View File

@ -17,7 +17,8 @@ import (
const debugTransport = false const debugTransport = false
const ( const (
gcmCipherID = "aes128-gcm@openssh.com" gcm128CipherID = "aes128-gcm@openssh.com"
gcm256CipherID = "aes256-gcm@openssh.com"
aes128cbcID = "aes128-cbc" aes128cbcID = "aes128-cbc"
tripledescbcID = "3des-cbc" tripledescbcID = "3des-cbc"
) )

View File

@ -88,13 +88,9 @@ func (p *pipe) Write(d []byte) (n int, err error) {
p.c.L = &p.mu p.c.L = &p.mu
} }
defer p.c.Signal() defer p.c.Signal()
if p.err != nil { if p.err != nil || p.breakErr != nil {
return 0, errClosedPipeWrite return 0, errClosedPipeWrite
} }
if p.breakErr != nil {
p.unread += len(d)
return len(d), nil // discard when there is no reader
}
return p.b.Write(d) return p.b.Write(d)
} }

View File

@ -1822,15 +1822,18 @@ func (sc *serverConn) processData(f *DataFrame) error {
} }
if len(data) > 0 { if len(data) > 0 {
st.bodyBytes += int64(len(data))
wrote, err := st.body.Write(data) wrote, err := st.body.Write(data)
if err != nil { if err != nil {
// The handler has closed the request body.
// Return the connection-level flow control for the discarded data,
// but not the stream-level flow control.
sc.sendWindowUpdate(nil, int(f.Length)-wrote) sc.sendWindowUpdate(nil, int(f.Length)-wrote)
return sc.countError("body_write_err", streamError(id, ErrCodeStreamClosed)) return nil
} }
if wrote != len(data) { if wrote != len(data) {
panic("internal error: bad Writer") panic("internal error: bad Writer")
} }
st.bodyBytes += int64(len(data))
} }
// Return any padded flow control now, since we won't // Return any padded flow control now, since we won't

View File

@ -560,10 +560,11 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res
traceGotConn(req, cc, reused) traceGotConn(req, cc, reused)
res, err := cc.RoundTrip(req) res, err := cc.RoundTrip(req)
if err != nil && retry <= 6 { if err != nil && retry <= 6 {
roundTripErr := err
if req, err = shouldRetryRequest(req, err); err == nil { if req, err = shouldRetryRequest(req, err); err == nil {
// After the first retry, do exponential backoff with 10% jitter. // After the first retry, do exponential backoff with 10% jitter.
if retry == 0 { if retry == 0 {
t.vlogf("RoundTrip retrying after failure: %v", err) t.vlogf("RoundTrip retrying after failure: %v", roundTripErr)
continue continue
} }
backoff := float64(uint(1) << (uint(retry) - 1)) backoff := float64(uint(1) << (uint(retry) - 1))
@ -572,7 +573,7 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res
timer := backoffNewTimer(d) timer := backoffNewTimer(d)
select { select {
case <-timer.C: case <-timer.C:
t.vlogf("RoundTrip retrying after failure: %v", err) t.vlogf("RoundTrip retrying after failure: %v", roundTripErr)
continue continue
case <-req.Context().Done(): case <-req.Context().Done():
timer.Stop() timer.Stop()
@ -2555,6 +2556,9 @@ func (b transportResponseBody) Close() error {
cs := b.cs cs := b.cs
cc := cs.cc cc := cs.cc
cs.bufPipe.BreakWithError(errClosedResponseBody)
cs.abortStream(errClosedResponseBody)
unread := cs.bufPipe.Len() unread := cs.bufPipe.Len()
if unread > 0 { if unread > 0 {
cc.mu.Lock() cc.mu.Lock()
@ -2573,9 +2577,6 @@ func (b transportResponseBody) Close() error {
cc.wmu.Unlock() cc.wmu.Unlock()
} }
cs.bufPipe.BreakWithError(errClosedResponseBody)
cs.abortStream(errClosedResponseBody)
select { select {
case <-cs.donec: case <-cs.donec:
case <-cs.ctx.Done(): case <-cs.ctx.Done():

View File

@ -13,7 +13,7 @@ import "encoding/binary"
// a rune to a uint16. The values take two forms. For v >= 0x8000: // a rune to a uint16. The values take two forms. For v >= 0x8000:
// bits // bits
// 15: 1 (inverse of NFD_QC bit of qcInfo) // 15: 1 (inverse of NFD_QC bit of qcInfo)
// 13..7: qcInfo (see below). isYesD is always true (no decompostion). // 13..7: qcInfo (see below). isYesD is always true (no decomposition).
// 6..0: ccc (compressed CCC value). // 6..0: ccc (compressed CCC value).
// For v < 0x8000, the respective rune has a decomposition and v is an index // For v < 0x8000, the respective rune has a decomposition and v is an index
// into a byte array of UTF-8 decomposition sequences and additional info and // into a byte array of UTF-8 decomposition sequences and additional info and

View File

@ -53,10 +53,13 @@ func New(files []*ast.File) *Inspector {
// of an ast.Node during a traversal. // of an ast.Node during a traversal.
type event struct { type event struct {
node ast.Node node ast.Node
typ uint64 // typeOf(node) typ uint64 // typeOf(node) on push event, or union of typ strictly between push and pop events on pop events
index int // 1 + index of corresponding pop event, or 0 if this is a pop index int // index of corresponding push or pop event
} }
// TODO: Experiment with storing only the second word of event.node (unsafe.Pointer).
// Type can be recovered from the sole bit in typ.
// Preorder visits all the nodes of the files supplied to New in // Preorder visits all the nodes of the files supplied to New in
// depth-first order. It calls f(n) for each node n before it visits // depth-first order. It calls f(n) for each node n before it visits
// n's children. // n's children.
@ -72,10 +75,17 @@ func (in *Inspector) Preorder(types []ast.Node, f func(ast.Node)) {
mask := maskOf(types) mask := maskOf(types)
for i := 0; i < len(in.events); { for i := 0; i < len(in.events); {
ev := in.events[i] ev := in.events[i]
if ev.typ&mask != 0 { if ev.index > i {
if ev.index > 0 { // push
if ev.typ&mask != 0 {
f(ev.node) f(ev.node)
} }
pop := ev.index
if in.events[pop].typ&mask == 0 {
// Subtrees do not contain types: skip them and pop.
i = pop + 1
continue
}
} }
i++ i++
} }
@ -94,15 +104,24 @@ func (in *Inspector) Nodes(types []ast.Node, f func(n ast.Node, push bool) (proc
mask := maskOf(types) mask := maskOf(types)
for i := 0; i < len(in.events); { for i := 0; i < len(in.events); {
ev := in.events[i] ev := in.events[i]
if ev.typ&mask != 0 { if ev.index > i {
if ev.index > 0 { // push
// push pop := ev.index
if ev.typ&mask != 0 {
if !f(ev.node, true) { if !f(ev.node, true) {
i = ev.index // jump to corresponding pop + 1 i = pop + 1 // jump to corresponding pop + 1
continue continue
} }
} else { }
// pop if in.events[pop].typ&mask == 0 {
// Subtrees do not contain types: skip them.
i = pop
continue
}
} else {
// pop
push := ev.index
if in.events[push].typ&mask != 0 {
f(ev.node, false) f(ev.node, false)
} }
} }
@ -119,19 +138,26 @@ func (in *Inspector) WithStack(types []ast.Node, f func(n ast.Node, push bool, s
var stack []ast.Node var stack []ast.Node
for i := 0; i < len(in.events); { for i := 0; i < len(in.events); {
ev := in.events[i] ev := in.events[i]
if ev.index > 0 { if ev.index > i {
// push // push
pop := ev.index
stack = append(stack, ev.node) stack = append(stack, ev.node)
if ev.typ&mask != 0 { if ev.typ&mask != 0 {
if !f(ev.node, true, stack) { if !f(ev.node, true, stack) {
i = ev.index i = pop + 1
stack = stack[:len(stack)-1] stack = stack[:len(stack)-1]
continue continue
} }
} }
if in.events[pop].typ&mask == 0 {
// Subtrees does not contain types: skip them.
i = pop
continue
}
} else { } else {
// pop // pop
if ev.typ&mask != 0 { push := ev.index
if in.events[push].typ&mask != 0 {
f(ev.node, false, stack) f(ev.node, false, stack)
} }
stack = stack[:len(stack)-1] stack = stack[:len(stack)-1]
@ -157,25 +183,31 @@ func traverse(files []*ast.File) []event {
events := make([]event, 0, capacity) events := make([]event, 0, capacity)
var stack []event var stack []event
stack = append(stack, event{}) // include an extra event so file nodes have a parent
for _, f := range files { for _, f := range files {
ast.Inspect(f, func(n ast.Node) bool { ast.Inspect(f, func(n ast.Node) bool {
if n != nil { if n != nil {
// push // push
ev := event{ ev := event{
node: n, node: n,
typ: typeOf(n), typ: 0, // temporarily used to accumulate type bits of subtree
index: len(events), // push event temporarily holds own index index: len(events), // push event temporarily holds own index
} }
stack = append(stack, ev) stack = append(stack, ev)
events = append(events, ev) events = append(events, ev)
} else { } else {
// pop // pop
ev := stack[len(stack)-1] top := len(stack) - 1
stack = stack[:len(stack)-1] ev := stack[top]
typ := typeOf(ev.node)
push := ev.index
parent := top - 1
events[ev.index].index = len(events) + 1 // make push refer to pop events[push].typ = typ // set type of push
stack[parent].typ |= typ | ev.typ // parent's typ contains push and pop's typs.
events[push].index = len(events) // make push refer to pop
ev.index = 0 // turn ev into a pop event stack = stack[:top]
events = append(events, ev) events = append(events, ev)
} }
return true return true

View File

@ -11,6 +11,7 @@ package inspector
import ( import (
"go/ast" "go/ast"
"math"
"golang.org/x/tools/internal/typeparams" "golang.org/x/tools/internal/typeparams"
) )
@ -218,7 +219,7 @@ func typeOf(n ast.Node) uint64 {
func maskOf(nodes []ast.Node) uint64 { func maskOf(nodes []ast.Node) uint64 {
if nodes == nil { if nodes == nil {
return 1<<64 - 1 // match all node types return math.MaxUint64 // match all node types
} }
var mask uint64 var mask uint64
for _, n := range nodes { for _, n := range nodes {

View File

@ -27,10 +27,9 @@ import (
"go/token" "go/token"
"go/types" "go/types"
"io" "io"
"io/ioutil"
"os/exec" "os/exec"
"golang.org/x/tools/go/internal/gcimporter" "golang.org/x/tools/internal/gcimporter"
) )
// Find returns the name of an object (.o) or archive (.a) file // Find returns the name of an object (.o) or archive (.a) file
@ -85,9 +84,26 @@ func NewReader(r io.Reader) (io.Reader, error) {
} }
} }
// readAll works the same way as io.ReadAll, but avoids allocations and copies
// by preallocating a byte slice of the necessary size if the size is known up
// front. This is always possible when the input is an archive. In that case,
// NewReader will return the known size using an io.LimitedReader.
func readAll(r io.Reader) ([]byte, error) {
if lr, ok := r.(*io.LimitedReader); ok {
data := make([]byte, lr.N)
_, err := io.ReadFull(lr, data)
return data, err
}
return io.ReadAll(r)
}
// Read reads export data from in, decodes it, and returns type // Read reads export data from in, decodes it, and returns type
// information for the package. // information for the package.
// The package name is specified by path. //
// The package path (effectively its linker symbol prefix) is
// specified by path, since unlike the package name, this information
// may not be recorded in the export data.
//
// File position information is added to fset. // File position information is added to fset.
// //
// Read may inspect and add to the imports map to ensure that references // Read may inspect and add to the imports map to ensure that references
@ -98,7 +114,7 @@ func NewReader(r io.Reader) (io.Reader, error) {
// //
// On return, the state of the reader is undefined. // On return, the state of the reader is undefined.
func Read(in io.Reader, fset *token.FileSet, imports map[string]*types.Package, path string) (*types.Package, error) { func Read(in io.Reader, fset *token.FileSet, imports map[string]*types.Package, path string) (*types.Package, error) {
data, err := ioutil.ReadAll(in) data, err := readAll(in)
if err != nil { if err != nil {
return nil, fmt.Errorf("reading export data for %q: %v", path, err) return nil, fmt.Errorf("reading export data for %q: %v", path, err)
} }
@ -107,12 +123,6 @@ func Read(in io.Reader, fset *token.FileSet, imports map[string]*types.Package,
return nil, fmt.Errorf("can't read export data for %q directly from an archive file (call gcexportdata.NewReader first to extract export data)", path) return nil, fmt.Errorf("can't read export data for %q directly from an archive file (call gcexportdata.NewReader first to extract export data)", path)
} }
// The App Engine Go runtime v1.6 uses the old export data format.
// TODO(adonovan): delete once v1.7 has been around for a while.
if bytes.HasPrefix(data, []byte("package ")) {
return gcimporter.ImportData(imports, path, path, bytes.NewReader(data))
}
// The indexed export format starts with an 'i'; the older // The indexed export format starts with an 'i'; the older
// binary export format starts with a 'c', 'd', or 'v' // binary export format starts with a 'c', 'd', or 'v'
// (from "version"). Select appropriate importer. // (from "version"). Select appropriate importer.
@ -161,7 +171,7 @@ func Write(out io.Writer, fset *token.FileSet, pkg *types.Package) error {
// //
// Experimental: This API is experimental and may change in the future. // Experimental: This API is experimental and may change in the future.
func ReadBundle(in io.Reader, fset *token.FileSet, imports map[string]*types.Package) ([]*types.Package, error) { func ReadBundle(in io.Reader, fset *token.FileSet, imports map[string]*types.Package) ([]*types.Package, error) {
data, err := ioutil.ReadAll(in) data, err := readAll(in)
if err != nil { if err != nil {
return nil, fmt.Errorf("reading export bundle: %v", err) return nil, fmt.Errorf("reading export bundle: %v", err)
} }

File diff suppressed because it is too large Load Diff

View File

@ -60,6 +60,7 @@ func (r *responseDeduper) addAll(dr *driverResponse) {
for _, root := range dr.Roots { for _, root := range dr.Roots {
r.addRoot(root) r.addRoot(root)
} }
r.dr.GoVersion = dr.GoVersion
} }
func (r *responseDeduper) addPackage(p *Package) { func (r *responseDeduper) addPackage(p *Package) {
@ -454,11 +455,14 @@ func (state *golistState) createDriverResponse(words ...string) (*driverResponse
if err != nil { if err != nil {
return nil, err return nil, err
} }
seen := make(map[string]*jsonPackage) seen := make(map[string]*jsonPackage)
pkgs := make(map[string]*Package) pkgs := make(map[string]*Package)
additionalErrors := make(map[string][]Error) additionalErrors := make(map[string][]Error)
// Decode the JSON and convert it to Package form. // Decode the JSON and convert it to Package form.
var response driverResponse response := &driverResponse{
GoVersion: goVersion,
}
for dec := json.NewDecoder(buf); dec.More(); { for dec := json.NewDecoder(buf); dec.More(); {
p := new(jsonPackage) p := new(jsonPackage)
if err := dec.Decode(p); err != nil { if err := dec.Decode(p); err != nil {
@ -600,17 +604,12 @@ func (state *golistState) createDriverResponse(words ...string) (*driverResponse
// Work around https://golang.org/issue/28749: // Work around https://golang.org/issue/28749:
// cmd/go puts assembly, C, and C++ files in CompiledGoFiles. // cmd/go puts assembly, C, and C++ files in CompiledGoFiles.
// Filter out any elements of CompiledGoFiles that are also in OtherFiles. // Remove files from CompiledGoFiles that are non-go files
// We have to keep this workaround in place until go1.12 is a distant memory. // (or are not files that look like they are from the cache).
if len(pkg.OtherFiles) > 0 { if len(pkg.CompiledGoFiles) > 0 {
other := make(map[string]bool, len(pkg.OtherFiles))
for _, f := range pkg.OtherFiles {
other[f] = true
}
out := pkg.CompiledGoFiles[:0] out := pkg.CompiledGoFiles[:0]
for _, f := range pkg.CompiledGoFiles { for _, f := range pkg.CompiledGoFiles {
if other[f] { if ext := filepath.Ext(f); ext != ".go" && ext != "" { // ext == "" means the file is from the cache, so probably cgo-processed file
continue continue
} }
out = append(out, f) out = append(out, f)
@ -730,7 +729,7 @@ func (state *golistState) createDriverResponse(words ...string) (*driverResponse
} }
sort.Slice(response.Packages, func(i, j int) bool { return response.Packages[i].ID < response.Packages[j].ID }) sort.Slice(response.Packages, func(i, j int) bool { return response.Packages[i].ID < response.Packages[j].ID })
return &response, nil return response, nil
} }
func (state *golistState) shouldAddFilenameFromError(p *jsonPackage) bool { func (state *golistState) shouldAddFilenameFromError(p *jsonPackage) bool {
@ -756,6 +755,7 @@ func (state *golistState) shouldAddFilenameFromError(p *jsonPackage) bool {
return len(p.Error.ImportStack) == 0 || p.Error.ImportStack[len(p.Error.ImportStack)-1] == p.ImportPath return len(p.Error.ImportStack) == 0 || p.Error.ImportStack[len(p.Error.ImportStack)-1] == p.ImportPath
} }
// getGoVersion returns the effective minor version of the go command.
func (state *golistState) getGoVersion() (int, error) { func (state *golistState) getGoVersion() (int, error) {
state.goVersionOnce.Do(func() { state.goVersionOnce.Do(func() {
state.goVersion, state.goVersionError = gocommand.GoVersion(state.ctx, state.cfgInvocation(), state.cfg.gocmdRunner) state.goVersion, state.goVersionError = gocommand.GoVersion(state.ctx, state.cfgInvocation(), state.cfg.gocmdRunner)

View File

@ -15,10 +15,12 @@ import (
"go/scanner" "go/scanner"
"go/token" "go/token"
"go/types" "go/types"
"io"
"io/ioutil" "io/ioutil"
"log" "log"
"os" "os"
"path/filepath" "path/filepath"
"runtime"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -233,6 +235,11 @@ type driverResponse struct {
// Imports will be connected and then type and syntax information added in a // Imports will be connected and then type and syntax information added in a
// later pass (see refine). // later pass (see refine).
Packages []*Package Packages []*Package
// GoVersion is the minor version number used by the driver
// (e.g. the go command on the PATH) when selecting .go files.
// Zero means unknown.
GoVersion int
} }
// Load loads and returns the Go packages named by the given patterns. // Load loads and returns the Go packages named by the given patterns.
@ -256,7 +263,7 @@ func Load(cfg *Config, patterns ...string) ([]*Package, error) {
return nil, err return nil, err
} }
l.sizes = response.Sizes l.sizes = response.Sizes
return l.refine(response.Roots, response.Packages...) return l.refine(response)
} }
// defaultDriver is a driver that implements go/packages' fallback behavior. // defaultDriver is a driver that implements go/packages' fallback behavior.
@ -297,6 +304,9 @@ type Package struct {
// of the package, or while parsing or type-checking its files. // of the package, or while parsing or type-checking its files.
Errors []Error Errors []Error
// TypeErrors contains the subset of errors produced during type checking.
TypeErrors []types.Error
// GoFiles lists the absolute file paths of the package's Go source files. // GoFiles lists the absolute file paths of the package's Go source files.
GoFiles []string GoFiles []string
@ -532,6 +542,7 @@ type loaderPackage struct {
needsrc bool // load from source (Mode >= LoadTypes) needsrc bool // load from source (Mode >= LoadTypes)
needtypes bool // type information is either requested or depended on needtypes bool // type information is either requested or depended on
initial bool // package was matched by a pattern initial bool // package was matched by a pattern
goVersion int // minor version number of go command on PATH
} }
// loader holds the working state of a single call to load. // loader holds the working state of a single call to load.
@ -618,7 +629,8 @@ func newLoader(cfg *Config) *loader {
// refine connects the supplied packages into a graph and then adds type and // refine connects the supplied packages into a graph and then adds type and
// and syntax information as requested by the LoadMode. // and syntax information as requested by the LoadMode.
func (ld *loader) refine(roots []string, list ...*Package) ([]*Package, error) { func (ld *loader) refine(response *driverResponse) ([]*Package, error) {
roots := response.Roots
rootMap := make(map[string]int, len(roots)) rootMap := make(map[string]int, len(roots))
for i, root := range roots { for i, root := range roots {
rootMap[root] = i rootMap[root] = i
@ -626,7 +638,7 @@ func (ld *loader) refine(roots []string, list ...*Package) ([]*Package, error) {
ld.pkgs = make(map[string]*loaderPackage) ld.pkgs = make(map[string]*loaderPackage)
// first pass, fixup and build the map and roots // first pass, fixup and build the map and roots
var initial = make([]*loaderPackage, len(roots)) var initial = make([]*loaderPackage, len(roots))
for _, pkg := range list { for _, pkg := range response.Packages {
rootIndex := -1 rootIndex := -1
if i, found := rootMap[pkg.ID]; found { if i, found := rootMap[pkg.ID]; found {
rootIndex = i rootIndex = i
@ -648,6 +660,7 @@ func (ld *loader) refine(roots []string, list ...*Package) ([]*Package, error) {
Package: pkg, Package: pkg,
needtypes: needtypes, needtypes: needtypes,
needsrc: needsrc, needsrc: needsrc,
goVersion: response.GoVersion,
} }
ld.pkgs[lpkg.ID] = lpkg ld.pkgs[lpkg.ID] = lpkg
if rootIndex >= 0 { if rootIndex >= 0 {
@ -865,12 +878,19 @@ func (ld *loader) loadPackage(lpkg *loaderPackage) {
// never has to create a types.Package for an indirect dependency, // never has to create a types.Package for an indirect dependency,
// which would then require that such created packages be explicitly // which would then require that such created packages be explicitly
// inserted back into the Import graph as a final step after export data loading. // inserted back into the Import graph as a final step after export data loading.
// (Hence this return is after the Types assignment.)
// The Diamond test exercises this case. // The Diamond test exercises this case.
if !lpkg.needtypes && !lpkg.needsrc { if !lpkg.needtypes && !lpkg.needsrc {
return return
} }
if !lpkg.needsrc { if !lpkg.needsrc {
ld.loadFromExportData(lpkg) if err := ld.loadFromExportData(lpkg); err != nil {
lpkg.Errors = append(lpkg.Errors, Error{
Pos: "-",
Msg: err.Error(),
Kind: UnknownError, // e.g. can't find/open/parse export data
})
}
return // not a source package, don't get syntax trees return // not a source package, don't get syntax trees
} }
@ -902,6 +922,7 @@ func (ld *loader) loadPackage(lpkg *loaderPackage) {
case types.Error: case types.Error:
// from type checker // from type checker
lpkg.TypeErrors = append(lpkg.TypeErrors, err)
errs = append(errs, Error{ errs = append(errs, Error{
Pos: err.Fset.Position(err.Pos).String(), Pos: err.Fset.Position(err.Pos).String(),
Msg: err.Msg, Msg: err.Msg,
@ -923,11 +944,41 @@ func (ld *loader) loadPackage(lpkg *loaderPackage) {
lpkg.Errors = append(lpkg.Errors, errs...) lpkg.Errors = append(lpkg.Errors, errs...)
} }
// If the go command on the PATH is newer than the runtime,
// then the go/{scanner,ast,parser,types} packages from the
// standard library may be unable to process the files
// selected by go list.
//
// There is currently no way to downgrade the effective
// version of the go command (see issue 52078), so we proceed
// with the newer go command but, in case of parse or type
// errors, we emit an additional diagnostic.
//
// See:
// - golang.org/issue/52078 (flag to set release tags)
// - golang.org/issue/50825 (gopls legacy version support)
// - golang.org/issue/55883 (go/packages confusing error)
//
// Should we assert a hard minimum of (currently) go1.16 here?
var runtimeVersion int
if _, err := fmt.Sscanf(runtime.Version(), "go1.%d", &runtimeVersion); err == nil && runtimeVersion < lpkg.goVersion {
defer func() {
if len(lpkg.Errors) > 0 {
appendError(Error{
Pos: "-",
Msg: fmt.Sprintf("This application uses version go1.%d of the source-processing packages but runs version go1.%d of 'go list'. It may fail to process source files that rely on newer language features. If so, rebuild the application using a newer version of Go.", runtimeVersion, lpkg.goVersion),
Kind: UnknownError,
})
}
}()
}
if ld.Config.Mode&NeedTypes != 0 && len(lpkg.CompiledGoFiles) == 0 && lpkg.ExportFile != "" { if ld.Config.Mode&NeedTypes != 0 && len(lpkg.CompiledGoFiles) == 0 && lpkg.ExportFile != "" {
// The config requested loading sources and types, but sources are missing. // The config requested loading sources and types, but sources are missing.
// Add an error to the package and fall back to loading from export data. // Add an error to the package and fall back to loading from export data.
appendError(Error{"-", fmt.Sprintf("sources missing for package %s", lpkg.ID), ParseError}) appendError(Error{"-", fmt.Sprintf("sources missing for package %s", lpkg.ID), ParseError})
ld.loadFromExportData(lpkg) _ = ld.loadFromExportData(lpkg) // ignore any secondary errors
return // can't get syntax trees for this package return // can't get syntax trees for this package
} }
@ -981,7 +1032,7 @@ func (ld *loader) loadPackage(lpkg *loaderPackage) {
tc := &types.Config{ tc := &types.Config{
Importer: importer, Importer: importer,
// Type-check bodies of functions only in non-initial packages. // Type-check bodies of functions only in initial packages.
// Example: for import graph A->B->C and initial packages {A,C}, // Example: for import graph A->B->C and initial packages {A,C},
// we can ignore function bodies in B. // we can ignore function bodies in B.
IgnoreFuncBodies: ld.Mode&NeedDeps == 0 && !lpkg.initial, IgnoreFuncBodies: ld.Mode&NeedDeps == 0 && !lpkg.initial,
@ -1151,9 +1202,10 @@ func sameFile(x, y string) bool {
return false return false
} }
// loadFromExportData returns type information for the specified // loadFromExportData ensures that type information is present for the specified
// package, loading it from an export data file on the first request. // package, loading it from an export data file on the first request.
func (ld *loader) loadFromExportData(lpkg *loaderPackage) (*types.Package, error) { // On success it sets lpkg.Types to a new Package.
func (ld *loader) loadFromExportData(lpkg *loaderPackage) error {
if lpkg.PkgPath == "" { if lpkg.PkgPath == "" {
log.Fatalf("internal error: Package %s has no PkgPath", lpkg) log.Fatalf("internal error: Package %s has no PkgPath", lpkg)
} }
@ -1164,8 +1216,8 @@ func (ld *loader) loadFromExportData(lpkg *loaderPackage) (*types.Package, error
// must be sequential. (Finer-grained locking would require // must be sequential. (Finer-grained locking would require
// changes to the gcexportdata API.) // changes to the gcexportdata API.)
// //
// The exportMu lock guards the Package.Pkg field and the // The exportMu lock guards the lpkg.Types field and the
// types.Package it points to, for each Package in the graph. // types.Package it points to, for each loaderPackage in the graph.
// //
// Not all accesses to Package.Pkg need to be protected by exportMu: // Not all accesses to Package.Pkg need to be protected by exportMu:
// graph ordering ensures that direct dependencies of source // graph ordering ensures that direct dependencies of source
@ -1174,18 +1226,18 @@ func (ld *loader) loadFromExportData(lpkg *loaderPackage) (*types.Package, error
defer ld.exportMu.Unlock() defer ld.exportMu.Unlock()
if tpkg := lpkg.Types; tpkg != nil && tpkg.Complete() { if tpkg := lpkg.Types; tpkg != nil && tpkg.Complete() {
return tpkg, nil // cache hit return nil // cache hit
} }
lpkg.IllTyped = true // fail safe lpkg.IllTyped = true // fail safe
if lpkg.ExportFile == "" { if lpkg.ExportFile == "" {
// Errors while building export data will have been printed to stderr. // Errors while building export data will have been printed to stderr.
return nil, fmt.Errorf("no export data file") return fmt.Errorf("no export data file")
} }
f, err := os.Open(lpkg.ExportFile) f, err := os.Open(lpkg.ExportFile)
if err != nil { if err != nil {
return nil, err return err
} }
defer f.Close() defer f.Close()
@ -1197,7 +1249,7 @@ func (ld *loader) loadFromExportData(lpkg *loaderPackage) (*types.Package, error
// queries.) // queries.)
r, err := gcexportdata.NewReader(f) r, err := gcexportdata.NewReader(f)
if err != nil { if err != nil {
return nil, fmt.Errorf("reading %s: %v", lpkg.ExportFile, err) return fmt.Errorf("reading %s: %v", lpkg.ExportFile, err)
} }
// Build the view. // Build the view.
@ -1241,7 +1293,7 @@ func (ld *loader) loadFromExportData(lpkg *loaderPackage) (*types.Package, error
// (May modify incomplete packages in view but not create new ones.) // (May modify incomplete packages in view but not create new ones.)
tpkg, err := gcexportdata.Read(r, ld.Fset, view, lpkg.PkgPath) tpkg, err := gcexportdata.Read(r, ld.Fset, view, lpkg.PkgPath)
if err != nil { if err != nil {
return nil, fmt.Errorf("reading %s: %v", lpkg.ExportFile, err) return fmt.Errorf("reading %s: %v", lpkg.ExportFile, err)
} }
if _, ok := view["go.shape"]; ok { if _, ok := view["go.shape"]; ok {
// Account for the pseudopackage "go.shape" that gets // Account for the pseudopackage "go.shape" that gets
@ -1254,8 +1306,7 @@ func (ld *loader) loadFromExportData(lpkg *loaderPackage) (*types.Package, error
lpkg.Types = tpkg lpkg.Types = tpkg
lpkg.IllTyped = false lpkg.IllTyped = false
return nil
return tpkg, nil
} }
// impliedLoadMode returns loadMode with its dependencies. // impliedLoadMode returns loadMode with its dependencies.
@ -1271,3 +1322,5 @@ func impliedLoadMode(loadMode LoadMode) LoadMode {
func usesExportData(cfg *Config) bool { func usesExportData(cfg *Config) bool {
return cfg.Mode&NeedExportFile != 0 || cfg.Mode&NeedTypes != 0 && cfg.Mode&NeedDeps == 0 return cfg.Mode&NeedExportFile != 0 || cfg.Mode&NeedTypes != 0 && cfg.Mode&NeedDeps == 0
} }
var _ interface{} = io.Discard // assert build toolchain is go1.16 or later

View File

@ -0,0 +1,119 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build darwin && cgo
// +build darwin,cgo
package fastwalk
/*
#include <dirent.h>
// fastwalk_readdir_r wraps readdir_r so that we don't have to pass a dirent**
// result pointer which triggers CGO's "Go pointer to Go pointer" check unless
// we allocat the result dirent* with malloc.
//
// fastwalk_readdir_r returns 0 on success, -1 upon reaching the end of the
// directory, or a positive error number to indicate failure.
static int fastwalk_readdir_r(DIR *fd, struct dirent *entry) {
struct dirent *result;
int ret = readdir_r(fd, entry, &result);
if (ret == 0 && result == NULL) {
ret = -1; // EOF
}
return ret;
}
*/
import "C"
import (
"os"
"syscall"
"unsafe"
)
func readDir(dirName string, fn func(dirName, entName string, typ os.FileMode) error) error {
fd, err := openDir(dirName)
if err != nil {
return &os.PathError{Op: "opendir", Path: dirName, Err: err}
}
defer C.closedir(fd)
skipFiles := false
var dirent syscall.Dirent
for {
ret := int(C.fastwalk_readdir_r(fd, (*C.struct_dirent)(unsafe.Pointer(&dirent))))
if ret != 0 {
if ret == -1 {
break // EOF
}
if ret == int(syscall.EINTR) {
continue
}
return &os.PathError{Op: "readdir", Path: dirName, Err: syscall.Errno(ret)}
}
if dirent.Ino == 0 {
continue
}
typ := dtToType(dirent.Type)
if skipFiles && typ.IsRegular() {
continue
}
name := (*[len(syscall.Dirent{}.Name)]byte)(unsafe.Pointer(&dirent.Name))[:]
name = name[:dirent.Namlen]
for i, c := range name {
if c == 0 {
name = name[:i]
break
}
}
// Check for useless names before allocating a string.
if string(name) == "." || string(name) == ".." {
continue
}
if err := fn(dirName, string(name), typ); err != nil {
if err != ErrSkipFiles {
return err
}
skipFiles = true
}
}
return nil
}
func dtToType(typ uint8) os.FileMode {
switch typ {
case syscall.DT_BLK:
return os.ModeDevice
case syscall.DT_CHR:
return os.ModeDevice | os.ModeCharDevice
case syscall.DT_DIR:
return os.ModeDir
case syscall.DT_FIFO:
return os.ModeNamedPipe
case syscall.DT_LNK:
return os.ModeSymlink
case syscall.DT_REG:
return 0
case syscall.DT_SOCK:
return os.ModeSocket
}
return ^os.FileMode(0)
}
// openDir wraps opendir(3) and handles any EINTR errors. The returned *DIR
// needs to be closed with closedir(3).
func openDir(path string) (*C.DIR, error) {
name, err := syscall.BytePtrFromString(path)
if err != nil {
return nil, err
}
for {
fd, err := C.opendir((*C.char)(unsafe.Pointer(name)))
if err != syscall.EINTR {
return fd, err
}
}
}

View File

@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build (linux || darwin) && !appengine //go:build (linux || (darwin && !cgo)) && !appengine
// +build linux darwin // +build linux darwin,!cgo
// +build !appengine // +build !appengine
package fastwalk package fastwalk
@ -11,5 +11,5 @@ package fastwalk
import "syscall" import "syscall"
func direntInode(dirent *syscall.Dirent) uint64 { func direntInode(dirent *syscall.Dirent) uint64 {
return uint64(dirent.Ino) return dirent.Ino
} }

View File

@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build darwin || freebsd || openbsd || netbsd //go:build (darwin && !cgo) || freebsd || openbsd || netbsd
// +build darwin freebsd openbsd netbsd // +build darwin,!cgo freebsd openbsd netbsd
package fastwalk package fastwalk

View File

@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build (linux || darwin || freebsd || openbsd || netbsd) && !appengine //go:build (linux || freebsd || openbsd || netbsd || (darwin && !cgo)) && !appengine
// +build linux darwin freebsd openbsd netbsd // +build linux freebsd openbsd netbsd darwin,!cgo
// +build !appengine // +build !appengine
package fastwalk package fastwalk

View File

@ -12,7 +12,6 @@ import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"go/ast"
"go/constant" "go/constant"
"go/token" "go/token"
"go/types" "go/types"
@ -145,7 +144,7 @@ func BExportData(fset *token.FileSet, pkg *types.Package) (b []byte, err error)
objcount := 0 objcount := 0
scope := pkg.Scope() scope := pkg.Scope()
for _, name := range scope.Names() { for _, name := range scope.Names() {
if !ast.IsExported(name) { if !token.IsExported(name) {
continue continue
} }
if trace { if trace {
@ -482,7 +481,7 @@ func (p *exporter) method(m *types.Func) {
p.pos(m) p.pos(m)
p.string(m.Name()) p.string(m.Name())
if m.Name() != "_" && !ast.IsExported(m.Name()) { if m.Name() != "_" && !token.IsExported(m.Name()) {
p.pkg(m.Pkg(), false) p.pkg(m.Pkg(), false)
} }
@ -501,7 +500,7 @@ func (p *exporter) fieldName(f *types.Var) {
// 3) field name doesn't match base type name (alias name) // 3) field name doesn't match base type name (alias name)
bname := basetypeName(f.Type()) bname := basetypeName(f.Type())
if name == bname { if name == bname {
if ast.IsExported(name) { if token.IsExported(name) {
name = "" // 1) we don't need to know the field name or package name = "" // 1) we don't need to know the field name or package
} else { } else {
name = "?" // 2) use unexported name "?" to force package export name = "?" // 2) use unexported name "?" to force package export
@ -514,7 +513,7 @@ func (p *exporter) fieldName(f *types.Var) {
} }
p.string(name) p.string(name)
if name != "" && !ast.IsExported(name) { if name != "" && !token.IsExported(name) {
p.pkg(f.Pkg(), false) p.pkg(f.Pkg(), false)
} }
} }

View File

@ -0,0 +1,265 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file is a reduced copy of $GOROOT/src/go/internal/gcimporter/gcimporter.go.
// Package gcimporter provides various functions for reading
// gc-generated object files that can be used to implement the
// Importer interface defined by the Go 1.5 standard library package.
package gcimporter // import "golang.org/x/tools/internal/gcimporter"
import (
"bufio"
"bytes"
"fmt"
"go/build"
"go/token"
"go/types"
"io"
"io/ioutil"
"os"
"os/exec"
"path/filepath"
"strings"
"sync"
)
const (
// Enable debug during development: it adds some additional checks, and
// prevents errors from being recovered.
debug = false
// If trace is set, debugging output is printed to std out.
trace = false
)
var exportMap sync.Map // package dir → func() (string, bool)
// lookupGorootExport returns the location of the export data
// (normally found in the build cache, but located in GOROOT/pkg
// in prior Go releases) for the package located in pkgDir.
//
// (We use the package's directory instead of its import path
// mainly to simplify handling of the packages in src/vendor
// and cmd/vendor.)
func lookupGorootExport(pkgDir string) (string, bool) {
f, ok := exportMap.Load(pkgDir)
if !ok {
var (
listOnce sync.Once
exportPath string
)
f, _ = exportMap.LoadOrStore(pkgDir, func() (string, bool) {
listOnce.Do(func() {
cmd := exec.Command("go", "list", "-export", "-f", "{{.Export}}", pkgDir)
cmd.Dir = build.Default.GOROOT
var output []byte
output, err := cmd.Output()
if err != nil {
return
}
exports := strings.Split(string(bytes.TrimSpace(output)), "\n")
if len(exports) != 1 {
return
}
exportPath = exports[0]
})
return exportPath, exportPath != ""
})
}
return f.(func() (string, bool))()
}
var pkgExts = [...]string{".a", ".o"}
// FindPkg returns the filename and unique package id for an import
// path based on package information provided by build.Import (using
// the build.Default build.Context). A relative srcDir is interpreted
// relative to the current working directory.
// If no file was found, an empty filename is returned.
func FindPkg(path, srcDir string) (filename, id string) {
if path == "" {
return
}
var noext string
switch {
default:
// "x" -> "$GOPATH/pkg/$GOOS_$GOARCH/x.ext", "x"
// Don't require the source files to be present.
if abs, err := filepath.Abs(srcDir); err == nil { // see issue 14282
srcDir = abs
}
bp, _ := build.Import(path, srcDir, build.FindOnly|build.AllowBinary)
if bp.PkgObj == "" {
var ok bool
if bp.Goroot && bp.Dir != "" {
filename, ok = lookupGorootExport(bp.Dir)
}
if !ok {
id = path // make sure we have an id to print in error message
return
}
} else {
noext = strings.TrimSuffix(bp.PkgObj, ".a")
id = bp.ImportPath
}
case build.IsLocalImport(path):
// "./x" -> "/this/directory/x.ext", "/this/directory/x"
noext = filepath.Join(srcDir, path)
id = noext
case filepath.IsAbs(path):
// for completeness only - go/build.Import
// does not support absolute imports
// "/x" -> "/x.ext", "/x"
noext = path
id = path
}
if false { // for debugging
if path != id {
fmt.Printf("%s -> %s\n", path, id)
}
}
if filename != "" {
if f, err := os.Stat(filename); err == nil && !f.IsDir() {
return
}
}
// try extensions
for _, ext := range pkgExts {
filename = noext + ext
if f, err := os.Stat(filename); err == nil && !f.IsDir() {
return
}
}
filename = "" // not found
return
}
// Import imports a gc-generated package given its import path and srcDir, adds
// the corresponding package object to the packages map, and returns the object.
// The packages map must contain all packages already imported.
func Import(packages map[string]*types.Package, path, srcDir string, lookup func(path string) (io.ReadCloser, error)) (pkg *types.Package, err error) {
var rc io.ReadCloser
var filename, id string
if lookup != nil {
// With custom lookup specified, assume that caller has
// converted path to a canonical import path for use in the map.
if path == "unsafe" {
return types.Unsafe, nil
}
id = path
// No need to re-import if the package was imported completely before.
if pkg = packages[id]; pkg != nil && pkg.Complete() {
return
}
f, err := lookup(path)
if err != nil {
return nil, err
}
rc = f
} else {
filename, id = FindPkg(path, srcDir)
if filename == "" {
if path == "unsafe" {
return types.Unsafe, nil
}
return nil, fmt.Errorf("can't find import: %q", id)
}
// no need to re-import if the package was imported completely before
if pkg = packages[id]; pkg != nil && pkg.Complete() {
return
}
// open file
f, err := os.Open(filename)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
// add file name to error
err = fmt.Errorf("%s: %v", filename, err)
}
}()
rc = f
}
defer rc.Close()
var hdr string
var size int64
buf := bufio.NewReader(rc)
if hdr, size, err = FindExportData(buf); err != nil {
return
}
switch hdr {
case "$$B\n":
var data []byte
data, err = ioutil.ReadAll(buf)
if err != nil {
break
}
// TODO(gri): allow clients of go/importer to provide a FileSet.
// Or, define a new standard go/types/gcexportdata package.
fset := token.NewFileSet()
// The indexed export format starts with an 'i'; the older
// binary export format starts with a 'c', 'd', or 'v'
// (from "version"). Select appropriate importer.
if len(data) > 0 {
switch data[0] {
case 'i':
_, pkg, err := IImportData(fset, packages, data[1:], id)
return pkg, err
case 'v', 'c', 'd':
_, pkg, err := BImportData(fset, packages, data, id)
return pkg, err
case 'u':
_, pkg, err := UImportData(fset, packages, data[1:size], id)
return pkg, err
default:
l := len(data)
if l > 10 {
l = 10
}
return nil, fmt.Errorf("unexpected export data with prefix %q for path %s", string(data[:l]), id)
}
}
default:
err = fmt.Errorf("unknown export data header: %q", hdr)
}
return
}
func deref(typ types.Type) types.Type {
if p, _ := typ.(*types.Pointer); p != nil {
return p.Elem()
}
return typ
}
type byPath []*types.Package
func (a byPath) Len() int { return len(a) }
func (a byPath) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a byPath) Less(i, j int) bool { return a[i].Path() < a[j].Path() }

View File

@ -12,7 +12,6 @@ import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"go/ast"
"go/constant" "go/constant"
"go/token" "go/token"
"go/types" "go/types"
@ -23,9 +22,45 @@ import (
"strconv" "strconv"
"strings" "strings"
"golang.org/x/tools/internal/tokeninternal"
"golang.org/x/tools/internal/typeparams" "golang.org/x/tools/internal/typeparams"
) )
// IExportShallow encodes "shallow" export data for the specified package.
//
// No promises are made about the encoding other than that it can be
// decoded by the same version of IIExportShallow. If you plan to save
// export data in the file system, be sure to include a cryptographic
// digest of the executable in the key to avoid version skew.
func IExportShallow(fset *token.FileSet, pkg *types.Package) ([]byte, error) {
// In principle this operation can only fail if out.Write fails,
// but that's impossible for bytes.Buffer---and as a matter of
// fact iexportCommon doesn't even check for I/O errors.
// TODO(adonovan): handle I/O errors properly.
// TODO(adonovan): use byte slices throughout, avoiding copying.
const bundle, shallow = false, true
var out bytes.Buffer
err := iexportCommon(&out, fset, bundle, shallow, iexportVersion, []*types.Package{pkg})
return out.Bytes(), err
}
// IImportShallow decodes "shallow" types.Package data encoded by IExportShallow
// in the same executable. This function cannot import data from
// cmd/compile or gcexportdata.Write.
func IImportShallow(fset *token.FileSet, imports map[string]*types.Package, data []byte, path string, insert InsertType) (*types.Package, error) {
const bundle = false
pkgs, err := iimportCommon(fset, imports, data, bundle, path, insert)
if err != nil {
return nil, err
}
return pkgs[0], nil
}
// InsertType is the type of a function that creates a types.TypeName
// object for a named type and inserts it into the scope of the
// specified Package.
type InsertType = func(pkg *types.Package, name string)
// Current bundled export format version. Increase with each format change. // Current bundled export format version. Increase with each format change.
// 0: initial implementation // 0: initial implementation
const bundleVersion = 0 const bundleVersion = 0
@ -36,15 +71,17 @@ const bundleVersion = 0
// The package path of the top-level package will not be recorded, // The package path of the top-level package will not be recorded,
// so that calls to IImportData can override with a provided package path. // so that calls to IImportData can override with a provided package path.
func IExportData(out io.Writer, fset *token.FileSet, pkg *types.Package) error { func IExportData(out io.Writer, fset *token.FileSet, pkg *types.Package) error {
return iexportCommon(out, fset, false, iexportVersion, []*types.Package{pkg}) const bundle, shallow = false, false
return iexportCommon(out, fset, bundle, shallow, iexportVersion, []*types.Package{pkg})
} }
// IExportBundle writes an indexed export bundle for pkgs to out. // IExportBundle writes an indexed export bundle for pkgs to out.
func IExportBundle(out io.Writer, fset *token.FileSet, pkgs []*types.Package) error { func IExportBundle(out io.Writer, fset *token.FileSet, pkgs []*types.Package) error {
return iexportCommon(out, fset, true, iexportVersion, pkgs) const bundle, shallow = true, false
return iexportCommon(out, fset, bundle, shallow, iexportVersion, pkgs)
} }
func iexportCommon(out io.Writer, fset *token.FileSet, bundle bool, version int, pkgs []*types.Package) (err error) { func iexportCommon(out io.Writer, fset *token.FileSet, bundle, shallow bool, version int, pkgs []*types.Package) (err error) {
if !debug { if !debug {
defer func() { defer func() {
if e := recover(); e != nil { if e := recover(); e != nil {
@ -61,6 +98,7 @@ func iexportCommon(out io.Writer, fset *token.FileSet, bundle bool, version int,
p := iexporter{ p := iexporter{
fset: fset, fset: fset,
version: version, version: version,
shallow: shallow,
allPkgs: map[*types.Package]bool{}, allPkgs: map[*types.Package]bool{},
stringIndex: map[string]uint64{}, stringIndex: map[string]uint64{},
declIndex: map[types.Object]uint64{}, declIndex: map[types.Object]uint64{},
@ -82,7 +120,7 @@ func iexportCommon(out io.Writer, fset *token.FileSet, bundle bool, version int,
for _, pkg := range pkgs { for _, pkg := range pkgs {
scope := pkg.Scope() scope := pkg.Scope()
for _, name := range scope.Names() { for _, name := range scope.Names() {
if ast.IsExported(name) { if token.IsExported(name) {
p.pushDecl(scope.Lookup(name)) p.pushDecl(scope.Lookup(name))
} }
} }
@ -101,6 +139,17 @@ func iexportCommon(out io.Writer, fset *token.FileSet, bundle bool, version int,
p.doDecl(p.declTodo.popHead()) p.doDecl(p.declTodo.popHead())
} }
// Produce index of offset of each file record in files.
var files intWriter
var fileOffset []uint64 // fileOffset[i] is offset in files of file encoded as i
if p.shallow {
fileOffset = make([]uint64, len(p.fileInfos))
for i, info := range p.fileInfos {
fileOffset[i] = uint64(files.Len())
p.encodeFile(&files, info.file, info.needed)
}
}
// Append indices to data0 section. // Append indices to data0 section.
dataLen := uint64(p.data0.Len()) dataLen := uint64(p.data0.Len())
w := p.newWriter() w := p.newWriter()
@ -126,16 +175,75 @@ func iexportCommon(out io.Writer, fset *token.FileSet, bundle bool, version int,
} }
hdr.uint64(uint64(p.version)) hdr.uint64(uint64(p.version))
hdr.uint64(uint64(p.strings.Len())) hdr.uint64(uint64(p.strings.Len()))
if p.shallow {
hdr.uint64(uint64(files.Len()))
hdr.uint64(uint64(len(fileOffset)))
for _, offset := range fileOffset {
hdr.uint64(offset)
}
}
hdr.uint64(dataLen) hdr.uint64(dataLen)
// Flush output. // Flush output.
io.Copy(out, &hdr) io.Copy(out, &hdr)
io.Copy(out, &p.strings) io.Copy(out, &p.strings)
if p.shallow {
io.Copy(out, &files)
}
io.Copy(out, &p.data0) io.Copy(out, &p.data0)
return nil return nil
} }
// encodeFile writes to w a representation of the file sufficient to
// faithfully restore position information about all needed offsets.
// Mutates the needed array.
func (p *iexporter) encodeFile(w *intWriter, file *token.File, needed []uint64) {
_ = needed[0] // precondition: needed is non-empty
w.uint64(p.stringOff(file.Name()))
size := uint64(file.Size())
w.uint64(size)
// Sort the set of needed offsets. Duplicates are harmless.
sort.Slice(needed, func(i, j int) bool { return needed[i] < needed[j] })
lines := tokeninternal.GetLines(file) // byte offset of each line start
w.uint64(uint64(len(lines)))
// Rather than record the entire array of line start offsets,
// we save only a sparse list of (index, offset) pairs for
// the start of each line that contains a needed position.
var sparse [][2]int // (index, offset) pairs
outer:
for i, lineStart := range lines {
lineEnd := size
if i < len(lines)-1 {
lineEnd = uint64(lines[i+1])
}
// Does this line contains a needed offset?
if needed[0] < lineEnd {
sparse = append(sparse, [2]int{i, lineStart})
for needed[0] < lineEnd {
needed = needed[1:]
if len(needed) == 0 {
break outer
}
}
}
}
// Delta-encode the columns.
w.uint64(uint64(len(sparse)))
var prev [2]int
for _, pair := range sparse {
w.uint64(uint64(pair[0] - prev[0]))
w.uint64(uint64(pair[1] - prev[1]))
prev = pair
}
}
// writeIndex writes out an object index. mainIndex indicates whether // writeIndex writes out an object index. mainIndex indicates whether
// we're writing out the main index, which is also read by // we're writing out the main index, which is also read by
// non-compiler tools and includes a complete package description // non-compiler tools and includes a complete package description
@ -205,7 +313,8 @@ type iexporter struct {
out *bytes.Buffer out *bytes.Buffer
version int version int
localpkg *types.Package shallow bool // don't put types from other packages in the index
localpkg *types.Package // (nil in bundle mode)
// allPkgs tracks all packages that have been referenced by // allPkgs tracks all packages that have been referenced by
// the export data, so we can ensure to include them in the // the export data, so we can ensure to include them in the
@ -217,6 +326,12 @@ type iexporter struct {
strings intWriter strings intWriter
stringIndex map[string]uint64 stringIndex map[string]uint64
// In shallow mode, object positions are encoded as (file, offset).
// Each file is recorded as a line-number table.
// Only the lines of needed positions are saved faithfully.
fileInfo map[*token.File]uint64 // value is index in fileInfos
fileInfos []*filePositions
data0 intWriter data0 intWriter
declIndex map[types.Object]uint64 declIndex map[types.Object]uint64
tparamNames map[types.Object]string // typeparam->exported name tparamNames map[types.Object]string // typeparam->exported name
@ -225,6 +340,11 @@ type iexporter struct {
indent int // for tracing support indent int // for tracing support
} }
type filePositions struct {
file *token.File
needed []uint64 // unordered list of needed file offsets
}
func (p *iexporter) trace(format string, args ...interface{}) { func (p *iexporter) trace(format string, args ...interface{}) {
if !trace { if !trace {
// Call sites should also be guarded, but having this check here allows // Call sites should also be guarded, but having this check here allows
@ -248,6 +368,25 @@ func (p *iexporter) stringOff(s string) uint64 {
return off return off
} }
// fileIndexAndOffset returns the index of the token.File and the byte offset of pos within it.
func (p *iexporter) fileIndexAndOffset(file *token.File, pos token.Pos) (uint64, uint64) {
index, ok := p.fileInfo[file]
if !ok {
index = uint64(len(p.fileInfo))
p.fileInfos = append(p.fileInfos, &filePositions{file: file})
if p.fileInfo == nil {
p.fileInfo = make(map[*token.File]uint64)
}
p.fileInfo[file] = index
}
// Record each needed offset.
info := p.fileInfos[index]
offset := uint64(file.Offset(pos))
info.needed = append(info.needed, offset)
return index, offset
}
// pushDecl adds n to the declaration work queue, if not already present. // pushDecl adds n to the declaration work queue, if not already present.
func (p *iexporter) pushDecl(obj types.Object) { func (p *iexporter) pushDecl(obj types.Object) {
// Package unsafe is known to the compiler and predeclared. // Package unsafe is known to the compiler and predeclared.
@ -256,6 +395,11 @@ func (p *iexporter) pushDecl(obj types.Object) {
panic("cannot export package unsafe") panic("cannot export package unsafe")
} }
// Shallow export data: don't index decls from other packages.
if p.shallow && obj.Pkg() != p.localpkg {
return
}
if _, ok := p.declIndex[obj]; ok { if _, ok := p.declIndex[obj]; ok {
return return
} }
@ -303,7 +447,13 @@ func (p *iexporter) doDecl(obj types.Object) {
case *types.Func: case *types.Func:
sig, _ := obj.Type().(*types.Signature) sig, _ := obj.Type().(*types.Signature)
if sig.Recv() != nil { if sig.Recv() != nil {
panic(internalErrorf("unexpected method: %v", sig)) // We shouldn't see methods in the package scope,
// but the type checker may repair "func () F() {}"
// to "func (Invalid) F()" and then treat it like "func F()",
// so allow that. See golang/go#57729.
if sig.Recv().Type() != types.Typ[types.Invalid] {
panic(internalErrorf("unexpected method: %v", sig))
}
} }
// Function. // Function.
@ -415,13 +565,30 @@ func (w *exportWriter) tag(tag byte) {
} }
func (w *exportWriter) pos(pos token.Pos) { func (w *exportWriter) pos(pos token.Pos) {
if w.p.version >= iexportVersionPosCol { if w.p.shallow {
w.posV2(pos)
} else if w.p.version >= iexportVersionPosCol {
w.posV1(pos) w.posV1(pos)
} else { } else {
w.posV0(pos) w.posV0(pos)
} }
} }
// posV2 encoding (used only in shallow mode) records positions as
// (file, offset), where file is the index in the token.File table
// (which records the file name and newline offsets) and offset is a
// byte offset. It effectively ignores //line directives.
func (w *exportWriter) posV2(pos token.Pos) {
if pos == token.NoPos {
w.uint64(0)
return
}
file := w.p.fset.File(pos) // fset must be non-nil
index, offset := w.p.fileIndexAndOffset(file, pos)
w.uint64(1 + index)
w.uint64(offset)
}
func (w *exportWriter) posV1(pos token.Pos) { func (w *exportWriter) posV1(pos token.Pos) {
if w.p.fset == nil { if w.p.fset == nil {
w.int64(0) w.int64(0)
@ -497,7 +664,7 @@ func (w *exportWriter) pkg(pkg *types.Package) {
w.string(w.exportPath(pkg)) w.string(w.exportPath(pkg))
} }
func (w *exportWriter) qualifiedIdent(obj types.Object) { func (w *exportWriter) qualifiedType(obj *types.TypeName) {
name := w.p.exportName(obj) name := w.p.exportName(obj)
// Ensure any referenced declarations are written out too. // Ensure any referenced declarations are written out too.
@ -556,11 +723,11 @@ func (w *exportWriter) doTyp(t types.Type, pkg *types.Package) {
return return
} }
w.startType(definedType) w.startType(definedType)
w.qualifiedIdent(t.Obj()) w.qualifiedType(t.Obj())
case *typeparams.TypeParam: case *typeparams.TypeParam:
w.startType(typeParamType) w.startType(typeParamType)
w.qualifiedIdent(t.Obj()) w.qualifiedType(t.Obj())
case *types.Pointer: case *types.Pointer:
w.startType(pointerType) w.startType(pointerType)
@ -602,14 +769,17 @@ func (w *exportWriter) doTyp(t types.Type, pkg *types.Package) {
case *types.Struct: case *types.Struct:
w.startType(structType) w.startType(structType)
w.setPkg(pkg, true)
n := t.NumFields() n := t.NumFields()
if n > 0 {
w.setPkg(t.Field(0).Pkg(), true) // qualifying package for field objects
} else {
w.setPkg(pkg, true)
}
w.uint64(uint64(n)) w.uint64(uint64(n))
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
f := t.Field(i) f := t.Field(i)
w.pos(f.Pos()) w.pos(f.Pos())
w.string(f.Name()) w.string(f.Name()) // unexported fields implicitly qualified by prior setPkg
w.typ(f.Type(), pkg) w.typ(f.Type(), pkg)
w.bool(f.Anonymous()) w.bool(f.Anonymous())
w.string(t.Tag(i)) // note (or tag) w.string(t.Tag(i)) // note (or tag)

View File

@ -51,6 +51,8 @@ const (
iexportVersionPosCol = 1 iexportVersionPosCol = 1
iexportVersionGo1_18 = 2 iexportVersionGo1_18 = 2
iexportVersionGenerics = 2 iexportVersionGenerics = 2
iexportVersionCurrent = 2
) )
type ident struct { type ident struct {
@ -83,7 +85,7 @@ const (
// If the export data version is not recognized or the format is otherwise // If the export data version is not recognized or the format is otherwise
// compromised, an error is returned. // compromised, an error is returned.
func IImportData(fset *token.FileSet, imports map[string]*types.Package, data []byte, path string) (int, *types.Package, error) { func IImportData(fset *token.FileSet, imports map[string]*types.Package, data []byte, path string) (int, *types.Package, error) {
pkgs, err := iimportCommon(fset, imports, data, false, path) pkgs, err := iimportCommon(fset, imports, data, false, path, nil)
if err != nil { if err != nil {
return 0, nil, err return 0, nil, err
} }
@ -92,11 +94,11 @@ func IImportData(fset *token.FileSet, imports map[string]*types.Package, data []
// IImportBundle imports a set of packages from the serialized package bundle. // IImportBundle imports a set of packages from the serialized package bundle.
func IImportBundle(fset *token.FileSet, imports map[string]*types.Package, data []byte) ([]*types.Package, error) { func IImportBundle(fset *token.FileSet, imports map[string]*types.Package, data []byte) ([]*types.Package, error) {
return iimportCommon(fset, imports, data, true, "") return iimportCommon(fset, imports, data, true, "", nil)
} }
func iimportCommon(fset *token.FileSet, imports map[string]*types.Package, data []byte, bundle bool, path string) (pkgs []*types.Package, err error) { func iimportCommon(fset *token.FileSet, imports map[string]*types.Package, data []byte, bundle bool, path string, insert InsertType) (pkgs []*types.Package, err error) {
const currentVersion = 1 const currentVersion = iexportVersionCurrent
version := int64(-1) version := int64(-1)
if !debug { if !debug {
defer func() { defer func() {
@ -135,19 +137,34 @@ func iimportCommon(fset *token.FileSet, imports map[string]*types.Package, data
} }
sLen := int64(r.uint64()) sLen := int64(r.uint64())
var fLen int64
var fileOffset []uint64
if insert != nil {
// Shallow mode uses a different position encoding.
fLen = int64(r.uint64())
fileOffset = make([]uint64, r.uint64())
for i := range fileOffset {
fileOffset[i] = r.uint64()
}
}
dLen := int64(r.uint64()) dLen := int64(r.uint64())
whence, _ := r.Seek(0, io.SeekCurrent) whence, _ := r.Seek(0, io.SeekCurrent)
stringData := data[whence : whence+sLen] stringData := data[whence : whence+sLen]
declData := data[whence+sLen : whence+sLen+dLen] fileData := data[whence+sLen : whence+sLen+fLen]
r.Seek(sLen+dLen, io.SeekCurrent) declData := data[whence+sLen+fLen : whence+sLen+fLen+dLen]
r.Seek(sLen+fLen+dLen, io.SeekCurrent)
p := iimporter{ p := iimporter{
version: int(version), version: int(version),
ipath: path, ipath: path,
insert: insert,
stringData: stringData, stringData: stringData,
stringCache: make(map[uint64]string), stringCache: make(map[uint64]string),
fileOffset: fileOffset,
fileData: fileData,
fileCache: make([]*token.File, len(fileOffset)),
pkgCache: make(map[uint64]*types.Package), pkgCache: make(map[uint64]*types.Package),
declData: declData, declData: declData,
@ -185,11 +202,18 @@ func iimportCommon(fset *token.FileSet, imports map[string]*types.Package, data
} else if pkg.Name() != pkgName { } else if pkg.Name() != pkgName {
errorf("conflicting names %s and %s for package %q", pkg.Name(), pkgName, path) errorf("conflicting names %s and %s for package %q", pkg.Name(), pkgName, path)
} }
if i == 0 && !bundle {
p.localpkg = pkg
}
p.pkgCache[pkgPathOff] = pkg p.pkgCache[pkgPathOff] = pkg
// Read index for package.
nameIndex := make(map[string]uint64) nameIndex := make(map[string]uint64)
for nSyms := r.uint64(); nSyms > 0; nSyms-- { nSyms := r.uint64()
// In shallow mode we don't expect an index for other packages.
assert(nSyms == 0 || p.localpkg == pkg || p.insert == nil)
for ; nSyms > 0; nSyms-- {
name := p.stringAt(r.uint64()) name := p.stringAt(r.uint64())
nameIndex[name] = r.uint64() nameIndex[name] = r.uint64()
} }
@ -265,8 +289,14 @@ type iimporter struct {
version int version int
ipath string ipath string
localpkg *types.Package
insert func(pkg *types.Package, name string) // "shallow" mode only
stringData []byte stringData []byte
stringCache map[uint64]string stringCache map[uint64]string
fileOffset []uint64 // fileOffset[i] is offset in fileData for info about file encoded as i
fileData []byte
fileCache []*token.File // memoized decoding of file encoded as i
pkgCache map[uint64]*types.Package pkgCache map[uint64]*types.Package
declData []byte declData []byte
@ -308,6 +338,13 @@ func (p *iimporter) doDecl(pkg *types.Package, name string) {
off, ok := p.pkgIndex[pkg][name] off, ok := p.pkgIndex[pkg][name]
if !ok { if !ok {
// In "shallow" mode, call back to the application to
// find the object and insert it into the package scope.
if p.insert != nil {
assert(pkg != p.localpkg)
p.insert(pkg, name) // "can't fail"
return
}
errorf("%v.%v not in index", pkg, name) errorf("%v.%v not in index", pkg, name)
} }
@ -332,6 +369,55 @@ func (p *iimporter) stringAt(off uint64) string {
return s return s
} }
func (p *iimporter) fileAt(index uint64) *token.File {
file := p.fileCache[index]
if file == nil {
off := p.fileOffset[index]
file = p.decodeFile(intReader{bytes.NewReader(p.fileData[off:]), p.ipath})
p.fileCache[index] = file
}
return file
}
func (p *iimporter) decodeFile(rd intReader) *token.File {
filename := p.stringAt(rd.uint64())
size := int(rd.uint64())
file := p.fake.fset.AddFile(filename, -1, size)
// SetLines requires a nondecreasing sequence.
// Because it is common for clients to derive the interval
// [start, start+len(name)] from a start position, and we
// want to ensure that the end offset is on the same line,
// we fill in the gaps of the sparse encoding with values
// that strictly increase by the largest possible amount.
// This allows us to avoid having to record the actual end
// offset of each needed line.
lines := make([]int, int(rd.uint64()))
var index, offset int
for i, n := 0, int(rd.uint64()); i < n; i++ {
index += int(rd.uint64())
offset += int(rd.uint64())
lines[index] = offset
// Ensure monotonicity between points.
for j := index - 1; j > 0 && lines[j] == 0; j-- {
lines[j] = lines[j+1] - 1
}
}
// Ensure monotonicity after last point.
for j := len(lines) - 1; j > 0 && lines[j] == 0; j-- {
size--
lines[j] = size
}
if !file.SetLines(lines) {
errorf("SetLines failed: %d", lines) // can't happen
}
return file
}
func (p *iimporter) pkgAt(off uint64) *types.Package { func (p *iimporter) pkgAt(off uint64) *types.Package {
if pkg, ok := p.pkgCache[off]; ok { if pkg, ok := p.pkgCache[off]; ok {
return pkg return pkg
@ -625,6 +711,9 @@ func (r *importReader) qualifiedIdent() (*types.Package, string) {
} }
func (r *importReader) pos() token.Pos { func (r *importReader) pos() token.Pos {
if r.p.insert != nil { // shallow mode
return r.posv2()
}
if r.p.version >= iexportVersionPosCol { if r.p.version >= iexportVersionPosCol {
r.posv1() r.posv1()
} else { } else {
@ -661,6 +750,15 @@ func (r *importReader) posv1() {
} }
} }
func (r *importReader) posv2() token.Pos {
file := r.uint64()
if file == 0 {
return token.NoPos
}
tf := r.p.fileAt(file - 1)
return tf.Pos(int(r.uint64()))
}
func (r *importReader) typ() types.Type { func (r *importReader) typ() types.Type {
return r.p.typAt(r.uint64(), nil) return r.p.typAt(r.uint64(), nil)
} }

View File

@ -21,3 +21,17 @@ func additionalPredeclared() []types.Type {
types.Universe.Lookup("any").Type(), types.Universe.Lookup("any").Type(),
} }
} }
// See cmd/compile/internal/types.SplitVargenSuffix.
func splitVargenSuffix(name string) (base, suffix string) {
i := len(name)
for i > 0 && name[i-1] >= '0' && name[i-1] <= '9' {
i--
}
const dot = "·"
if i >= len(dot) && name[i-len(dot):i] == dot {
i -= len(dot)
return name[:i], name[i:]
}
return name, ""
}

View File

@ -14,7 +14,7 @@ import (
"go/types" "go/types"
"strings" "strings"
"golang.org/x/tools/go/internal/pkgbits" "golang.org/x/tools/internal/pkgbits"
) )
// A pkgReader holds the shared state for reading a unified IR package // A pkgReader holds the shared state for reading a unified IR package
@ -36,6 +36,12 @@ type pkgReader struct {
// laterFns holds functions that need to be invoked at the end of // laterFns holds functions that need to be invoked at the end of
// import reading. // import reading.
laterFns []func() laterFns []func()
// laterFors is used in case of 'type A B' to ensure that B is processed before A.
laterFors map[types.Type]int
// ifaces holds a list of constructed Interfaces, which need to have
// Complete called after importing is done.
ifaces []*types.Interface
} }
// later adds a function to be invoked at the end of import reading. // later adds a function to be invoked at the end of import reading.
@ -63,6 +69,15 @@ func UImportData(fset *token.FileSet, imports map[string]*types.Package, data []
return return
} }
// laterFor adds a function to be invoked at the end of import reading, and records the type that function is finishing.
func (pr *pkgReader) laterFor(t types.Type, fn func()) {
if pr.laterFors == nil {
pr.laterFors = make(map[types.Type]int)
}
pr.laterFors[t] = len(pr.laterFns)
pr.laterFns = append(pr.laterFns, fn)
}
// readUnifiedPackage reads a package description from the given // readUnifiedPackage reads a package description from the given
// unified IR export data decoder. // unified IR export data decoder.
func readUnifiedPackage(fset *token.FileSet, ctxt *types.Context, imports map[string]*types.Package, input pkgbits.PkgDecoder) *types.Package { func readUnifiedPackage(fset *token.FileSet, ctxt *types.Context, imports map[string]*types.Package, input pkgbits.PkgDecoder) *types.Package {
@ -102,6 +117,10 @@ func readUnifiedPackage(fset *token.FileSet, ctxt *types.Context, imports map[st
fn() fn()
} }
for _, iface := range pr.ifaces {
iface.Complete()
}
pkg.MarkComplete() pkg.MarkComplete()
return pkg return pkg
} }
@ -139,6 +158,17 @@ func (pr *pkgReader) newReader(k pkgbits.RelocKind, idx pkgbits.Index, marker pk
} }
} }
func (pr *pkgReader) tempReader(k pkgbits.RelocKind, idx pkgbits.Index, marker pkgbits.SyncMarker) *reader {
return &reader{
Decoder: pr.TempDecoder(k, idx, marker),
p: pr,
}
}
func (pr *pkgReader) retireReader(r *reader) {
pr.RetireDecoder(&r.Decoder)
}
// @@@ Positions // @@@ Positions
func (r *reader) pos() token.Pos { func (r *reader) pos() token.Pos {
@ -163,26 +193,29 @@ func (pr *pkgReader) posBaseIdx(idx pkgbits.Index) string {
return b return b
} }
r := pr.newReader(pkgbits.RelocPosBase, idx, pkgbits.SyncPosBase) var filename string
{
r := pr.tempReader(pkgbits.RelocPosBase, idx, pkgbits.SyncPosBase)
// Within types2, position bases have a lot more details (e.g., // Within types2, position bases have a lot more details (e.g.,
// keeping track of where //line directives appeared exactly). // keeping track of where //line directives appeared exactly).
// //
// For go/types, we just track the file name. // For go/types, we just track the file name.
filename := r.String() filename = r.String()
if r.Bool() { // file base if r.Bool() { // file base
// Was: "b = token.NewTrimmedFileBase(filename, true)" // Was: "b = token.NewTrimmedFileBase(filename, true)"
} else { // line base } else { // line base
pos := r.pos() pos := r.pos()
line := r.Uint() line := r.Uint()
col := r.Uint() col := r.Uint()
// Was: "b = token.NewLineBase(pos, filename, true, line, col)" // Was: "b = token.NewLineBase(pos, filename, true, line, col)"
_, _, _ = pos, line, col _, _, _ = pos, line, col
}
pr.retireReader(r)
} }
b := filename b := filename
pr.posBases[idx] = b pr.posBases[idx] = b
return b return b
@ -231,11 +264,35 @@ func (r *reader) doPkg() *types.Package {
for i := range imports { for i := range imports {
imports[i] = r.pkg() imports[i] = r.pkg()
} }
pkg.SetImports(imports) pkg.SetImports(flattenImports(imports))
return pkg return pkg
} }
// flattenImports returns the transitive closure of all imported
// packages rooted from pkgs.
func flattenImports(pkgs []*types.Package) []*types.Package {
var res []*types.Package
seen := make(map[*types.Package]struct{})
for _, pkg := range pkgs {
if _, ok := seen[pkg]; ok {
continue
}
seen[pkg] = struct{}{}
res = append(res, pkg)
// pkg.Imports() is already flattened.
for _, pkg := range pkg.Imports() {
if _, ok := seen[pkg]; ok {
continue
}
seen[pkg] = struct{}{}
res = append(res, pkg)
}
}
return res
}
// @@@ Types // @@@ Types
func (r *reader) typ() types.Type { func (r *reader) typ() types.Type {
@ -264,12 +321,15 @@ func (pr *pkgReader) typIdx(info typeInfo, dict *readerDict) types.Type {
return typ return typ
} }
r := pr.newReader(pkgbits.RelocType, idx, pkgbits.SyncTypeIdx) var typ types.Type
r.dict = dict {
r := pr.tempReader(pkgbits.RelocType, idx, pkgbits.SyncTypeIdx)
typ := r.doTyp() r.dict = dict
assert(typ != nil)
typ = r.doTyp()
assert(typ != nil)
pr.retireReader(r)
}
// See comment in pkgReader.typIdx explaining how this happens. // See comment in pkgReader.typIdx explaining how this happens.
if prev := *where; prev != nil { if prev := *where; prev != nil {
return prev return prev
@ -372,6 +432,16 @@ func (r *reader) interfaceType() *types.Interface {
if implicit { if implicit {
iface.MarkImplicit() iface.MarkImplicit()
} }
// We need to call iface.Complete(), but if there are any embedded
// defined types, then we may not have set their underlying
// interface type yet. So we need to defer calling Complete until
// after we've called SetUnderlying everywhere.
//
// TODO(mdempsky): After CL 424876 lands, it should be safe to call
// iface.Complete() immediately.
r.p.ifaces = append(r.p.ifaces, iface)
return iface return iface
} }
@ -425,18 +495,30 @@ func (r *reader) obj() (types.Object, []types.Type) {
} }
func (pr *pkgReader) objIdx(idx pkgbits.Index) (*types.Package, string) { func (pr *pkgReader) objIdx(idx pkgbits.Index) (*types.Package, string) {
rname := pr.newReader(pkgbits.RelocName, idx, pkgbits.SyncObject1)
objPkg, objName := rname.qualifiedIdent() var objPkg *types.Package
assert(objName != "") var objName string
var tag pkgbits.CodeObj
{
rname := pr.tempReader(pkgbits.RelocName, idx, pkgbits.SyncObject1)
tag := pkgbits.CodeObj(rname.Code(pkgbits.SyncCodeObj)) objPkg, objName = rname.qualifiedIdent()
assert(objName != "")
tag = pkgbits.CodeObj(rname.Code(pkgbits.SyncCodeObj))
pr.retireReader(rname)
}
if tag == pkgbits.ObjStub { if tag == pkgbits.ObjStub {
assert(objPkg == nil || objPkg == types.Unsafe) assert(objPkg == nil || objPkg == types.Unsafe)
return objPkg, objName return objPkg, objName
} }
// Ignore local types promoted to global scope (#55110).
if _, suffix := splitVargenSuffix(objName); suffix != "" {
return objPkg, objName
}
if objPkg.Scope().Lookup(objName) == nil { if objPkg.Scope().Lookup(objName) == nil {
dict := pr.objDictIdx(idx) dict := pr.objDictIdx(idx)
@ -477,15 +559,56 @@ func (pr *pkgReader) objIdx(idx pkgbits.Index) (*types.Package, string) {
named.SetTypeParams(r.typeParamNames()) named.SetTypeParams(r.typeParamNames())
// TODO(mdempsky): Rewrite receiver types to underlying is an setUnderlying := func(underlying types.Type) {
// Interface? The go/types importer does this (I think because // If the underlying type is an interface, we need to
// unit tests expected that), but cmd/compile doesn't care // duplicate its methods so we can replace the receiver
// about it, so maybe we can avoid worrying about that here. // parameter's type (#49906).
rhs := r.typ() if iface, ok := underlying.(*types.Interface); ok && iface.NumExplicitMethods() != 0 {
r.p.later(func() { methods := make([]*types.Func, iface.NumExplicitMethods())
underlying := rhs.Underlying() for i := range methods {
fn := iface.ExplicitMethod(i)
sig := fn.Type().(*types.Signature)
recv := types.NewVar(fn.Pos(), fn.Pkg(), "", named)
methods[i] = types.NewFunc(fn.Pos(), fn.Pkg(), fn.Name(), types.NewSignature(recv, sig.Params(), sig.Results(), sig.Variadic()))
}
embeds := make([]types.Type, iface.NumEmbeddeds())
for i := range embeds {
embeds[i] = iface.EmbeddedType(i)
}
newIface := types.NewInterfaceType(methods, embeds)
r.p.ifaces = append(r.p.ifaces, newIface)
underlying = newIface
}
named.SetUnderlying(underlying) named.SetUnderlying(underlying)
}) }
// Since go.dev/cl/455279, we can assume rhs.Underlying() will
// always be non-nil. However, to temporarily support users of
// older snapshot releases, we continue to fallback to the old
// behavior for now.
//
// TODO(mdempsky): Remove fallback code and simplify after
// allowing time for snapshot users to upgrade.
rhs := r.typ()
if underlying := rhs.Underlying(); underlying != nil {
setUnderlying(underlying)
} else {
pk := r.p
pk.laterFor(named, func() {
// First be sure that the rhs is initialized, if it needs to be initialized.
delete(pk.laterFors, named) // prevent cycles
if i, ok := pk.laterFors[rhs]; ok {
f := pk.laterFns[i]
pk.laterFns[i] = func() {} // function is running now, so replace it with a no-op
f() // initialize RHS
}
setUnderlying(rhs.Underlying())
})
}
for i, n := 0, r.Len(); i < n; i++ { for i, n := 0, r.Len(); i < n; i++ {
named.AddMethod(r.method()) named.AddMethod(r.method())
@ -502,25 +625,28 @@ func (pr *pkgReader) objIdx(idx pkgbits.Index) (*types.Package, string) {
} }
func (pr *pkgReader) objDictIdx(idx pkgbits.Index) *readerDict { func (pr *pkgReader) objDictIdx(idx pkgbits.Index) *readerDict {
r := pr.newReader(pkgbits.RelocObjDict, idx, pkgbits.SyncObject1)
var dict readerDict var dict readerDict
if implicits := r.Len(); implicits != 0 { {
errorf("unexpected object with %v implicit type parameter(s)", implicits) r := pr.tempReader(pkgbits.RelocObjDict, idx, pkgbits.SyncObject1)
} if implicits := r.Len(); implicits != 0 {
errorf("unexpected object with %v implicit type parameter(s)", implicits)
}
dict.bounds = make([]typeInfo, r.Len()) dict.bounds = make([]typeInfo, r.Len())
for i := range dict.bounds { for i := range dict.bounds {
dict.bounds[i] = r.typInfo() dict.bounds[i] = r.typInfo()
} }
dict.derived = make([]derivedInfo, r.Len()) dict.derived = make([]derivedInfo, r.Len())
dict.derivedTypes = make([]types.Type, len(dict.derived)) dict.derivedTypes = make([]types.Type, len(dict.derived))
for i := range dict.derived { for i := range dict.derived {
dict.derived[i] = derivedInfo{r.Reloc(pkgbits.RelocType), r.Bool()} dict.derived[i] = derivedInfo{r.Reloc(pkgbits.RelocType), r.Bool()}
} }
pr.retireReader(r)
}
// function references follow, but reader doesn't need those // function references follow, but reader doesn't need those
return &dict return &dict

View File

@ -10,8 +10,10 @@ import (
"context" "context"
"fmt" "fmt"
"io" "io"
"log"
"os" "os"
"regexp" "regexp"
"runtime"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@ -232,6 +234,12 @@ func (i *Invocation) run(ctx context.Context, stdout, stderr io.Writer) error {
return runCmdContext(ctx, cmd) return runCmdContext(ctx, cmd)
} }
// DebugHangingGoCommands may be set by tests to enable additional
// instrumentation (including panics) for debugging hanging Go commands.
//
// See golang/go#54461 for details.
var DebugHangingGoCommands = false
// runCmdContext is like exec.CommandContext except it sends os.Interrupt // runCmdContext is like exec.CommandContext except it sends os.Interrupt
// before os.Kill. // before os.Kill.
func runCmdContext(ctx context.Context, cmd *exec.Cmd) error { func runCmdContext(ctx context.Context, cmd *exec.Cmd) error {
@ -243,11 +251,24 @@ func runCmdContext(ctx context.Context, cmd *exec.Cmd) error {
resChan <- cmd.Wait() resChan <- cmd.Wait()
}() }()
select { // If we're interested in debugging hanging Go commands, stop waiting after a
case err := <-resChan: // minute and panic with interesting information.
return err if DebugHangingGoCommands {
case <-ctx.Done(): select {
case err := <-resChan:
return err
case <-time.After(1 * time.Minute):
HandleHangingGoCommand(cmd.Process)
case <-ctx.Done():
}
} else {
select {
case err := <-resChan:
return err
case <-ctx.Done():
}
} }
// Cancelled. Interrupt and see if it ends voluntarily. // Cancelled. Interrupt and see if it ends voluntarily.
cmd.Process.Signal(os.Interrupt) cmd.Process.Signal(os.Interrupt)
select { select {
@ -255,11 +276,63 @@ func runCmdContext(ctx context.Context, cmd *exec.Cmd) error {
return err return err
case <-time.After(time.Second): case <-time.After(time.Second):
} }
// Didn't shut down in response to interrupt. Kill it hard. // Didn't shut down in response to interrupt. Kill it hard.
cmd.Process.Kill() // TODO(rfindley): per advice from bcmills@, it may be better to send SIGQUIT
// on certain platforms, such as unix.
if err := cmd.Process.Kill(); err != nil && DebugHangingGoCommands {
// Don't panic here as this reliably fails on windows with EINVAL.
log.Printf("error killing the Go command: %v", err)
}
// See above: don't wait indefinitely if we're debugging hanging Go commands.
if DebugHangingGoCommands {
select {
case err := <-resChan:
return err
case <-time.After(10 * time.Second): // a shorter wait as resChan should return quickly following Kill
HandleHangingGoCommand(cmd.Process)
}
}
return <-resChan return <-resChan
} }
func HandleHangingGoCommand(proc *os.Process) {
switch runtime.GOOS {
case "linux", "darwin", "freebsd", "netbsd":
fmt.Fprintln(os.Stderr, `DETECTED A HANGING GO COMMAND
The gopls test runner has detected a hanging go command. In order to debug
this, the output of ps and lsof/fstat is printed below.
See golang/go#54461 for more details.`)
fmt.Fprintln(os.Stderr, "\nps axo ppid,pid,command:")
fmt.Fprintln(os.Stderr, "-------------------------")
psCmd := exec.Command("ps", "axo", "ppid,pid,command")
psCmd.Stdout = os.Stderr
psCmd.Stderr = os.Stderr
if err := psCmd.Run(); err != nil {
panic(fmt.Sprintf("running ps: %v", err))
}
listFiles := "lsof"
if runtime.GOOS == "freebsd" || runtime.GOOS == "netbsd" {
listFiles = "fstat"
}
fmt.Fprintln(os.Stderr, "\n"+listFiles+":")
fmt.Fprintln(os.Stderr, "-----")
listFilesCmd := exec.Command(listFiles)
listFilesCmd.Stdout = os.Stderr
listFilesCmd.Stderr = os.Stderr
if err := listFilesCmd.Run(); err != nil {
panic(fmt.Sprintf("running %s: %v", listFiles, err))
}
}
panic(fmt.Sprintf("detected hanging go command (pid %d): see golang/go#54461 for more details", proc.Pid))
}
func cmdDebugStr(cmd *exec.Cmd) string { func cmdDebugStr(cmd *exec.Cmd) string {
env := make(map[string]string) env := make(map[string]string)
for _, kv := range cmd.Env { for _, kv := range cmd.Env {

View File

@ -7,11 +7,19 @@ package gocommand
import ( import (
"context" "context"
"fmt" "fmt"
"regexp"
"strings" "strings"
) )
// GoVersion checks the go version by running "go list" with modules off. // GoVersion reports the minor version number of the highest release
// It returns the X in Go 1.X. // tag built into the go command on the PATH.
//
// Note that this may be higher than the version of the go tool used
// to build this application, and thus the versions of the standard
// go/{scanner,parser,ast,types} packages that are linked into it.
// In that case, callers should either downgrade to the version of
// go used to build the application, or report an error that the
// application is too old to use the go command on the PATH.
func GoVersion(ctx context.Context, inv Invocation, r *Runner) (int, error) { func GoVersion(ctx context.Context, inv Invocation, r *Runner) (int, error) {
inv.Verb = "list" inv.Verb = "list"
inv.Args = []string{"-e", "-f", `{{context.ReleaseTags}}`, `--`, `unsafe`} inv.Args = []string{"-e", "-f", `{{context.ReleaseTags}}`, `--`, `unsafe`}
@ -38,7 +46,7 @@ func GoVersion(ctx context.Context, inv Invocation, r *Runner) (int, error) {
if len(stdout) < 3 { if len(stdout) < 3 {
return 0, fmt.Errorf("bad ReleaseTags output: %q", stdout) return 0, fmt.Errorf("bad ReleaseTags output: %q", stdout)
} }
// Split up "[go1.1 go1.15]" // Split up "[go1.1 go1.15]" and return highest go1.X value.
tags := strings.Fields(stdout[1 : len(stdout)-2]) tags := strings.Fields(stdout[1 : len(stdout)-2])
for i := len(tags) - 1; i >= 0; i-- { for i := len(tags) - 1; i >= 0; i-- {
var version int var version int
@ -49,3 +57,25 @@ func GoVersion(ctx context.Context, inv Invocation, r *Runner) (int, error) {
} }
return 0, fmt.Errorf("no parseable ReleaseTags in %v", tags) return 0, fmt.Errorf("no parseable ReleaseTags in %v", tags)
} }
// GoVersionOutput returns the complete output of the go version command.
func GoVersionOutput(ctx context.Context, inv Invocation, r *Runner) (string, error) {
inv.Verb = "version"
goVersion, err := r.Run(ctx, inv)
if err != nil {
return "", err
}
return goVersion.String(), nil
}
// ParseGoVersionOutput extracts the Go version string
// from the output of the "go version" command.
// Given an unrecognized form, it returns an empty string.
func ParseGoVersionOutput(data string) string {
re := regexp.MustCompile(`^go version (go\S+|devel \S+)`)
m := re.FindStringSubmatch(data)
if len(m) != 2 {
return "" // unrecognized version
}
return m[1]
}

View File

@ -697,6 +697,9 @@ func candidateImportName(pkg *pkg) string {
// GetAllCandidates calls wrapped for each package whose name starts with // GetAllCandidates calls wrapped for each package whose name starts with
// searchPrefix, and can be imported from filename with the package name filePkg. // searchPrefix, and can be imported from filename with the package name filePkg.
//
// Beware that the wrapped function may be called multiple times concurrently.
// TODO(adonovan): encapsulate the concurrency.
func GetAllCandidates(ctx context.Context, wrapped func(ImportFix), searchPrefix, filename, filePkg string, env *ProcessEnv) error { func GetAllCandidates(ctx context.Context, wrapped func(ImportFix), searchPrefix, filename, filePkg string, env *ProcessEnv) error {
callback := &scanCallback{ callback := &scanCallback{
rootFound: func(gopathwalk.Root) bool { rootFound: func(gopathwalk.Root) bool {
@ -796,7 +799,7 @@ func GetPackageExports(ctx context.Context, wrapped func(PackageExport), searchP
return getCandidatePkgs(ctx, callback, filename, filePkg, env) return getCandidatePkgs(ctx, callback, filename, filePkg, env)
} }
var RequiredGoEnvVars = []string{"GO111MODULE", "GOFLAGS", "GOINSECURE", "GOMOD", "GOMODCACHE", "GONOPROXY", "GONOSUMDB", "GOPATH", "GOPROXY", "GOROOT", "GOSUMDB", "GOWORK"} var requiredGoEnvVars = []string{"GO111MODULE", "GOFLAGS", "GOINSECURE", "GOMOD", "GOMODCACHE", "GONOPROXY", "GONOSUMDB", "GOPATH", "GOPROXY", "GOROOT", "GOSUMDB", "GOWORK"}
// ProcessEnv contains environment variables and settings that affect the use of // ProcessEnv contains environment variables and settings that affect the use of
// the go command, the go/build package, etc. // the go command, the go/build package, etc.
@ -807,6 +810,11 @@ type ProcessEnv struct {
ModFlag string ModFlag string
ModFile string ModFile string
// SkipPathInScan returns true if the path should be skipped from scans of
// the RootCurrentModule root type. The function argument is a clean,
// absolute path.
SkipPathInScan func(string) bool
// Env overrides the OS environment, and can be used to specify // Env overrides the OS environment, and can be used to specify
// GOPROXY, GO111MODULE, etc. PATH cannot be set here, because // GOPROXY, GO111MODULE, etc. PATH cannot be set here, because
// exec.Command will not honor it. // exec.Command will not honor it.
@ -861,7 +869,7 @@ func (e *ProcessEnv) init() error {
} }
foundAllRequired := true foundAllRequired := true
for _, k := range RequiredGoEnvVars { for _, k := range requiredGoEnvVars {
if _, ok := e.Env[k]; !ok { if _, ok := e.Env[k]; !ok {
foundAllRequired = false foundAllRequired = false
break break
@ -877,7 +885,7 @@ func (e *ProcessEnv) init() error {
} }
goEnv := map[string]string{} goEnv := map[string]string{}
stdout, err := e.invokeGo(context.TODO(), "env", append([]string{"-json"}, RequiredGoEnvVars...)...) stdout, err := e.invokeGo(context.TODO(), "env", append([]string{"-json"}, requiredGoEnvVars...)...)
if err != nil { if err != nil {
return err return err
} }
@ -1367,9 +1375,9 @@ func (r *gopathResolver) scan(ctx context.Context, callback *scanCallback) error
return err return err
} }
var roots []gopathwalk.Root var roots []gopathwalk.Root
roots = append(roots, gopathwalk.Root{filepath.Join(goenv["GOROOT"], "src"), gopathwalk.RootGOROOT}) roots = append(roots, gopathwalk.Root{Path: filepath.Join(goenv["GOROOT"], "src"), Type: gopathwalk.RootGOROOT})
for _, p := range filepath.SplitList(goenv["GOPATH"]) { for _, p := range filepath.SplitList(goenv["GOPATH"]) {
roots = append(roots, gopathwalk.Root{filepath.Join(p, "src"), gopathwalk.RootGOPATH}) roots = append(roots, gopathwalk.Root{Path: filepath.Join(p, "src"), Type: gopathwalk.RootGOPATH})
} }
// The callback is not necessarily safe to use in the goroutine below. Process roots eagerly. // The callback is not necessarily safe to use in the goroutine below. Process roots eagerly.
roots = filterRoots(roots, callback.rootFound) roots = filterRoots(roots, callback.rootFound)

View File

@ -129,22 +129,22 @@ func (r *ModuleResolver) init() error {
}) })
r.roots = []gopathwalk.Root{ r.roots = []gopathwalk.Root{
{filepath.Join(goenv["GOROOT"], "/src"), gopathwalk.RootGOROOT}, {Path: filepath.Join(goenv["GOROOT"], "/src"), Type: gopathwalk.RootGOROOT},
} }
r.mainByDir = make(map[string]*gocommand.ModuleJSON) r.mainByDir = make(map[string]*gocommand.ModuleJSON)
for _, main := range r.mains { for _, main := range r.mains {
r.roots = append(r.roots, gopathwalk.Root{main.Dir, gopathwalk.RootCurrentModule}) r.roots = append(r.roots, gopathwalk.Root{Path: main.Dir, Type: gopathwalk.RootCurrentModule})
r.mainByDir[main.Dir] = main r.mainByDir[main.Dir] = main
} }
if vendorEnabled { if vendorEnabled {
r.roots = append(r.roots, gopathwalk.Root{r.dummyVendorMod.Dir, gopathwalk.RootOther}) r.roots = append(r.roots, gopathwalk.Root{Path: r.dummyVendorMod.Dir, Type: gopathwalk.RootOther})
} else { } else {
addDep := func(mod *gocommand.ModuleJSON) { addDep := func(mod *gocommand.ModuleJSON) {
if mod.Replace == nil { if mod.Replace == nil {
// This is redundant with the cache, but we'll skip it cheaply enough. // This is redundant with the cache, but we'll skip it cheaply enough.
r.roots = append(r.roots, gopathwalk.Root{mod.Dir, gopathwalk.RootModuleCache}) r.roots = append(r.roots, gopathwalk.Root{Path: mod.Dir, Type: gopathwalk.RootModuleCache})
} else { } else {
r.roots = append(r.roots, gopathwalk.Root{mod.Dir, gopathwalk.RootOther}) r.roots = append(r.roots, gopathwalk.Root{Path: mod.Dir, Type: gopathwalk.RootOther})
} }
} }
// Walk dependent modules before scanning the full mod cache, direct deps first. // Walk dependent modules before scanning the full mod cache, direct deps first.
@ -158,7 +158,7 @@ func (r *ModuleResolver) init() error {
addDep(mod) addDep(mod)
} }
} }
r.roots = append(r.roots, gopathwalk.Root{r.moduleCacheDir, gopathwalk.RootModuleCache}) r.roots = append(r.roots, gopathwalk.Root{Path: r.moduleCacheDir, Type: gopathwalk.RootModuleCache})
} }
r.scannedRoots = map[gopathwalk.Root]bool{} r.scannedRoots = map[gopathwalk.Root]bool{}
@ -466,6 +466,16 @@ func (r *ModuleResolver) scan(ctx context.Context, callback *scanCallback) error
// We assume cached directories are fully cached, including all their // We assume cached directories are fully cached, including all their
// children, and have not changed. We can skip them. // children, and have not changed. We can skip them.
skip := func(root gopathwalk.Root, dir string) bool { skip := func(root gopathwalk.Root, dir string) bool {
if r.env.SkipPathInScan != nil && root.Type == gopathwalk.RootCurrentModule {
if root.Path == dir {
return false
}
if r.env.SkipPathInScan(filepath.Clean(dir)) {
return true
}
}
info, ok := r.cacheLoad(dir) info, ok := r.cacheLoad(dir)
if !ok { if !ok {
return false return false

View File

@ -52,6 +52,7 @@ func sortImports(localPrefix string, tokFile *token.File, f *ast.File) {
d.Specs = specs d.Specs = specs
// Deduping can leave a blank line before the rparen; clean that up. // Deduping can leave a blank line before the rparen; clean that up.
// Ignore line directives.
if len(d.Specs) > 0 { if len(d.Specs) > 0 {
lastSpec := d.Specs[len(d.Specs)-1] lastSpec := d.Specs[len(d.Specs)-1]
lastLine := tokFile.PositionFor(lastSpec.Pos(), false).Line lastLine := tokFile.PositionFor(lastSpec.Pos(), false).Line

File diff suppressed because it is too large Load Diff

View File

@ -6,9 +6,11 @@ package pkgbits
import ( import (
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"go/constant" "go/constant"
"go/token" "go/token"
"io"
"math/big" "math/big"
"os" "os"
"runtime" "runtime"
@ -51,6 +53,8 @@ type PkgDecoder struct {
// For example, section K's end positions start at elemEndsEnds[K-1] // For example, section K's end positions start at elemEndsEnds[K-1]
// (or 0, if K==0) and end at elemEndsEnds[K]. // (or 0, if K==0) and end at elemEndsEnds[K].
elemEndsEnds [numRelocs]uint32 elemEndsEnds [numRelocs]uint32
scratchRelocEnt []RelocEnt
} }
// PkgPath returns the package path for the package // PkgPath returns the package path for the package
@ -94,7 +98,7 @@ func NewPkgDecoder(pkgPath, input string) PkgDecoder {
pr.elemEnds = make([]uint32, pr.elemEndsEnds[len(pr.elemEndsEnds)-1]) pr.elemEnds = make([]uint32, pr.elemEndsEnds[len(pr.elemEndsEnds)-1])
assert(binary.Read(r, binary.LittleEndian, pr.elemEnds[:]) == nil) assert(binary.Read(r, binary.LittleEndian, pr.elemEnds[:]) == nil)
pos, err := r.Seek(0, os.SEEK_CUR) pos, err := r.Seek(0, io.SeekCurrent)
assert(err == nil) assert(err == nil)
pr.elemData = input[pos:] pr.elemData = input[pos:]
@ -164,6 +168,21 @@ func (pr *PkgDecoder) NewDecoder(k RelocKind, idx Index, marker SyncMarker) Deco
return r return r
} }
// TempDecoder returns a Decoder for the given (section, index) pair,
// and decodes the given SyncMarker from the element bitstream.
// If possible the Decoder should be RetireDecoder'd when it is no longer
// needed, this will avoid heap allocations.
func (pr *PkgDecoder) TempDecoder(k RelocKind, idx Index, marker SyncMarker) Decoder {
r := pr.TempDecoderRaw(k, idx)
r.Sync(marker)
return r
}
func (pr *PkgDecoder) RetireDecoder(d *Decoder) {
pr.scratchRelocEnt = d.Relocs
d.Relocs = nil
}
// NewDecoderRaw returns a Decoder for the given (section, index) pair. // NewDecoderRaw returns a Decoder for the given (section, index) pair.
// //
// Most callers should use NewDecoder instead. // Most callers should use NewDecoder instead.
@ -187,6 +206,30 @@ func (pr *PkgDecoder) NewDecoderRaw(k RelocKind, idx Index) Decoder {
return r return r
} }
func (pr *PkgDecoder) TempDecoderRaw(k RelocKind, idx Index) Decoder {
r := Decoder{
common: pr,
k: k,
Idx: idx,
}
r.Data.Reset(pr.DataIdx(k, idx))
r.Sync(SyncRelocs)
l := r.Len()
if cap(pr.scratchRelocEnt) >= l {
r.Relocs = pr.scratchRelocEnt[:l]
pr.scratchRelocEnt = nil
} else {
r.Relocs = make([]RelocEnt, l)
}
for i := range r.Relocs {
r.Sync(SyncReloc)
r.Relocs[i] = RelocEnt{RelocKind(r.Len()), Index(r.Len())}
}
return r
}
// A Decoder provides methods for decoding an individual element's // A Decoder provides methods for decoding an individual element's
// bitstream data. // bitstream data.
type Decoder struct { type Decoder struct {
@ -206,11 +249,39 @@ func (r *Decoder) checkErr(err error) {
} }
func (r *Decoder) rawUvarint() uint64 { func (r *Decoder) rawUvarint() uint64 {
x, err := binary.ReadUvarint(&r.Data) x, err := readUvarint(&r.Data)
r.checkErr(err) r.checkErr(err)
return x return x
} }
// readUvarint is a type-specialized copy of encoding/binary.ReadUvarint.
// This avoids the interface conversion and thus has better escape properties,
// which flows up the stack.
func readUvarint(r *strings.Reader) (uint64, error) {
var x uint64
var s uint
for i := 0; i < binary.MaxVarintLen64; i++ {
b, err := r.ReadByte()
if err != nil {
if i > 0 && err == io.EOF {
err = io.ErrUnexpectedEOF
}
return x, err
}
if b < 0x80 {
if i == binary.MaxVarintLen64-1 && b > 1 {
return x, overflow
}
return x | uint64(b)<<s, nil
}
x |= uint64(b&0x7f) << s
s += 7
}
return x, overflow
}
var overflow = errors.New("pkgbits: readUvarint overflows a 64-bit integer")
func (r *Decoder) rawVarint() int64 { func (r *Decoder) rawVarint() int64 {
ux := r.rawUvarint() ux := r.rawUvarint()
@ -237,7 +308,7 @@ func (r *Decoder) Sync(mWant SyncMarker) {
return return
} }
pos, _ := r.Data.Seek(0, os.SEEK_CUR) // TODO(mdempsky): io.SeekCurrent after #44505 is resolved pos, _ := r.Data.Seek(0, io.SeekCurrent)
mHave := SyncMarker(r.rawUvarint()) mHave := SyncMarker(r.rawUvarint())
writerPCs := make([]int, r.rawUvarint()) writerPCs := make([]int, r.rawUvarint())
for i := range writerPCs { for i := range writerPCs {
@ -302,7 +373,7 @@ func (r *Decoder) Int64() int64 {
return r.rawVarint() return r.rawVarint()
} }
// Int64 decodes and returns a uint64 value from the element bitstream. // Uint64 decodes and returns a uint64 value from the element bitstream.
func (r *Decoder) Uint64() uint64 { func (r *Decoder) Uint64() uint64 {
r.Sync(SyncUint64) r.Sync(SyncUint64)
return r.rawUvarint() return r.rawUvarint()
@ -409,8 +480,12 @@ func (r *Decoder) bigFloat() *big.Float {
// PeekPkgPath returns the package path for the specified package // PeekPkgPath returns the package path for the specified package
// index. // index.
func (pr *PkgDecoder) PeekPkgPath(idx Index) string { func (pr *PkgDecoder) PeekPkgPath(idx Index) string {
r := pr.NewDecoder(RelocPkg, idx, SyncPkgDef) var path string
path := r.String() {
r := pr.TempDecoder(RelocPkg, idx, SyncPkgDef)
path = r.String()
pr.RetireDecoder(&r)
}
if path == "" { if path == "" {
path = pr.pkgPath path = pr.pkgPath
} }
@ -420,14 +495,23 @@ func (pr *PkgDecoder) PeekPkgPath(idx Index) string {
// PeekObj returns the package path, object name, and CodeObj for the // PeekObj returns the package path, object name, and CodeObj for the
// specified object index. // specified object index.
func (pr *PkgDecoder) PeekObj(idx Index) (string, string, CodeObj) { func (pr *PkgDecoder) PeekObj(idx Index) (string, string, CodeObj) {
r := pr.NewDecoder(RelocName, idx, SyncObject1) var ridx Index
r.Sync(SyncSym) var name string
r.Sync(SyncPkg) var rcode int
path := pr.PeekPkgPath(r.Reloc(RelocPkg)) {
name := r.String() r := pr.TempDecoder(RelocName, idx, SyncObject1)
r.Sync(SyncSym)
r.Sync(SyncPkg)
ridx = r.Reloc(RelocPkg)
name = r.String()
rcode = r.Code(SyncCodeObj)
pr.RetireDecoder(&r)
}
path := pr.PeekPkgPath(ridx)
assert(name != "") assert(name != "")
tag := CodeObj(r.Code(SyncCodeObj)) tag := CodeObj(rcode)
return path, name, tag return path, name, tag
} }

View File

@ -147,8 +147,9 @@ func (pw *PkgEncoder) NewEncoderRaw(k RelocKind) Encoder {
type Encoder struct { type Encoder struct {
p *PkgEncoder p *PkgEncoder
Relocs []RelocEnt Relocs []RelocEnt
Data bytes.Buffer // accumulated element bitstream data RelocMap map[RelocEnt]uint32
Data bytes.Buffer // accumulated element bitstream data
encodingRelocHeader bool encodingRelocHeader bool
@ -210,15 +211,18 @@ func (w *Encoder) rawVarint(x int64) {
} }
func (w *Encoder) rawReloc(r RelocKind, idx Index) int { func (w *Encoder) rawReloc(r RelocKind, idx Index) int {
// TODO(mdempsky): Use map for lookup; this takes quadratic time. e := RelocEnt{r, idx}
for i, rEnt := range w.Relocs { if w.RelocMap != nil {
if rEnt.Kind == r && rEnt.Idx == idx { if i, ok := w.RelocMap[e]; ok {
return i return int(i)
} }
} else {
w.RelocMap = make(map[RelocEnt]uint32)
} }
i := len(w.Relocs) i := len(w.Relocs)
w.Relocs = append(w.Relocs, RelocEnt{r, idx}) w.RelocMap[e] = uint32(i)
w.Relocs = append(w.Relocs, e)
return i return i
} }
@ -289,7 +293,7 @@ func (w *Encoder) Len(x int) { assert(x >= 0); w.Uint64(uint64(x)) }
// Int encodes and writes an int value into the element bitstream. // Int encodes and writes an int value into the element bitstream.
func (w *Encoder) Int(x int) { w.Int64(int64(x)) } func (w *Encoder) Int(x int) { w.Int64(int64(x)) }
// Len encodes and writes a uint value into the element bitstream. // Uint encodes and writes a uint value into the element bitstream.
func (w *Encoder) Uint(x uint) { w.Uint64(uint64(x)) } func (w *Encoder) Uint(x uint) { w.Uint64(uint64(x)) }
// Reloc encodes and writes a relocation for the given (section, // Reloc encodes and writes a relocation for the given (section,

Some files were not shown because too many files have changed in this diff Show More