xs/vendor/github.com/kuking/go-frodokem/impl.go

536 lines
14 KiB
Go
Raw Permalink Normal View History

package go_frodokem
import (
"crypto/aes"
"encoding/binary"
"errors"
"golang.org/x/crypto/sha3"
"math"
)
// Returns the name of this particular FrodoKEM variant, i.e. Frodo640AES
func (k *FrodoKEM) Name() string {
return k.name
}
// Returns the shared secret (in bytes) this variant generates
func (k *FrodoKEM) SharedSecretLen() int {
return k.lenSS / 8
}
// Returns the public key length (in bytes) for this variant
func (k *FrodoKEM) PublicKeyLen() int {
return k.lenPkBytes
}
// Returns the secret key length (in bytes) for this variant
func (k *FrodoKEM) SecretKeyLen() int {
return k.lenSkBytes
}
// Returns the cipher-text length (in bytes) encapsulating the shared secret for this variant
func (k *FrodoKEM) CipherTextLen() int {
return k.lenCtBytes
}
// Generate a key-pair
func (k *FrodoKEM) Keygen() (pk []uint8, sk []uint8) {
sSeedSEz := make([]byte, k.lenS/8+k.lenSeedSE/8+k.lenZ/8)
k.rng(sSeedSEz) // fmt.Println("randomness(", len(sSeedSEz), ")", strings.ToUpper(hex.EncodeToString(sSeedSEz)))
s := sSeedSEz[0 : k.lenS/8]
seedSE := sSeedSEz[k.lenS/8 : k.lenS/8+k.lenSeedSE/8] // fmt.Println("seedSE", hex.EncodeToString(seedSE))
z := sSeedSEz[k.lenS/8+k.lenSeedSE/8 : k.lenS/8+k.lenSeedSE/8+k.lenZ/8]
seedA := k.shake(z, k.lenSeedA/8) // fmt.Println("seedA(", len(seedA), ")", strings.ToUpper(hex.EncodeToString(seedA)))
A := k.gen(seedA)
r := unpackUint16(k.shake(append([]byte{0x5f}, seedSE...), 2*k.n*k.nBar*k.lenChi/8)) //fmt.Println("r(", len(r), ")", r)
Stransposed := k.sampleMatrix(r[0:k.n*k.nBar], k.nBar, k.n) //fmt.Println("S^T", Stransposed)
S := matrixTranspose(Stransposed)
E := k.sampleMatrix(r[k.n*k.nBar:2*k.n*k.nBar], k.n, k.nBar)
B := matrixAddWithMod(matrixMulWithMod(A, S, k.q), E, k.q)
b := k.pack(B) // fmt.Println("b", hex.EncodeToString(b))
pk = append(seedA, b...)
pkh := k.shake(pk, k.lenPkh/8) // fmt.Println("pkh", strings.ToUpper(hex.EncodeToString(pkh)))
stb := make([]uint8, len(Stransposed)*len(Stransposed[0])*2)
stbI := 0
for i := 0; i < len(Stransposed); i++ {
for j := 0; j < len(Stransposed[i]); j++ {
stb[stbI] = uint8(Stransposed[i][j] & 0xff)
stbI++
stb[stbI] = uint8(Stransposed[i][j] >> 8)
stbI++
}
}
sk = append(s, seedA...)
sk = append(sk, b...)
sk = append(sk, stb...)
sk = append(sk, pkh...)
return
}
// Generate a KEM returning the cipher-text and shared-secret
func (k *FrodoKEM) Encapsulate(pk []uint8) (ct []uint8, ssEnc []uint8, err error) {
if len(pk) != k.lenSeedA/8+k.d*k.n*k.nBar/8 {
err = errors.New("incorrect public key length")
return
}
seedA := pk[0 : k.lenSeedA/8]
b := pk[k.lenSeedA/8:]
mu := make([]uint8, k.lenMu/8)
k.rng(mu)
//fmt.Println("seedA", hex.EncodeToString(seedA))
//fmt.Println("b", hex.EncodeToString(b))
//fmt.Println("mu", hex.EncodeToString(mu))
pkh := k.shake(pk, k.lenPkh/8) // fmt.Println("pkh", hex.EncodeToString(pkh))
seedSE_k := k.shake(append(pkh, mu...), k.lenSeedSE/8+k.lenK/8)
seedSE := seedSE_k[0 : k.lenSeedSE/8]
_k := seedSE_k[k.lenSeedSE/8 : k.lenSeedSE/8+k.lenK/8]
r := unpackUint16(k.shake(append([]byte{0x96}, seedSE...), (2*k.mBar*k.n*k.mBar*k.mBar)*k.lenChi/8))
Sprime := k.sampleMatrix(r[0:k.mBar*k.n], k.mBar, k.n) // fmt.Println("S'", Sprime)
Eprime := k.sampleMatrix(r[k.mBar*k.n:2*k.mBar*k.n], k.mBar, k.n) // fmt.Println("E'", Eprime)
A := k.gen(seedA)
Bprime := matrixAddWithMod(matrixMulWithMod2(Sprime, A, k.q), Eprime, k.q) // fmt.Println("b'", Bprime)
c1 := k.pack(Bprime) // fmt.Println("c1", hex.EncodeToString(c1))
Eprimeprime := k.sampleMatrix(r[2*k.mBar*k.n:2*k.mBar*k.n+k.mBar*k.nBar], k.mBar, k.nBar) // fmt.Println("E''", Eprimeprime)
B := k.unpack(b, k.n, k.nBar)
V := matrixAddWithMod(matrixMulWithMod2(Sprime, B, k.q), Eprimeprime, k.q)
C := uMatrixAdd(V, k.encode(mu), k.q)
c2 := k.pack(C) // fmt.Println("c2", hex.EncodeToString(c2))
ct = append(c1, c2...)
ssEnc = k.shake(append(ct, _k...), k.lenSS/8)
return
}
// Returns the shared secret by using the provided cipher-text and secret-key
func (k *FrodoKEM) Dencapsulate(sk []uint8, ct []uint8) (ssDec []uint8, err error) {
if len(ct) != k.lenCtBytes {
err = errors.New("incorrect cipher length")
return
}
if len(sk) != k.lenSkBytes {
err = errors.New("incorrect secret key length")
return
}
c1, c2 := k.unwrapCt(ct)
s, seedA, b, Stransposed, pkh := k.unwrapSk(sk)
S := matrixTranspose(Stransposed)
Bprime := k.unpack(c1, k.mBar, k.n)
C := k.unpack(c2, k.mBar, k.nBar)
BprimeS := matrixMulWithMod(Bprime, S, k.q)
M := matrixSubWithMod(C, BprimeS, k.q)
muPrime := k.decode(M) // fmt.Println("mu'", hex.EncodeToString(muPrime))
seedSEprime_kprime := k.shake(append(pkh, muPrime...), k.lenSeedSE/8+k.lenK/8)
seedSEprime := seedSEprime_kprime[0 : k.lenSeedSE/8] // fmt.Println("seedSE'", hex.EncodeToString(seedSEprime))
kprime := seedSEprime_kprime[k.lenSeedSE/8:] // fmt.Println("k'", hex.EncodeToString(kprime))
r := unpackUint16(k.shake(append([]byte{0x96}, seedSEprime...), (2*k.mBar*k.n+k.mBar*k.mBar)*k.lenChi/8)) // fmt.Println("r", r)
Sprime := k.sampleMatrix(r[0:k.mBar*k.n], k.mBar, k.n)
Eprime := k.sampleMatrix(r[k.mBar*k.n:2*k.mBar*k.n], k.mBar, k.n)
A := k.gen(seedA)
Bprimeprime := matrixAddWithMod(matrixMulWithMod2(Sprime, A, k.q), Eprime, k.q)
Eprimeprime := k.sampleMatrix(r[2*k.mBar*k.n:2*k.mBar*k.n+k.mBar*k.nBar], k.mBar, k.nBar)
B := k.unpack(b, k.n, k.nBar)
V := matrixAddWithMod(matrixMulWithMod2(Sprime, B, k.q), Eprimeprime, k.q)
Cprime := uMatrixAdd(V, k.encode(muPrime), k.q)
if constantUint16Equals(Bprime, Bprimeprime)+constantUint16Equals(C, Cprime) == 2 {
ssDec = k.shake(append(ct, kprime...), k.lenSS/8)
} else {
ssDec = k.shake(append(ct, s...), k.lenSS/8)
}
return
}
// Overrides the default random number generator (crypto/rand)
func (k *FrodoKEM) OverrideRng(newRng func([]byte)) {
k.rng = newRng
}
func (k *FrodoKEM) unwrapCt(ct []uint8) (c1 []uint8, c2 []uint8) {
ofs := 0
size := k.mBar * k.n * k.d / 8
c1 = ct[ofs:size] // fmt.Println("c1", hex.EncodeToString(c1))
ofs += size
size = k.mBar * k.mBar * k.d / 8
c2 = ct[ofs : ofs+size] // fmt.Println("c2", hex.EncodeToString(c2))
return
}
func (k *FrodoKEM) unwrapSk(sk []uint8) (s []uint8, seedA []uint8, b []uint8, Stransposed [][]int16, pkh []uint8) {
ofs := 0
size := k.lenS / 8
s = sk[ofs:size] // fmt.Println("s", hex.EncodeToString(s))
ofs += size
size = k.lenSeedA / 8
seedA = sk[ofs : ofs+size] // fmt.Println("seedA", hex.EncodeToString(seedA))
ofs += size
size = k.d * k.n * k.nBar / 8
b = sk[ofs : ofs+size] // fmt.Println("b", hex.EncodeToString(b))
ofs += size
size = k.n * k.nBar * 2
Sbytes := sk[ofs : ofs+size]
idx := 0
Stransposed = make([][]int16, k.nBar)
for i := 0; i < k.nBar; i++ {
Stransposed[i] = make([]int16, k.n)
for j := 0; j < k.n; j++ {
Stransposed[i][j] = int16(Sbytes[idx])
idx++
Stransposed[i][j] |= int16(Sbytes[idx]) << 8
idx++
}
}
// fmt.Println("S^T", Stransposed)
ofs += size
size = k.lenPkh / 8
pkh = sk[ofs : ofs+size] // fmt.Println("pkh", hex.EncodeToString(pkh))
return
}
func (k *FrodoKEM) sample(r uint16) (e int16) {
t := int(r >> 1)
e = 0
for z := 0; z < len(k.tChi)-1; z++ {
if t > int(k.tChi[z]) {
e += 1
}
}
r0 := r % 2
if r0 == 1 {
e = -e
}
return
}
func (k *FrodoKEM) sampleMatrix(r []uint16, n1 int, n2 int) (E [][]int16) {
E = make([][]int16, n1)
for i := 0; i < n1; i++ {
E[i] = make([]int16, n2)
for j := 0; j < n2; j++ {
E[i][j] = k.sample(r[i*n2+j])
}
}
return E
}
// FrodoKEM specification, Algorithm 3: Frodo.Pack
func (k *FrodoKEM) pack(C [][]uint16) (r []byte) {
rows := len(C)
cols := len(C[0])
r = make([]byte, k.d*rows*cols/8)
var ri = 0
var packed uint8
var bits uint8
for i := 0; i < rows; i++ {
for j := 0; j < cols; j++ {
val := C[i][j]
for b := 0; b < k.d; b++ {
packed <<= 1
packed |= uint16BitN(val, k.d-b-1)
if bits++; bits == 8 {
r[ri] = packed
ri++
packed = 0
bits = 0
}
}
}
}
if bits != 0 {
r[ri] = packed
}
return r
}
// FrodoKEM specification, Algorithm 4: Frodo.Unpack
func (k *FrodoKEM) unpack(b []uint8, n1 int, n2 int) (C [][]uint16) {
bIdx := 0
BBit := 7
C = make([][]uint16, n1)
for i := 0; i < n1; i++ {
C[i] = make([]uint16, n2)
for j := 0; j < n2; j++ {
var val uint16
for l := 0; l < k.d; l++ {
val <<= 1
val |= uint16(uint8BitN(b[bIdx], BBit))
if BBit--; BBit < 0 {
BBit = 7
bIdx++
}
}
C[i][j] = val
}
}
return
}
// FrodoKEM specification, Algorithm 1
func (k *FrodoKEM) encode(b []uint8) (K [][]uint16) {
multiplier := int(k.q)
if multiplier == 0 {
multiplier = 65536
}
if k.b > 0 {
multiplier /= 2 << (k.b - 1)
}
bIdx := 0
BBit := 0
K = make([][]uint16, k.mBar)
for i := 0; i < k.mBar; i++ {
K[i] = make([]uint16, k.nBar)
for j := 0; j < k.nBar; j++ {
var val uint16
for l := 0; l < k.b; l++ {
val |= uint16(uint8BitN(b[bIdx], BBit)) << l
if BBit++; BBit > 7 {
BBit = 0
bIdx++
}
}
K[i][j] = val * uint16(multiplier)
}
}
return
}
// FrodoKEM specification, Algorithm 2
func (k *FrodoKEM) decode(K [][]uint16) (b []uint8) {
b = make([]uint8, k.b*k.mBar*k.nBar/8)
fixedQ := float64(k.q)
if k.q == 0 {
fixedQ = float64(65535)
}
twoPowerB := int32(2 << (k.b - 1))
twoPowerBf := float64(int(2 << (k.b - 1)))
bIdx := 0
BBit := 0
for i := 0; i < k.mBar; i++ {
for j := 0; j < k.nBar; j++ {
tmp := uint8(int32(math.Round(float64(K[i][j])*twoPowerBf/fixedQ)) % twoPowerB) //FIXME: please do this better
for l := 0; l < k.b; l++ {
if uint8BitN(tmp, l) == 1 {
b[bIdx] = uint8setBitN(b[bIdx], BBit)
}
BBit++
if BBit == 8 {
bIdx++
BBit = 0
}
}
}
}
return
}
func (k *FrodoKEM) genSHAKE128(seedA []byte) (A [][]uint16) {
var c = make([]byte, 2*k.n)
var tmp = make([]byte, 2+len(seedA))
copy(tmp[2:], seedA)
A = make([][]uint16, k.n)
for i := 0; i < k.n; i++ {
A[i] = make([]uint16, k.n)
binary.LittleEndian.PutUint16(tmp[0:], uint16(i))
sha3.ShakeSum128(c, tmp)
for j := 0; j < k.n; j++ {
A[i][j] = binary.LittleEndian.Uint16(c[j*2 : (j+1)*2])
if k.q != 0 {
A[i][j] %= k.q
}
}
}
return
}
func (k *FrodoKEM) genAES128(seedA []byte) (A [][]uint16) {
A = make([][]uint16, k.n)
cipher, err := aes.NewCipher(seedA)
if err != nil {
panic(err)
}
var b = [16]byte{}
var c = [16]byte{}
for i := 0; i < k.n; i++ {
A[i] = make([]uint16, k.n)
for j := 0; j < k.n; j += 8 {
binary.LittleEndian.PutUint16(b[0:2], uint16(i))
binary.LittleEndian.PutUint16(b[2:4], uint16(j))
cipher.Encrypt(c[:], b[:])
for l := 0; l < 8; l++ {
A[i][j+l] = binary.LittleEndian.Uint16(c[l*2 : (l+1)*2])
if k.q != 0 {
A[i][j+l] %= k.q
}
}
}
}
return
}
// constant time [][]uint16 equals, 1=true, 0=false
func constantUint16Equals(a [][]uint16, b [][]uint16) (ret int) {
ret = 1
if len(a) != len(b) {
panic("Can not compare matrices of different size")
}
for i := 0; i < len(a); i++ {
if len(a[i]) != len(b[i]) {
panic("Can not compare matrices of different size")
}
for j := 0; j < len(a[i]); j++ {
if a[i][j] != b[i][j] {
ret = 0
}
}
}
return
}
func matrixAddWithMod(X [][]uint16, Y [][]int16, q uint16) (R [][]uint16) {
nrowsx := len(X)
ncolsx := len(X[0])
nrowsy := len(Y)
ncolsy := len(Y[0])
if nrowsx != nrowsy || ncolsx != ncolsy {
panic("can't add these matrices")
}
R = make([][]uint16, nrowsx)
for i := 0; i < nrowsx; i++ {
R[i] = make([]uint16, ncolsx)
for j := 0; j < ncolsx; j++ {
R[i][j] = uint16(int16(X[i][j]) + Y[i][j])
if q != 0 {
R[i][j] %= q
}
}
}
return
}
func uMatrixAdd(X [][]uint16, Y [][]uint16, q uint16) (R [][]uint16) {
nrowsx := len(X)
ncolsx := len(X[0])
nrowsy := len(Y)
ncolsy := len(Y[0])
if nrowsx != nrowsy || ncolsx != ncolsy {
panic("can't add these matrices")
}
R = make([][]uint16, nrowsx)
for i := 0; i < nrowsx; i++ {
R[i] = make([]uint16, ncolsx)
for j := 0; j < ncolsx; j++ {
R[i][j] = X[i][j] + Y[i][j]
if q != 0 {
R[i][j] %= q
}
}
}
return
}
func matrixSubWithMod(X [][]uint16, Y [][]uint16, q uint16) (R [][]uint16) {
nrowsx := len(X)
ncolsx := len(X[0])
nrowsy := len(Y)
ncolsy := len(Y[0])
if nrowsx != nrowsy || ncolsx != ncolsy {
panic("can't sub these matrices")
}
R = make([][]uint16, nrowsx)
for i := 0; i < nrowsx; i++ {
R[i] = make([]uint16, ncolsx)
for j := 0; j < ncolsx; j++ {
R[i][j] = X[i][j] - Y[i][j]
if q != 0 {
R[i][j] %= q
}
}
}
return
}
func matrixMulWithMod(X [][]uint16, Y [][]int16, q uint16) (R [][]uint16) {
nrowsx := len(X)
ncolsx := len(X[0])
//nrowsy := len(y)
ncolsy := len(Y[0])
R = make([][]uint16, nrowsx)
for i := 0; i < nrowsx; i++ {
R[i] = make([]uint16, ncolsy)
for j := 0; j < ncolsy; j++ {
var res uint16
for k := 0; k < ncolsx; k++ {
res += uint16(int16(X[i][k]) * Y[k][j])
}
if q != 0 {
res %= q
}
R[i][j] = res
}
}
return
}
func matrixMulWithMod2(X [][]int16, Y [][]uint16, q uint16) (R [][]uint16) {
nrowsx := len(X)
ncolsx := len(X[0])
//nrowsy := len(y)
ncolsy := len(Y[0])
R = make([][]uint16, nrowsx)
for i := 0; i < nrowsx; i++ {
R[i] = make([]uint16, ncolsy)
for j := 0; j < ncolsy; j++ {
var res uint16
for k := 0; k < ncolsx; k++ {
res += uint16(X[i][k] * int16(Y[k][j]))
}
if q != 0 {
res %= q
}
R[i][j] = res
}
}
return
}
func matrixTranspose(O [][]int16) (T [][]int16) {
T = make([][]int16, len(O[0]))
for x := 0; x < len(T); x++ {
T[x] = make([]int16, len(O))
for y := 0; y < len(O); y++ {
T[x][y] = O[y][x]
}
}
return
}
func unpackUint16(bytes []byte) (r []uint16) {
r = make([]uint16, len(bytes)/2)
j := 0
for i := 0; i+1 < len(bytes); i += 2 {
r[j] = binary.LittleEndian.Uint16(bytes[i : i+2])
j++
}
return r
}
func uint8setBitN(val uint8, i int) uint8 {
return val | (1 << i)
}
func uint16BitN(val uint16, i int) uint8 {
return uint8((val >> i) & 1)
}
func uint8BitN(val uint8, i int) uint8 {
return (val >> i) & 1
}