Created
February 13, 2020 15:30
-
-
Save pawlos/ebf753484ff62c908bc3df60f50bae35 to your computer and use it in GitHub Desktop.
Solution for vv_max with emulating AVX operation with z3
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from z3 import * | |
| zero = 0 | |
| reg0 = BitVec('r0', 32*8) | |
| reg1 = BitVec('r1', 32*8) | |
| reg2 = BitVec('r2', 32*8) | |
| reg3 = BitVec('r3', 32*8) | |
| reg4 = BitVec('r4', 32*8) | |
| reg5 = BitVec('r5', 32*8) | |
| reg6 = BitVec('r6', 32*8) | |
| reg7 = BitVec('r7', 32*8) | |
| reg8 = BitVec('r8', 32*8) | |
| reg9 = BitVec('r9', 32*8) | |
| reg10 = BitVec('r10', 32*8) | |
| reg11 = BitVec('r11', 32*8) | |
| reg12 = BitVec('r12', 32*8) | |
| reg13 = BitVec('r13', 32*8) | |
| reg14 = BitVec('r14', 32*8) | |
| reg15 = BitVec('r15', 32*8) | |
| reg16 = BitVec('r16', 32*8) | |
| reg17 = BitVec('r17', 32*8) | |
| reg18 = BitVec('r18', 32*8) | |
| reg19 = BitVec('r19', 32*8) | |
| reg20 = BitVec('r20', 32*8) | |
| reg21 = BitVec('r21', 32*8) | |
| reg22 = BitVec('r22', 32*8) | |
| reg23 = BitVec('r23', 32*8) | |
| reg24 = BitVec('r24', 32*8) | |
| reg25 = BitVec('r25', 32*8) | |
| reg26 = BitVec('r26', 32*8) | |
| reg27 = BitVec('r27', 32*8) | |
| reg28 = BitVec('r28', 32*8) | |
| reg29 = BitVec('r29', 32*8) | |
| reg30 = BitVec('r30', 32*8) | |
| reg31 = BitVec('r31', 32*8) | |
| regs = [reg0, reg1, reg2, reg3, reg4, reg5, reg6, reg7 ,reg8, reg9, reg10, reg0, reg0, reg0, reg0, reg0, reg0, reg0, reg0, reg0, reg0, reg0,reg0,reg0,reg0,reg0,reg0,reg0,reg0,reg0,reg0,reg0] | |
| import re | |
| def to_num(v): | |
| d = v[0] | |
| for p in v[1:]: | |
| d = (d << 8) + p | |
| return d | |
| # vpermd | |
| def perm(op1, op2): | |
| chunks = [None]*8 | |
| for j in range(8): | |
| i = j * 32 | |
| idx = Extract(256-32*j-6,256-32*j-8, regs[op2]) | |
| chunks[j] = simplify(If(idx == 7, Extract(1*32-1, 0, regs[op1]), | |
| If(idx == 6, Extract(2*32-1, 1*32, regs[op1]), | |
| If(idx == 5, Extract(3*32-1, 2*32, regs[op1]), | |
| If(idx == 4, Extract(4*32-1, 3*32, regs[op1]), | |
| If(idx == 3, Extract(5*32-1, 4*32, regs[op1]), | |
| If(idx == 2, Extract(6*32-1, 5*32, regs[op1]), | |
| If(idx == 1, Extract(7*32-1, 6*32, regs[op1]), | |
| If(idx == 0, Extract(8*32-1, 7*32, regs[op1]),-1))))))))) | |
| a = simplify(Concat(chunks[0],chunks[1], chunks[2],chunks[3],chunks[4],chunks[5],chunks[6],chunks[7])) | |
| return a | |
| # vpsrld | |
| def shr(op1, const): | |
| src = regs[op1] | |
| chunks = [None]*8 | |
| for j in range(8): | |
| i = j*32 | |
| elem = simplify(If(const > 31, BitVecVal(0, 32), | |
| Extract(256-j*32-1, 256-(j+1)*32, src))) | |
| elem2 = simplify(Concat(Extract(7,0, elem),Extract(15,8, elem),Extract(23,16, elem),Extract(31,24, elem))) | |
| #print (elem2) | |
| elem3 = simplify(LShR(elem2, const)) | |
| #print (elem3) | |
| chunks[j] = Concat(Extract(7,0, elem3), Extract(15,8, elem3), Extract(23,16, elem3), Extract(31,24, elem3)) | |
| return simplify(Concat(chunks)) | |
| # vpslld | |
| def shl(op1, const): | |
| src = regs[op1] | |
| chunks = [None]*8 | |
| for j in range(8): | |
| i = j*32 | |
| elem = simplify(If(const > 31, BitVecVal(0, 32), | |
| Extract(256-j*32-1, 256-(j+1)*32, src))) | |
| elem2 = simplify(Concat(Extract(7,0, elem),Extract(15,8, elem),Extract(23,16, elem),Extract(31,24, elem))) | |
| elem3 = simplify(elem2 << const) | |
| chunks[j] = Concat(Extract(7,0, elem3), Extract(15,8, elem3), Extract(23,16, elem3), Extract(31,24, elem3)) | |
| return simplify(Concat(chunks)) | |
| # vpxor | |
| def xor(op1, op2): | |
| return simplify(regs[op1] ^ regs[op2]) | |
| # vpand | |
| def _and(op1, op2): | |
| return simplify(regs[op1] & regs[op2]) | |
| # vpor | |
| def _or(op1, op2): | |
| return simplify(regs[op1] | regs[op2]) | |
| # vpcmpeqb | |
| def cmp(op1, op2): | |
| chunksA = [None]*32 | |
| chunksB = [None]*32 | |
| chunksC = [None]*32 | |
| a = regs[op1] | |
| b = regs[op2] | |
| for j in range(32): | |
| chunksA[j] = simplify(Extract((j+1)*8-1, j*8, a)) | |
| chunksB[j] = simplify(Extract((j+1)*8-1, j*8, b)) | |
| for j in range(32): | |
| chunksC[j] = If(simplify(chunksA[j] == chunksB[j]), BitVecVal(0xFF, 8), BitVecVal(0, 8)) | |
| return simplify(Concat(chunksC))#[::-1] | |
| def to_dword(v): | |
| return simplify(Concat(Extract(7,0, v),Extract(15,8, v),Extract(23,16, v),Extract(31,24, v))) | |
| def from_dword(v): | |
| return Concat(Extract(7,0, v), Extract(15,8, v), Extract(23,16, v), Extract(31,24, v)) | |
| # vpaddd | |
| def add_dwords(op1, op2): | |
| src1 = regs[op1] | |
| chunksA = [None]*8 | |
| chunksB = [None]*8 | |
| chunksA[0] = to_dword(simplify(Extract(1*32-1, 0*32, src1))) | |
| chunksA[1] = to_dword(simplify(Extract(2*32-1, 1*32, src1))) | |
| chunksA[2] = to_dword(simplify(Extract(3*32-1, 2*32, src1))) | |
| chunksA[3] = to_dword(simplify(Extract(4*32-1, 3*32, src1))) | |
| chunksA[4] = to_dword(simplify(Extract(5*32-1, 4*32, src1))) | |
| chunksA[5] = to_dword(simplify(Extract(6*32-1, 5*32, src1))) | |
| chunksA[6] = to_dword(simplify(Extract(7*32-1, 6*32, src1))) | |
| chunksA[7] = to_dword(simplify(Extract(8*32-1, 7*32, src1))) | |
| src2 = regs[op2] | |
| chunksB[0] = to_dword(simplify(Extract(1*32-1, 0*32, src2))) | |
| chunksB[1] = to_dword(simplify(Extract(2*32-1, 1*32, src2))) | |
| chunksB[2] = to_dword(simplify(Extract(3*32-1, 2*32, src2))) | |
| chunksB[3] = to_dword(simplify(Extract(4*32-1, 3*32, src2))) | |
| chunksB[4] = to_dword(simplify(Extract(5*32-1, 4*32, src2))) | |
| chunksB[5] = to_dword(simplify(Extract(6*32-1, 5*32, src2))) | |
| chunksB[6] = to_dword(simplify(Extract(7*32-1, 6*32, src2))) | |
| chunksB[7] = to_dword(simplify(Extract(8*32-1, 7*32, src2))) | |
| result = [] | |
| for i in range(len(chunksA)): | |
| result.append(simplify(from_dword(chunksA[i] + chunksB[i]))) | |
| return simplify(Concat(result[::-1])) | |
| # vpaddb | |
| def add_bytes(op1, op2): | |
| a = regs[op1] | |
| b = regs[op2] | |
| chunks = [None]*32 | |
| for j in range(32): | |
| i = j * 8 | |
| chunks[j] = simplify(Extract(i+7, i, a) + Extract(i+7,i, b)) | |
| return simplify(Concat(chunks[::-1])) | |
| # vpshufb | |
| def shuff(op1, op2): | |
| a = regs[op1] | |
| b = regs[op2] | |
| destLow = [None]*16 | |
| destHi = [None]*16 | |
| for j in range(16): | |
| i = j*8 | |
| idx = simplify(Extract(256-8*j-5,256-8*j-8,b)) | |
| off = 256 | |
| destLow[j] = simplify(If(simplify(Extract(off-8*j-1,off-8*(j+1),b)) == 0xF, | |
| BitVecVal(0, 8), | |
| simplify(If(idx == 0, Extract(off-0*8-1, off-1*8, a), | |
| If(idx == 1, Extract(off-1*8-1, off-2*8, a), | |
| If(idx == 2, Extract(off-2*8-1, off-3*8, a), | |
| If(idx == 3, Extract(off-3*8-1, off-4*8, a), | |
| If(idx == 4, Extract(off-4*8-1, off-5*8, a), | |
| If(idx == 5, Extract(off-5*8-1, off-6*8, a), | |
| If(idx == 6, Extract(off-6*8-1, off-7*8, a), | |
| If(idx == 7, Extract(off-7*8-1, off-8*8, a), | |
| If(idx == 8, Extract(off-8*8-1, off-9*8, a), | |
| If(idx == 9, Extract(off-9*8-1,off- 10*8, a), | |
| If(idx == 10, Extract(off-10*8-1,off-11*8, a), | |
| If(idx == 11, Extract(off-11*8-1,off- 12*8, a), | |
| If(idx == 12, Extract(off-12*8-1,off- 13*8, a), | |
| If(idx == 13, Extract(off-13*8-1,off- 14*8, a), | |
| If(idx == 14, Extract(off-14*8-1,off- 15*8, a), | |
| If(idx == 15, Extract(off-15*8-1,off- 16*8,a), BitVecVal(0,8))))))))))))))))))) | |
| ) | |
| idx = simplify(Extract(128-8*j-5, 128-8*j-8, b)) | |
| off = 128 | |
| destHi[j] = simplify(If(simplify(Extract(off-8*j-1,off-8*(j+1),b)) == 0xF, | |
| BitVecVal(0, 8), | |
| simplify(If(idx == 0, Extract(off-0*8-1, off-1*8, a), | |
| If(idx == 1, Extract(off-1*8-1, off-2*8, a), | |
| If(idx == 2, Extract(off-2*8-1, off-3*8, a), | |
| If(idx == 3, Extract(off-3*8-1, off-4*8, a), | |
| If(idx == 4, Extract(off-4*8-1, off-5*8, a), | |
| If(idx == 5, Extract(off-5*8-1, off-6*8, a), | |
| If(idx == 6, Extract(off-6*8-1, off-7*8, a), | |
| If(idx == 7, Extract(off-7*8-1, off-8*8, a), | |
| If(idx == 8, Extract(off-8*8-1, off-9*8, a), | |
| If(idx == 9, Extract(off-9*8-1, off-10*8, a), | |
| If(idx == 10, Extract(off-10*8-1,off-11*8, a), | |
| If(idx == 11, Extract(off-11*8-1,off-12*8, a), | |
| If(idx == 12, Extract(off-12*8-1, off-13*8, a), | |
| If(idx == 13, Extract(off-13*8-1, off-14*8, a), | |
| If(idx == 14, Extract(off-14*8-1, off-15*8, a), | |
| If(idx == 15, Extract(off-15*8-1, off-16*8,a), BitVecVal(0,8))))))))))))))))))) | |
| ) | |
| res = [] | |
| for i in range(16): | |
| res.append(simplify(destLow[i])) | |
| for i in range(16): | |
| res.append(simplify(destHi[i])) | |
| return Concat(res) | |
| def mul_add8(op1, op2): | |
| res = [] | |
| a = regs[op1] | |
| b = regs[op2] | |
| for j in range(16): | |
| i = (j+1)*16 | |
| aHi = simplify(ZeroExt(8, Extract(256-i+15,256-i+8, a))) | |
| bHi = simplify(ZeroExt(8, Extract(256-i+15,256-i+8, b))) | |
| aLo = simplify(ZeroExt(8, Extract(256-i+7,256-i,a))) | |
| bLo = simplify(ZeroExt(8, Extract(256-i+7,256-i,b))) | |
| c = aHi*bHi + aLo*bLo | |
| v = to_16bit(simplify(c)) | |
| res.append(v) | |
| o = Concat(res) | |
| return simplify(o) | |
| def to_16bit(v): | |
| return simplify(Concat(Extract(7, 0, v), Extract(15, 8, v))) | |
| def to_32bit(v): | |
| return to_dword(v) | |
| def from_16bit(v): | |
| return simplify(Concat(Extract(7, 0, v), Extract(15, 8, v))) | |
| def mul_add16(op1, op2): | |
| res = [] | |
| a = regs[op1] | |
| b = regs[op2] | |
| for j in range(8): | |
| i = (j+1)*32 | |
| aHi = simplify(ZeroExt(16, from_16bit(Extract(256-i+31,256-i+16, a)))) | |
| bHi = simplify(ZeroExt(16, from_16bit(Extract(256-i+31,256-i+16, b)))) | |
| aLo = simplify(ZeroExt(16, from_16bit(Extract(256-i+15,256-i,a)))) | |
| bLo = simplify(ZeroExt(16, from_16bit(Extract(256-i+15,256-i,b)))) | |
| c = aHi*bHi + aLo*bLo | |
| v = to_32bit(simplify(c)) | |
| res.append(v) | |
| return simplify(Concat(res)) | |
| s = Solver() | |
| def write_before(txt, r1, r2, a): | |
| if a(): | |
| print (txt) | |
| print (regs[r1]) | |
| print (regs[r2]) | |
| def write_after(txt, r, a): | |
| if a(): | |
| print (txt) | |
| print(regs[r]) | |
| sys.exit(-1) | |
| import sys | |
| printArgs = False | |
| fileName = sys.argv[1] | |
| print ('Opening file: '+fileName) | |
| with open(fileName) as f: | |
| while True: | |
| line = f.readline().strip() | |
| if printArgs: | |
| printArgs = False | |
| if 'r2 =' in line: | |
| printArgs = True | |
| if not all([x.size() == 256 for x in regs]): | |
| print ('something not right!') | |
| sys.exit(-1) | |
| if line == '': | |
| break | |
| m = re.match('^r(\\d{1,2}) = (\\[.+\\])$', line) | |
| if m: | |
| r = int(m.group(1)) | |
| v = eval(m.group(2).strip()) | |
| if (r != 1): | |
| p = to_num(v) | |
| regs[r] = BitVecVal(p, 32*8) | |
| continue | |
| m = re.match('^r(\\d{1,2}) = r(\\d{1,2})$', line) | |
| if m: | |
| print ('Should never match!') | |
| r1 = int(m.group(1)) | |
| r2 = int(m.group(2)) | |
| regs[r1] = regs[r2] | |
| continue | |
| m = re.match('^r(\\d{1,2}) = r(\\d{1,2}) perm r(\\d{1,2})$', line) | |
| if m: | |
| r1 = int(m.group(1)) | |
| r2 = int(m.group(2)) | |
| r3 = int(m.group(3)) | |
| regs[r1] = perm(r2, r3) | |
| continue | |
| m = re.match('^r(\\d{1,2}) = r(\\d{1,2}) >> (\\d+)$', line) | |
| if m: | |
| r1 = int(m.group(1)) | |
| r2 = int(m.group(2)) | |
| c = int(m.group(3)) | |
| regs[r1] = shr(r2, c) | |
| continue | |
| m = re.match('^r(\\d{1,2}) = r(\\d{1,2}) \\^ r(\\d{1,2})$', line) | |
| if m: | |
| r1 = int(m.group(1)) | |
| r2 = int(m.group(2)) | |
| r3 = int(m.group(3)) | |
| regs[r1] = xor(r2, r3) | |
| continue | |
| m = re.match('^r(\\d{1,2}) = r(\\d{1,2}) & r(\\d{1,2})$', line) | |
| if m: | |
| r1 = int(m.group(1)) | |
| r2 = int(m.group(2)) | |
| r3 = int(m.group(3)) | |
| regs[r1] = _and(r2, r3) | |
| continue | |
| m = re.match('^r(\\d{1,2}) = r(\\d{1,2}) << (\\d+)$', line) | |
| if m: | |
| r1 = int(m.group(1)) | |
| r2 = int(m.group(2)) | |
| c = int(m.group(3)) | |
| regs[r1] = shl(r2, c) | |
| continue | |
| m = re.match('^r(\\d{1,2}) = r(\\d{1,2}) \\| r(\\d{1,2})$', line) | |
| if m: | |
| r1 = int(m.group(1)) | |
| r2 = int(m.group(2)) | |
| r3 = int(m.group(3)) | |
| regs[r1] = _or(r2, r3) | |
| continue | |
| m = re.match('^r(\\d{1,2}) = r(\\d{1,2}) == r(\\d{1,2});.*$', line) | |
| if m: | |
| r1 = int(m.group(1)) | |
| r2 = int(m.group(2)) | |
| r3 = int(m.group(3)) | |
| regs[r1] = cmp(r2, r3) | |
| continue | |
| m = re.match('^r(\\d{1,2}) = r(\\d{1,2}) \\+ r(\\d{1,2}) ;dwords.*$', line) | |
| if m: | |
| r1 = int(m.group(1)) | |
| r2 = int(m.group(2)) | |
| r3 = int(m.group(3)) | |
| regs[r1] = add_dwords(r2, r3) | |
| continue | |
| m = re.match('^r(\\d{1,2}) = r(\\d{1,2}) \\+ r(\\d{1,2}) ;bytes.*$', line) | |
| if m: | |
| r1 = int(m.group(1)) | |
| r2 = int(m.group(2)) | |
| r3 = int(m.group(3)) | |
| regs[r1] = add_bytes(r2, r3) | |
| continue | |
| m = re.match('^r(\\d{1,2}) = r(\\d{1,2}) shuff r(\\d{1,2})$', line) | |
| if m: | |
| r1 = int(m.group(1)) | |
| r2 = int(m.group(2)) | |
| r3 = int(m.group(3)) | |
| regs[r1] = shuff(r2, r3) | |
| continue | |
| m = re.match('^r(\\d{1,2}) = r(\\d{1,2}) mul_add8 r(\\d{1,2})$', line) | |
| if m: | |
| r1 = int(m.group(1)) | |
| r2 = int(m.group(2)) | |
| r3 = int(m.group(3)) | |
| regs[r1] = mul_add8(r2, r3) | |
| continue | |
| m = re.match('^r(\\d{1,2}) = r(\\d{1,2}) mul_add16 r(\\d{1,2})$', line) | |
| if m: | |
| r1 = int(m.group(1)) | |
| r2 = int(m.group(2)) | |
| r3 = int(m.group(3)) | |
| regs[r1] = mul_add16(r2, r3) | |
| continue | |
| print('Unrecognized line: '+line) | |
| import sys | |
| sys.exit(-1) | |
| x = [0x70,0x70,0xB2,0xAC,0x01,0xD2,0x5E,0x61,0x0A,0xA7,0x2A,0xA8,0x08,0x1C,0x86,0x1A,0xE8,0x45,0xC8,0x29,0xB2,0xF3,0xA1,0x1E,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00] | |
| x = BitVecVal(to_num(x), 32*8) | |
| s.add(x == regs[2]) | |
| print (simplify(regs[2])) | |
| for i in range(32): | |
| c = Extract(i*8+7, i*8, regs[1]) | |
| s.add(And(c > 0x30, c <= 0x7a)) | |
| for i in range(32): | |
| c = Extract(i*8+7, i*8, regs[1]) ^ Extract(i*8+7, i*8, regs[31]) | |
| s.add(Or( | |
| And(c >= 0x30, c <= 0x39), | |
| Or( | |
| And(c >= 0x41, c <= 0x5a), | |
| Or( | |
| And(c >= 0x61, c <= 0x7a), Or(c == 0x5f, c == 0x0))))) | |
| print (s.check()) | |
| r = s.model() | |
| print (r) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment