Skip to content

Instantly share code, notes, and snippets.

@DDoSolitary
Last active April 29, 2021 15:55
Show Gist options
  • Save DDoSolitary/0a1fdcd1a3a5f715f429ed7bc11b6d6b to your computer and use it in GitHub Desktop.
Save DDoSolitary/0a1fdcd1a3a5f715f429ed7bc11b6d6b to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
import functools
import json
import math
import multiprocessing
import random
import re
import subprocess
import sys
from sympy import sin, cos, diff, lambdify, expand
from sympy.abc import x
from sympy.parsing import parse_expr
MAX_DIGIT_COUNT = 2
MAX_FACTOR_COUNT = 3
MAX_RECURSION_DEPTH = 4
PATTERN_LEADING_ZEROS = re.compile(r'(?:^|(?<=[^0-9]))0+(?=[0-9])')
def gen_ws():
return random.choice(('', '', ' ', '\t'))
def gen_opt_sign():
return random.choice(('', '+', '-'))
def gen_special_const():
return random.choices(
((0, '-0'), (0, '0'), (0, '+0'), (-1, '-1'), (1, '1'), (1, '+1')),
weights=(2, 2, 2, 6, 3, 3)
)[0]
def gen_normal_const():
sign = gen_opt_sign()
digit_count = random.randrange(MAX_DIGIT_COUNT) + 1
digits = ''.join(str(random.randrange(10)) for _ in range(digit_count))
expr_str = sign + digits
return int(expr_str), expr_str
def gen_const():
return random.choice((gen_special_const, gen_normal_const))()
def gen_pow():
return x, 'x'
def gen_sin(expr, expr_str):
return sin(expr), f'sin{gen_ws()}({expr_str})'
def gen_cos(expr, expr_str):
return cos(expr), f'cos{gen_ws()}({expr_str})'
def gen_trig(depth):
expr, expr_str = gen_factor(depth + 1)
return random.choice([gen_sin, gen_cos])(expr, expr_str)
def gen_var(depth):
expr, expr_str = random.choice([gen_pow, functools.partial(gen_trig, depth=depth)])()
# exp, exp_str = gen_const()
exp = random.randrange(-50, 51)
exp_str = str(exp)
expr = expr ** exp
if exp != 1 or random.randrange(2) == 0:
expr_str += f'{gen_ws()}**{gen_ws()}{exp_str}'
return expr, expr_str
def gen_nested_expr(depth):
expr, expr_str = gen_expr(depth=depth + 1)
return expr, f'({expr_str})'
def gen_factor(depth):
gen_funcs = [gen_const, functools.partial(gen_var, depth=depth)]
if depth < MAX_RECURSION_DEPTH:
gen_funcs.append(functools.partial(gen_nested_expr, depth=depth))
return random.choice(gen_funcs)()
def gen_expr(depth=1):
factor_count = random.randrange(MAX_FACTOR_COUNT) + 1
term_count = random.randrange(factor_count) + 1
term_splits = [0] + sorted(random.sample(range(1, factor_count), term_count - 1)) + [factor_count]
term_sizes = [term_splits[i + 1] - term_splits[i] for i in range(term_count)]
expr = 0
expr_str = ''
for term_size in term_sizes:
term_str = gen_opt_sign()
term = -1 if term_str == '-' else 1
if len(term_str) > 0:
term_str += gen_ws()
for i in range(term_size):
factor, factor_str = gen_factor(depth)
term *= factor
if i > 0:
term_str += gen_ws() + '*' + gen_ws()
term_str += factor_str
if len(expr_str) == 0:
sign = gen_opt_sign()
else:
sign = random.choice(('+', '-'))
if sign == '-':
expr -= term
else:
expr += term
if len(sign) > 0:
expr_str += gen_ws() + sign + gen_ws()
else:
expr_str += gen_ws()
expr_str += term_str
return expr, expr_str + gen_ws()
def check_equal(f1, f2):
for _ in range(100):
for _ in range(5):
var = float(random.uniform(-10, 10))
try:
val1 = f1(var)
val2 = f2(var)
except (ZeroDivisionError, OverflowError):
continue
if not math.isfinite(val1) or not math.isfinite(val2):
continue
if not math.isclose(val1, val2, rel_tol=1e-3):
return False, var
break
return True, None
def remove_leading_zeros(s):
return PATTERN_LEADING_ZEROS.sub('', s)
def do_fuzz(_, config):
if config.get('manual'):
input_expr_str = input()
input_expr = parse_expr(remove_leading_zeros(input_expr_str))
else:
input_expr, input_expr_str = gen_expr()
input_sympy_str = str(input_expr)
if 'nan' in input_sympy_str or 'zoo' in input_sympy_str:
return []
ans_expr = expand(diff(input_expr))
ans_func = lambdify(x, ans_expr, 'numpy')
if config.get('debug'):
print(input_expr_str)
print(ans_expr)
subjects = config['subjects']
errors = []
for subject in subjects:
try:
proc = subprocess.run(
subject['cmd'],
input=input_expr_str,
capture_output=True,
text=True,
timeout=5
)
except subprocess.TimeoutExpired:
errors.append(dict(
name=subject['name'],
reason='Time Limit Exceeded',
stdin=input_expr_str,
ans=str(ans_expr)
))
continue
if proc.returncode != 0:
errors.append(dict(
name=subject['name'],
reason='Runtime Error',
stdin=input_expr_str,
stdout=proc.stdout,
stderr=proc.stderr,
ans=str(ans_expr)
))
continue
if config.get('debug'):
print(f'{subject["name"]}: {proc.stdout}')
validator_proc = subprocess.run(
config['validator'],
input=proc.stdout,
capture_output=True,
text=True
)
if validator_proc.returncode != 0:
errors.append(dict(
name=subject['name'],
reason='Wrong Answer (invalid output)',
stdin=input_expr_str,
stdout=proc.stdout,
ans=str(ans_expr),
validator_stdout=validator_proc.stdout
))
continue
output_expr = expand(parse_expr(remove_leading_zeros(proc.stdout)))
check_res, check_var = check_equal(lambdify(x, output_expr, 'numpy'), ans_func)
if not check_res:
errors.append(dict(
name=subject['name'],
reason='Wrong Answer',
stdin=input_expr_str,
stdout=proc.stdout,
ans=str(ans_expr),
var=check_var
))
return errors
def compile_rule_item(k, v):
if k == 'action':
return k, v
else:
return k, re.compile(v)
def compile_rule(rule):
return dict(compile_rule_item(k, v) for k, v in rule.items())
def filter_error(err, rules):
for rule in rules:
matched = True
for key, pattern in rule.items():
if key == 'action':
continue
value = err.get(key)
if value is None or pattern.search(value) is None:
matched = False
break
if matched:
action = rule['action']
if action == 'accept':
return True
elif action == 'ignore':
return False
return True
def main():
with open(sys.argv[1]) as f:
config = json.load(f)
if config.get('manual'):
errors = do_fuzz(None, config=config)
else:
errors = []
with multiprocessing.Pool() as pool:
idx = 0
for res in pool.imap_unordered(functools.partial(do_fuzz, config=config), range(config['count'])):
print(f'#{idx}: {len(res)}')
idx += 1
errors.extend(res)
filter_rules = list(map(compile_rule, config.get('filters', [])))
errors = list(filter(functools.partial(filter_error, rules=filter_rules), errors))
print(json.dumps(errors, indent=2))
if __name__ == '__main__':
main()
# vim: ts=4:sw=4:noet
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment