Skip to content

Instantly share code, notes, and snippets.

@emmansun
Last active February 21, 2022 05:36
Show Gist options
  • Save emmansun/c05b5e2262997ae546950f806fcd0146 to your computer and use it in GitHub Desktop.
Save emmansun/c05b5e2262997ae546950f806fcd0146 to your computer and use it in GitHub Desktop.
package main
import (
"fmt"
"math/bits"
)
// CPU instruction simulation for SM3PARTW1 / SM3PARTW2 / SM3SS1 / SM3TT1A / SM3TT2A / SM3TT2A / SM3TT2B
type register128 struct {
value [4]uint32
}
func p0(x uint32) uint32 {
return x ^ bits.RotateLeft32(x, 9) ^ bits.RotateLeft32(x, 17)
}
func p1(x uint32) uint32 {
return x ^ bits.RotateLeft32(x, 15) ^ bits.RotateLeft32(x, 23)
}
func ff(x, y, z uint32) uint32 {
return (x & y) | (x & z) | (y & z)
}
func gg(x, y, z uint32) uint32 {
return (x & y) | (^x & z)
}
func eor(Vd, Vn, Vm *register128) {
for i := 0; i < 4; i++ {
Vd.value[i] = Vn.value[i] ^ Vm.value[i]
}
}
func copyRegister(dst, src *register128) {
for i := 0; i < 4; i++ {
dst.value[i] = src.value[i]
}
}
func SM3PARTW1(Vd, Vn, Vm *register128) {
result := &register128{}
// Vd EOR Vn
eor(result, Vd, Vn)
// result<95:0> = (Vd EOR Vn)<95:0> EOR (ROL(Vm<127:96>, 15):ROL(Vm<95:64>, 15):ROL(Vm<63:32>, 15));
for i := 0; i < 3; i++ {
result.value[i] = result.value[i] ^ bits.RotateLeft32(Vm.value[i+1], 15)
}
//
for i := 0; i < 4; i++ {
if i == 3 {
result.value[3] = result.value[3] ^ bits.RotateLeft32(result.value[0], 15)
}
result.value[i] = p1(result.value[i])
}
// V[d] = result;
copyRegister(Vd, result)
}
func SM3PARTW2(Vd, Vn, Vm *register128) {
result := &register128{}
tmp := &register128{}
var tmp2 uint32
//tmp<127:0> = Vn EOR (ROL(Vm<127:96>, 7):ROL(Vm<95:64>, 7):ROL(Vm<63:32>, 7):ROL(Vm<31:0>, 7));
for i := 0; i < 4; i++ {
tmp.value[i] = Vn.value[i] ^ bits.RotateLeft32(Vm.value[i], 7)
}
//result<127:0> = Vd<127:0> EOR tmp<127:0>;
eor(result, Vd, tmp)
//tmp2 = ROL(tmp<31:0>, 15);
tmp2 = bits.RotateLeft32(tmp.value[0], 15)
//tmp2 = tmp2 EOR ROL(tmp2, 15) EOR ROL(tmp2, 23);
tmp2 = p1(tmp2)
result.value[3] = result.value[3] ^ tmp2
// V[d] = result;
copyRegister(Vd, result)
}
// Vm[3]: place T constant
// Vn[3]: sm3 state word A
// Va[3]: sm3 state word E
// Vd[3]: result
func SM3SS1(Vd, Vn, Vm, Va *register128) {
result := &register128{}
result.value[3] = bits.RotateLeft32(Vm.value[3]+Va.value[3]+bits.RotateLeft32(Vn.value[3], 12), 7)
// V[d] = result;
copyRegister(Vd, result)
}
// imm2: j
// Vd: sm3 state (D, C, B, A)
// Vn[3]: ss1
// Vm: W' words
func SM3TT1A(Vd, Vn, Vm *register128, imm2 byte) {
result := &register128{}
WjPrime := Vm.value[imm2]
ss2 := Vn.value[3] ^ bits.RotateLeft32(Vd.value[3], 12)
tt1 := Vd.value[1] ^ (Vd.value[3] ^ Vd.value[2])
tt1 = tt1 + Vd.value[0] + ss2 + WjPrime
result.value[0] = Vd.value[1]
result.value[1] = bits.RotateLeft32(Vd.value[2], 9)
result.value[2] = Vd.value[3]
result.value[3] = tt1
copyRegister(Vd, result)
}
// imm2: j
// Vd: sm3 state (D, C, B, A)
// Vn[3]: ss1
// Vm: W' words
func SM3TT1B(Vd, Vn, Vm *register128, imm2 byte) {
result := &register128{}
WjPrime := Vm.value[imm2]
ss2 := Vn.value[3] ^ bits.RotateLeft32(Vd.value[3], 12)
tt1 := ff(Vd.value[3], Vd.value[2], Vd.value[1])
tt1 = tt1 + Vd.value[0] + ss2 + WjPrime
result.value[0] = Vd.value[1]
result.value[1] = bits.RotateLeft32(Vd.value[2], 9)
result.value[2] = Vd.value[3]
result.value[3] = tt1
copyRegister(Vd, result)
}
// imm2: j
// Vd: sm3 state (H, G, F, E)
// Vn[3]: ss1
// Vm: W words
func SM3TT2A(Vd, Vn, Vm *register128, imm2 byte) {
result := &register128{}
Wj := Vm.value[imm2]
tt2 := Vd.value[1] ^ (Vd.value[3] ^ Vd.value[2])
tt2 = tt2 + Vd.value[0] + Vn.value[3] + Wj
result.value[0] = Vd.value[1]
result.value[1] = bits.RotateLeft32(Vd.value[2], 19)
result.value[2] = Vd.value[3]
result.value[3] = p0(tt2)
copyRegister(Vd, result)
}
// imm2: j
// Vd: sm3 state (H, G, F, E)
// Vn[3]: ss1
// Vm: W words
func SM3TT2B(Vd, Vn, Vm *register128, imm2 byte) {
result := &register128{}
Wj := Vm.value[imm2]
tt2 := gg(Vd.value[3], Vd.value[2], Vd.value[1])
tt2 = tt2 + Vd.value[0] + Vn.value[3] + Wj
result.value[0] = Vd.value[1]
result.value[1] = bits.RotateLeft32(Vd.value[2], 19)
result.value[2] = Vd.value[3]
result.value[3] = p0(tt2)
copyRegister(Vd, result)
}
// load and reverse byte order as 32 bits word
func loadRev32Register(r *register128, p []byte) {
r.value[0] = uint32(p[0])<<24 | uint32(p[1])<<16 | uint32(p[2])<<8 | uint32(p[3])
r.value[1] = uint32(p[4])<<24 | uint32(p[5])<<16 | uint32(p[6])<<8 | uint32(p[7])
r.value[2] = uint32(p[8])<<24 | uint32(p[9])<<16 | uint32(p[10])<<8 | uint32(p[11])
r.value[3] = uint32(p[12])<<24 | uint32(p[13])<<16 | uint32(p[14])<<8 | uint32(p[15])
}
// https://developer.arm.com/documentation/ddi0602/2021-12/SIMD-FP-Instructions/EXT--Extract-vector-from-pair-of-vectors-?lang=en
// imm2, here is 4 bytes as unit, so if imm2 is 3, then it's 12 as EXT instruction parameter value
// Vd: result
// Vm: high
// Vn: low
func extract32(Vd, Vn, Vm *register128, imm2 byte) {
var tmp [8]uint32
copy(tmp[4:], Vm.value[:]) // hi
copy(tmp[:], Vn.value[:]) // low
copy(Vd.value[:], tmp[imm2:])
}
// Used v5 as temp register
func roundA(i byte, t, st1, st2, w, wt *register128) {
v5 := &register128{}
SM3SS1(v5, st1, t, st2)
t.value[3] = bits.RotateLeft32(t.value[3], 1)
SM3TT1A(st1, v5, wt, i)
SM3TT2A(st2, v5, w, i)
}
// Compress 4 words and generate 4 words, used v6, v7, v10 as temp registers
// s4, used to store next 4 words
// s0, W(4i) W(4i+1) W(4i+2) W(4i+3)
// s1, W(4i+4) W(4i+5) W(4i+6) W(4i+7)
// s2, W(4i+8) W(4i+9) W(4i+10) W(4i+11)
// s3, W(4i+12) W(4i+13) W(4i+14) W(4i+15)
// t, t constant
// st1, st2, sm3 state
func qroundA(t, st1, st2, s0, s1, s2, s3, s4 *register128) {
v6 := &register128{}
v7 := &register128{}
v10 := &register128{}
// Extension
extract32(s4, s1, s2, 3) // w7,w8,w9,w10
extract32(v6, s0, s1, 3) // w3,w4,w5,w6
extract32(v7, s2, s3, 2) // w10,w11,w12,w13
SM3PARTW1(s4, s0, s3)
SM3PARTW2(s4, v7, v6)
eor(v10, s0, s1) //v10 is W'
// Compression
roundA(0, t, st1, st2, s0, v10)
roundA(1, t, st1, st2, s0, v10)
roundA(2, t, st1, st2, s0, v10)
roundA(3, t, st1, st2, s0, v10)
}
// Used v5 as temp register
func roundB(i byte, t, st1, st2, w, wt *register128) {
v5 := &register128{}
SM3SS1(v5, st1, t, st2)
t.value[3] = bits.RotateLeft32(t.value[3], 1)
SM3TT1B(st1, v5, wt, i)
SM3TT2B(st2, v5, w, i)
}
// Used v6, v7, v10 as temp registers
func qroundB(t, st1, st2, s0, s1, s2, s3, s4 *register128) {
v6 := &register128{}
v7 := &register128{}
v10 := &register128{}
if s4 != nil {
extract32(s4, s1, s2, 3) // w7,w8,w9,w10
extract32(v6, s0, s1, 3) // w3,w4,w5,w6
extract32(v7, s2, s3, 2) // w10,w11,w12,w13
SM3PARTW1(s4, s0, s3)
SM3PARTW2(s4, v7, v6)
}
eor(v10, s0, s1) //v10 is W'
roundB(0, t, st1, st2, s0, v10)
roundB(1, t, st1, st2, s0, v10)
roundB(2, t, st1, st2, s0, v10)
roundB(3, t, st1, st2, s0, v10)
}
//
// rev64 v.4s, v.4s
// ext v.16b, v.16b, v.16b, #8
func revWordOrder(v *register128) {
result := &register128{}
for i := 0; i < 4; i++ {
result.value[i] = v.value[3-i]
}
copyRegister(v, result)
}
func printRegister(v8 *register128) {
for i := 0; i < 4; i++ {
fmt.Printf("%08x ", v8.value[i])
}
fmt.Println()
}
func print2Registers(v8, v9 *register128) {
for i := 0; i < 4; i++ {
fmt.Printf("%08x ", v8.value[i])
}
for i := 0; i < 4; i++ {
fmt.Printf("%08x ", v9.value[i])
}
fmt.Println()
}
func block(sm3state *[8]uint32, p []byte) {
v0 := &register128{}
v1 := &register128{}
v2 := &register128{}
v3 := &register128{}
v4 := &register128{}
// for sm3 state
v8 := &register128{}
v9 := &register128{}
v15 := &register128{}
v16 := &register128{}
// for T constants
v11 := &register128{}
copy(v8.value[:], sm3state[:])
copy(v9.value[:], sm3state[4:])
revWordOrder(v8)
revWordOrder(v9)
for len(p) >= 64 {
//save last sm3state
copyRegister(v15, v8)
copyRegister(v16, v9)
//load one block
loadRev32Register(v0, p)
loadRev32Register(v1, p[16:])
loadRev32Register(v2, p[32:])
loadRev32Register(v3, p[48:])
// first 16 rounds
v11.value[3] = 0x79cc4519
qroundA(v11, v8, v9, v0, v1, v2, v3, v4)
qroundA(v11, v8, v9, v1, v2, v3, v4, v0)
qroundA(v11, v8, v9, v2, v3, v4, v0, v1)
qroundA(v11, v8, v9, v3, v4, v0, v1, v2)
// second 48 rounds
v11.value[3] = 0x9d8a7a87
qroundB(v11, v8, v9, v4, v0, v1, v2, v3)
qroundB(v11, v8, v9, v0, v1, v2, v3, v4)
qroundB(v11, v8, v9, v1, v2, v3, v4, v0)
qroundB(v11, v8, v9, v2, v3, v4, v0, v1)
qroundB(v11, v8, v9, v3, v4, v0, v1, v2)
qroundB(v11, v8, v9, v4, v0, v1, v2, v3)
qroundB(v11, v8, v9, v0, v1, v2, v3, v4)
qroundB(v11, v8, v9, v1, v2, v3, v4, v0)
qroundB(v11, v8, v9, v2, v3, v4, v0, v1)
qroundB(v11, v8, v9, v3, v4, nil, nil, nil)
qroundB(v11, v8, v9, v4, v0, nil, nil, nil)
qroundB(v11, v8, v9, v0, v1, nil, nil, nil)
eor(v8, v8, v15)
eor(v9, v9, v16)
p = p[64:]
}
revWordOrder(v8)
revWordOrder(v9)
copy(sm3state[:], v8.value[:])
copy(sm3state[4:], v9.value[:])
}
var IV [8]uint32 = [8]uint32{0x7380166f, 0x4914b2b9, 0x172442d7, 0xda8a0600, 0xa96f30bc, 0x163138aa, 0xe38dee4d, 0xb0fb0e4e}
func main() {
// case 1
var sm3state [8]uint32
var p []byte
copy(sm3state[:], IV[:])
p = []byte{
0x61, 0x62, 0x63, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x18}
block(&sm3state, p)
fmt.Printf("case 1 result: ")
for i := 0; i < 8; i++ {
fmt.Printf("%08x ", sm3state[i])
}
fmt.Println()
//case 2
copy(sm3state[:], IV[:])
p = []byte{
0x61, 0x62, 0x63, 0x64, 0x61, 0x62, 0x63, 0x64, 0x61, 0x62, 0x63, 0x64, 0x61, 0x62, 0x63, 0x64,
0x61, 0x62, 0x63, 0x64, 0x61, 0x62, 0x63, 0x64, 0x61, 0x62, 0x63, 0x64, 0x61, 0x62, 0x63, 0x64,
0x61, 0x62, 0x63, 0x64, 0x61, 0x62, 0x63, 0x64, 0x61, 0x62, 0x63, 0x64, 0x61, 0x62, 0x63, 0x64,
0x61, 0x62, 0x63, 0x64, 0x61, 0x62, 0x63, 0x64, 0x61, 0x62, 0x63, 0x64, 0x61, 0x62, 0x63, 0x64,
0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00,
}
block(&sm3state, p)
fmt.Printf("case 2 result: ")
for i := 0; i < 8; i++ {
fmt.Printf("%08x ", sm3state[i])
}
fmt.Println()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment