Skip to content

Instantly share code, notes, and snippets.

@sify21
Last active September 30, 2021 10:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sify21/ade08319e214eef6560e860b7530424c to your computer and use it in GitHub Desktop.
Save sify21/ade08319e214eef6560e860b7530424c to your computer and use it in GitHub Desktop.
aes-gcm-siv implementation in go. A direct translation of https://github.com/bjornedstrom/aes-gcm-siv-py
package aead
import (
"crypto/aes"
"encoding/binary"
"errors"
"math/big"
)
var _mod = big.NewInt(0)
var _inv = big.NewInt(0)
var _max = big.NewInt(0)
var _one = big.NewInt(1)
var _zero = big.NewInt(0)
func init() {
for _, i := range []uint{0, 121, 126, 127, 128} {
_mod.Add(_mod, _max.Lsh(_one, i))
}
for _, i := range []uint{0, 114, 121, 124, 127} {
_inv.Add(_inv, _max.Lsh(_one, i))
}
_max.Lsh(_one, 128)
}
func add(x, y *big.Int) (r big.Int) {
if x.Cmp(_max) != -1 {
panic(x)
}
if y.Cmp(_max) != -1 {
panic(x)
}
r.Xor(x, y)
return
}
func mul(x, y *big.Int) (res big.Int) {
var tmp big.Int
if x.Cmp(_max) != -1 {
panic(x)
}
if y.Cmp(_max) != -1 {
panic(x)
}
for bit := 0; bit < 128; bit++ {
if y.Bit(bit) == 1 {
res.Xor(&res, tmp.Lsh(_one, uint(bit)).Mul(&tmp, x))
}
}
return mod(&res, _mod)
}
func dot(a, b *big.Int) big.Int {
r := mul(a, b)
return mul(&r, _inv)
}
func mod(a, m *big.Int) (r big.Int) {
var a2, m2 big.Int
m2.Set(m)
r.Set(a)
i := 0
for m2.Cmp(&r) < 0 {
m2.Lsh(&m2, 1)
i += 1
}
for i >= 0 {
a2.Xor(&r, &m2)
if a2.Cmp(&r) < 0 {
r.Set(&a2)
}
m2.Rsh(&m2, 1)
i -= 1
}
return
}
type polyvalIUF struct {
s big.Int
h big.Int
nonce []byte
}
func newPolyvalIUF(h, nonce []byte) (ret polyvalIUF) {
ret.s.SetInt64(0)
ret.h = b2i(h)
ret.nonce = make([]byte, len(nonce))
copy(ret.nonce, nonce)
return
}
func (self *polyvalIUF) update(inp []byte) {
for _, block := range self.split16(inp) {
self.update16(block)
}
}
func (self *polyvalIUF) update16(inp []byte) {
if len(inp) != 16 {
panic(inp)
}
i := b2i(inp)
a := add(&self.s, &i)
self.s = dot(&a, &self.h)
}
func (self *polyvalIUF) split16(inp []byte) (r [][]byte) {
for i := 0; i < len(inp); i += 16 {
a := make([]byte, 16)
pos := i + 16
if i+16 > len(inp) {
pos = len(inp)
}
copy(a, inp[i:pos])
r = append(r, a)
}
return
}
func (self *polyvalIUF) digest() []byte {
S_s := i2b(&self.s)
for i := 0; i < 12; i++ {
S_s[i] ^= self.nonce[i]
}
S_s[15] &= 0x7f
return S_s
}
func b2i(inp []byte) (r big.Int) {
for i := len(inp); i > 0; i-- {
c := big.NewInt(int64(inp[i-1]))
r.Lsh(&r, 8).Or(&r, c)
}
return
}
func i2b(i *big.Int) (s []byte) {
if i.Cmp(_zero) == 0 {
s = make([]byte, 16)
return
}
var m, tmp big.Int
m.Set(i)
for m.Cmp(_zero) != 0 {
s = append(s, byte(tmp.And(tmp.SetInt64(255), &m).Int64()))
m.Rsh(&m, 8)
}
return
}
func le_uint32(i uint32) []byte {
b := make([]byte, 4)
binary.LittleEndian.PutUint32(b, i)
return b
}
func read_le_uint32(b []byte) uint32 {
return binary.LittleEndian.Uint32(b)
}
func le_uint64(i uint64) []byte {
b := make([]byte, 8)
binary.LittleEndian.PutUint64(b, i)
return b
}
type AESGCMSIV struct {
msgAuthKey []byte
msgEncKey []byte
nonce []byte
}
func NewAESGCMSIV(key_gen_key, nonce []byte) (ret AESGCMSIV) {
cipher, _ := aes.NewCipher(key_gen_key)
tmp := le_uint32(0)
tmp = append(tmp, nonce...)
cipher.Encrypt(tmp, tmp)
ret.msgAuthKey = append(ret.msgAuthKey, tmp[0:8]...)
tmp = le_uint32(1)
tmp = append(tmp, nonce...)
cipher.Encrypt(tmp, tmp)
ret.msgAuthKey = append(ret.msgAuthKey, tmp[0:8]...)
tmp = le_uint32(2)
tmp = append(tmp, nonce...)
cipher.Encrypt(tmp, tmp)
ret.msgEncKey = append(ret.msgEncKey, tmp[0:8]...)
tmp = le_uint32(3)
tmp = append(tmp, nonce...)
cipher.Encrypt(tmp, tmp)
ret.msgEncKey = append(ret.msgEncKey, tmp[0:8]...)
if len(key_gen_key) == 32 {
tmp = le_uint32(4)
tmp = append(tmp, nonce...)
cipher.Encrypt(tmp, tmp)
ret.msgEncKey = append(ret.msgEncKey, tmp[0:8]...)
tmp = le_uint32(5)
tmp = append(tmp, nonce...)
cipher.Encrypt(tmp, tmp)
ret.msgEncKey = append(ret.msgEncKey, tmp[0:8]...)
}
ret.nonce = append(ret.nonce, nonce...)
return
}
func (self *AESGCMSIV) aesCtr(key, initialBlock, inp []byte) (output []byte) {
block := initialBlock
keystream_block := make([]byte, len(block))
cipher, _ := aes.NewCipher(key)
for len(inp) > 0 {
cipher.Encrypt(keystream_block, block)
block = append(le_uint32((read_le_uint32(block[0:4])+1)&0xffffffff), block[4:]...)
todo := len(keystream_block)
if len(inp) < todo {
todo = len(inp)
}
for j := 0; j < todo; j++ {
output = append(output, keystream_block[j]^inp[j])
}
inp = inp[todo:]
}
return
}
func (self *AESGCMSIV) polyval_calc(plaintext, additional_data []byte) []byte {
pvh := newPolyvalIUF(self.msgAuthKey, self.nonce)
pvh.update(additional_data)
pvh.update(plaintext)
length_block := append(le_uint64(uint64(len(additional_data))*8), le_uint64(uint64(len(plaintext))*8)...)
pvh.update(length_block)
return pvh.digest()
}
func (self *AESGCMSIV) Encrypt(plaintext, additional_data []byte) (ret []byte, err error) {
if len(plaintext) > (1 << 36) {
err = errors.New("plaintext too large")
return
}
if len(additional_data) > (1 << 36) {
err = errors.New("additional_data too large")
return
}
S_s := self.polyval_calc(plaintext, additional_data)
cipher, _ := aes.NewCipher(self.msgEncKey)
tag := make([]byte, len(S_s))
cipher.Encrypt(tag, S_s)
counter_block := make([]byte, len(tag))
copy(counter_block, tag)
counter_block[15] |= 0x80
ret = append(self.aesCtr(self.msgEncKey, counter_block, plaintext), tag...)
return
}
func (self *AESGCMSIV) Decrypt(ciphertext, additional_data []byte) (plaintext []byte, err error) {
if len(ciphertext) < 16 || len(ciphertext) > (1<<36)+16 {
err = errors.New("ciphertext too small or too large")
return
}
if len(additional_data) > (1 << 36) {
err = errors.New("additional_data too large")
return
}
tag := make([]byte, 16)
copy(tag, ciphertext[len(ciphertext)-16:])
ciphertext = ciphertext[0 : len(ciphertext)-16]
counter_block := make([]byte, 16)
copy(counter_block, tag)
counter_block[15] |= 0x80
plaintext = self.aesCtr(self.msgEncKey, counter_block, ciphertext)
S_s := self.polyval_calc(plaintext, additional_data)
expected_tag := make([]byte, len(S_s))
cipher, _ := aes.NewCipher(self.msgEncKey)
cipher.Encrypt(expected_tag, S_s)
xor_sum := 0
for i := 0; i < len(expected_tag); i++ {
xor_sum |= int(expected_tag[i] ^ tag[i])
}
if xor_sum != 0 {
err = errors.New("auth fail")
}
return
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment