Skip to content

Instantly share code, notes, and snippets.

@jart
Created April 18, 2023 03:35
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jart/6983fee9f72e2560366226f5fb04035e to your computer and use it in GitHub Desktop.
Save jart/6983fee9f72e2560366226f5fb04035e to your computer and use it in GitHub Desktop.
equiv -- checks if integral c expressions are equivalent
#!/usr/bin/env python3
# -*- python -*-
#
# SYNOPSIS
#
# equiv -- checks if integral c expressions are equivalent
#
# EXAMPLES
#
# $ equiv '(gn1 & 1) || (gn1 != gn2)' '(gn1 & 1) | (gn1 ^ gn2)'
# int... DIFFERENT! gn2=0, gn1=2
# long... DIFFERENT! gn2=0, gn1=2
# unsigned... DIFFERENT! gn2=0, gn1=2
# unsigned long... DIFFERENT! gn2=0, gn1=2
# int coerced to boolean... EQUIVALENT
# long coerced to boolean... EQUIVALENT
# unsigned coerced to boolean... EQUIVALENT
# unsigned long coerced to boolean... EQUIVALENT
#
# $ equiv 'MIN(a, 0-a)' '0-ABS(a)'
# int... EQUIVALENT
# long... EQUIVALENT
# unsigned... DIFFERENT! a=1
# unsigned long... DIFFERENT! a=1
# int coerced to boolean... EQUIVALENT
# long coerced to boolean... EQUIVALENT
# unsigned coerced to boolean... EQUIVALENT
# unsigned long coerced to boolean... EQUIVALENT
#
# $ equiv '((((x ^ y) | (~(x ^ y) + 1)) >> 63) - 1)' 'x == y'
# int... DIFFERENT! y=0, x=0
# long... DIFFERENT! y=0, x=0
# unsigned... DIFFERENT! y=0, x=0
# unsigned long... DIFFERENT! y=0, x=0
# int coerced to boolean... DIFFERENT! y=0, x=1
# long coerced to boolean... DIFFERENT! y=0, x=1
# unsigned coerced to boolean... DIFFERENT! y=0, x=1
# unsigned long coerced to boolean... EQUIVALENT
#
import os
import re
import sys
KEYWORDS = set(('int', 'float', 'double', 'char', 'short', 'long', 'unsigned'))
LIB = r'''
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <limits.h>
#include <stdint.h>
#include <signal.h>
#include <setjmp.h>
#define T %s
#define TBIT (sizeof(T) * CHAR_BIT - 1)
#define TMIN (((T) ~(T)0) > 1 ? (T)0 : (T)((uintmax_t)1 << TBIT))
#define TMAX (((T) ~(T)0) > 1 ? (T) ~(T)0 : (T)(((uintmax_t)1 << TBIT) - 1))
#define bsf(x) __builtin_ctzll(x)
#define bsr(x) (__builtin_clzll(x) ^ 63)
#define popcount(x) __builtin_popcountll(x)
#define bswap_16(x) __builtin_bswap16(x)
#define bswap_32(x) __builtin_bswap32(x)
#define bswap_64(x) __builtin_bswap64(x)
#define ABS(X) ((X) >= 0 ? (X) : -(X))
#define MIN(X, Y) ((Y) > (X) ? (X) : (Y))
#define MAX(X, Y) ((Y) < (X) ? (X) : (Y))
#define IS2POW(X) (!((X) & ((X)-1)))
#define ROUNDUP(X, K) (((X) + (K)-1) & -(K))
#define ROUNDDOWN(X, K) ((X) & -(K))
unsigned long long roundup2pow(unsigned long long x) { return x > 1 ? 2ull << bsr(x - 1) : x ? 1 : 0; }
unsigned long long rounddown2pow(unsigned long long x) { return x ? 1ull << bsr(x) : 0; }
T V[] = {0, 1, 2, 3, 4, -1, -2, -3, -4, TMIN, TMIN+1, TMIN+2, TMIN+3, TMAX, TMAX-1, TMAX-2, TMAX-3, TMIN/2, TMAX/2};
int crashed;
void OnSigfpe(int sig) { crashed = 1; }
'''
def identity(x):
return x
def boolify(x):
return '!!(' + x + ')'
def extract_variable_names(code):
syms = re.findall(r'[a-z][0-9a-z_]*\(?', code)
return set(s for s in syms if not s.endswith('(')) - KEYWORDS
def main(args):
if len(args) != 3:
print("usage: %s EXPR1 EXPR2" % (args[0]))
return 1
expr1 = args[1]
expr2 = args[2]
names = extract_variable_names(expr1 + ' ' + expr2)
for Xfunc, Xdesc in ((identity, ''),
(boolify, ' coerced to boolean')):
for T, F in (('int', '%d'),
('long', '%ld'),
('unsigned', '%u'),
('unsigned long', '%lu')):
with open('/tmp/equiv.c', 'w') as f:
f.write(LIB % (T))
f.write('int main() {\n')
f.write('signal(SIGFPE, OnSigfpe);\n')
for name in names:
f.write('for (int %s_i = 0; %s_i < sizeof(V) / sizeof(*V); ++%s_i) {\n' % (name, name, name))
f.write('T %s = V[%s_i];\n' % (name, name))
f.write('T R1, R2;\n')
f.write('crashed = 0;\n')
f.write('R1 = (%s);\n' % (Xfunc(expr1)))
f.write('if (crashed) R1 = 666;\n')
f.write('crashed = 0;\n')
f.write('R2 = (%s);\n' % (Xfunc(expr2)))
f.write('if (crashed) R2 = 666;\n')
f.write('if (R1 != R2) {\n')
f.write(r' fprintf(stdout, "\033[51G')
for i, name in enumerate(names):
if i:
f.write(', ')
f.write('%s=%s' % (name, F))
f.write('"')
for name in names:
f.write(', %s' % (name))
f.write(');\n')
f.write(' fflush(stdout);\n')
f.write(' return 1;\n')
f.write('}\n')
for name in names:
f.write('}\n')
f.write('return 0;\n')
f.write('}\n')
if os.system('cc -w -O -o /tmp/equiv /tmp/equiv.c') != 0:
return 1
sys.stdout.write('%s%s...' % (T, Xdesc))
sys.stdout.flush()
if os.system('/tmp/equiv') == 0:
sys.stdout.write('\033[40G\033[1;32mEQUIVALENT\033[0m\n')
else:
sys.stdout.write('\033[40G\033[1;31mDIFFERENT!\033[0m\n')
return 0
if __name__ == '__main__':
sys.exit(main(sys.argv))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment