Skip to content

Instantly share code, notes, and snippets.

@qguv
Last active November 27, 2022 15:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save qguv/ca8809401efea5f682d61471d2c7cf50 to your computer and use it in GitHub Desktop.
Save qguv/ca8809401efea5f682d61471d2c7cf50 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
import collections
import functools
import itertools
import math
import primefac
import sys
tf_try_calls = []
class ThreshholdError(Exception):
pass
def dict_combinations(d):
return [dict(zip(d.keys(), xs)) for xs in itertools.product(*d.values())]
def safe(fn):
def _safe_fn(*args, **kwargs):
try:
return fn(*args, **kwargs)
except Exception as e:
return str(e)
_safe_fn.__name__ = fn.__name__
return _safe_fn
def tf(**kwarg_possibilities):
def _wrapper(fn):
tf_try_calls.append((safe(fn), kwarg_possibilities))
return fn
return _wrapper
def get_digits(x, base=10):
if type(x) != int:
raise ValueError(f"can't convert {x} to base-{base} int")
if base == 2:
s = bin(x)[2:]
elif base == 8:
s = oct(x)[2:]
elif base == 10:
s = str(x)
elif base == 16:
s = hex(x)[2:]
else:
raise ValueError(f"can't convert {x} to base-{base} int")
return [int(c, base=base) for c in s]
@tf(base=[2, 8, 10, 16], holy_four=[True, False], hex_upper=[True, False])
@functools.cache
def count_holes(n, holy_four=True, hex_upper=False, base=10):
count = collections.Counter(get_digits(n, base=base))
hole_digits = [0, 6, 8, 8, 9, 10, 13]
if holy_four:
hole_digits.append(4)
if hex_upper:
# two-hole B, no-hole E
hole_digits.extend([11, 11])
else:
# one-hole b, one-hole e
hole_digits.extend([11, 14])
return sum(count[x] for x in hole_digits)
@tf(power_base=list(range(2, 17)))
@functools.cache
def is_power_of(n, power_base=2, precision=15):
if n == 0:
return False
return round(math.log(n, power_base), precision).is_integer()
@tf()
@functools.cache
def count_syllables(x, log=lambda x: None):
if not x:
log('zero')
return 2
return _count_syllables(x, log=log)
def _count_syllables(x, suffix=None, log=lambda x: None):
syllables = 0
if x < 0:
log('negative')
x *= -1
if x > 999:
thousands = x // 1000
syllables += _count_syllables(thousands)
log('thousand')
syllables += 2
x %= 1000
if x > 99:
hundreds = x // 100
syllables += _count_syllables(hundreds)
log('hundred')
syllables += 2
x %= 100
if x > 19:
tens = x // 10
syllables += _count_syllables(tens, suffix='ty')
x %= 10
if x > 12:
syllables += _count_syllables(x % 10, suffix='teen')
return syllables
if x > 9:
log(['ten', 'eleven', 'twelve'][x % 10])
syllables += 3 if x == 11 else 1
return syllables
if x > 0:
if suffix:
log(['', '', 'twen', 'thir', 'four', 'fif', 'six', 'seven', 'eigh', 'nine'][x] + suffix)
else:
log(['', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten', 'eleven', 'twelve'][x])
syllables += 2 if x == 7 else 1
return syllables
return 0
@tf(base=[2, 8, 10, 16])
@functools.cache
def digits_are_sorted(x, base=10):
digits = get_digits(x, base=base)
sorted_digits = sorted(digits)
return digits == sorted_digits or digits == sorted_digits[::-1]
def concat(digits, base=10):
x = 0
for digit in digits:
x *= base
x += digit
return x
# @tf(base=[2, 8, 10, 16], reverse=[False, True])
def sort_digits(x, base=10, reverse=False):
digits = get_digits(x, base=base)
return concat(sorted(digits, reverse=reverse), base=base)
@tf(base=[2], remove_digit=list(range(2)))
@tf(base=[8], remove_digit=list(range(8)))
@tf(base=[10], remove_digit=list(range(10)))
@tf(base=[16], remove_digit=list(range(16)))
def remove_digit(x, base=10, remove_digit=0):
digits = get_digits(x, base=base)
return concat((x for x in digits if x != remove_digit), base=base)
@tf(base=[2], digit_from=list(range(2)), digit_to=list(range(2)))
@tf(base=[8], digit_from=list(range(8)), digit_to=list(range(8)))
@tf(base=[10], digit_from=list(range(10)), digit_to=list(range(10)))
@tf(base=[16], digit_from=list(range(16)), digit_to=list(range(16)))
def replace_digit(x, base=10, digit_from=0, digit_to=0):
if digit_from == digit_to:
return x
digits = get_digits(x, base=base)
return concat((digit_to if x == digit_from else x for x in digits), base=base)
@tf(base=[2, 8, 10, 16], reverse=[False, True])
def sort_unique_digits(x, base=10, reverse=False):
digits = set(get_digits(x, base=base))
return concat(sorted(digits, reverse=reverse))
@tf(base=[2, 8, 10, 16])
@functools.cache
def digit_sum(x, base=10):
return sum(get_digits(x, base=base))
@tf(base=[2, 8, 10, 16])
@functools.cache
def unique_digit_sum(x, base=10):
return sum(set(get_digits(x, base=base)))
def product(xs):
return functools.reduce(lambda x, y: x * y, xs, 1)
@tf(base=[2, 8, 10, 16])
@functools.cache
def digit_product(x, base=10):
return product(get_digits(x, base=base))
@tf(base=[2, 8, 10, 16])
@functools.cache
def unique_digit_product(x, base=10):
return product(set(get_digits(x, base=base)))
@functools.cache
def get_prime_factors(x):
return tuple(primefac.primefac(x))
def get_unique_prime_factors(x):
return tuple(sorted(set(get_prime_factors(x))))
@tf()
def count_prime_factors(x):
return len(get_prime_factors(x))
@tf()
def count_unique_prime_factors(x):
return len(set(get_prime_factors(x)))
@tf()
@functools.cache
def is_prime(x):
return primefac.isprime(x)
@tf()
@functools.cache
def prime_factors_sum(x):
return sum(get_prime_factors(x))
@tf()
@functools.cache
def unique_prime_factors_sum(x):
return sum(set(get_prime_factors(x)))
def has_duplicates(xs):
seen = set()
for x in xs:
if x in seen:
return True
seen.add(x)
return False
@tf(base=[2, 8, 10, 16])
@functools.cache
def digits_have_duplicates(x, base=10):
return has_duplicates(get_digits(x, base=base))
@tf()
@functools.cache
def prime_factors_have_duplicates(x):
return has_duplicates(get_prime_factors(x))
def reduce_digital(fn, x, base=10):
while True:
xs = get_digits(x, base=base)
if len(xs) == 1:
return x
x = fn(xs)
@tf(base=[2, 8, 10, 16])
@functools.cache
def digital_sum(x, base=10):
return reduce_digital(sum, x, base=base)
@tf(base=[2, 8, 10, 16])
@functools.cache
def digital_product(x, base=10):
'''
>>> digital_product(0, base=10)
0
>>> digital_product(1, base=10)
1
>>> digital_product(24, base=10)
8
>>> digital_product(25, base=10)
0
>>> digital_product(26, base=10)
2
'''
return reduce_digital(product, x, base=base)
@tf(target_digit=list(range(10)), base=[2, 8, 10, 16])
def count_consecutive_digits(n, target_digit=0, base=10):
'''
>>> count_consecutive_digits(1020030002001, target_digit=0, base=10)
3
>>> count_consecutive_digits(1, target_digit=0, base=10)
0
'''
current = 0
longest = 0
for digit in get_digits(n, base=base):
if digit == target_digit:
current += 1
longest = max(current, longest)
if digit != target_digit:
current = 0
return longest
@tf(divisor=list(range(2, 16)))
def mod(x, divisor=2):
return x % divisor
# @tf(base=[16], digit=list(range(16)))
# @tf(base=[10], digit=list(range(10)))
# @tf(base=[8], digit=list(range(8)))
# @tf(base=[2], digit=list(range(2)))
def count_digit(x, base=10, digit=0):
digits = get_digits(x, base=base)
return sum(1 for d in digits if d == digit)
def get_call_name(call):
fn, kwargs = call
if not kwargs:
return fn.__name__
return f'{fn.__name__}({str(kwargs)[1:-1]})' if kwargs else fn.__name__
def get_overlap(intersection, htbn_mapped, not_htbn_mapped):
mapped = list(htbn_mapped.values()) + list(not_htbn_mapped.values())
amount = sum(1 for x in mapped if x in intersection)
percent = 100 * amount / len(mapped)
return percent
def fmt_set(xs):
items = ', '.join(str(x) for x in sorted(xs))
if len(xs) == 1:
return items
return f'in {{{items}}}'
def fmt_rule(name, calls, htbn_mapped, not_htbn_mapped, intersection):
positive_rule, positive_penalty = _fmt_rule(name, calls, htbn_mapped, not_htbn_mapped, intersection)
print(f"positive rule ({positive_penalty}): {positive_rule}")
negative_rule, negative_penalty = _fmt_rule(name, calls, not_htbn_mapped, htbn_mapped, intersection)
print(f"negative rule ({negative_penalty}): {negative_rule}")
if negative_penalty < positive_penalty:
return (f"AKHTBN unless {negative_rule}", negative_penalty)
return (f"AKHTBN if {positive_rule}", positive_penalty)
def _fmt_rule(name, calls, htbn_mapped, not_htbn_mapped, intersection, negative=False):
in_set = set()
positive_exceptions = set()
for k, v in htbn_mapped.items():
if v in intersection:
positive_exceptions.add(k)
else:
in_set.add(v)
negative_exceptions = set()
for k, v in not_htbn_mapped.items():
if v in intersection:
negative_exceptions.add(k)
if len(positive_exceptions) > len(negative_exceptions):
positive_exceptions = set()
in_set.update(intersection)
rule = f"{name} is {fmt_set(in_set)} and is not {fmt_set(negative_exceptions)}"
elif positive_exceptions:
negative_exceptions = set()
rule = f"{name} is {fmt_set(in_set)} or is {fmt_set(positive_exceptions)}"
else:
rule = f"{name} is {fmt_set(in_set)}"
penalty = 0
penalty += 3 * len(calls)
penalty += len(in_set)
penalty += 4 * (len(positive_exceptions) + len(negative_exceptions))
# exotic bases penalty
if any(call[1].get('base', 10) != 10 for call in calls):
penalty += 5
return (rule, penalty)
def print_compose(htbn, not_htbn, calls, retention_percent=50, overlap_percent=20):
htbn_mapped = {k: k for k in htbn}
not_htbn_mapped = {k: k for k in not_htbn}
for fn, kwargs in calls:
last_htbn_mapped = htbn_mapped.copy()
last_not_htbn_mapped = not_htbn_mapped.copy()
for orig, x in htbn_mapped.items():
htbn_mapped[orig] = fn(x, **kwargs)
for orig, x in not_htbn_mapped.items():
not_htbn_mapped[orig] = fn(x, **kwargs)
# abort if any function leaves koans unchanged
if last_htbn_mapped == htbn_mapped and last_not_htbn_mapped == not_htbn_mapped:
raise ThreshholdError("function leaves koans unchanged")
htbn_unique = set(htbn_mapped.values())
not_htbn_unique = set(not_htbn_mapped.values())
intersection = htbn_unique.intersection(not_htbn_unique)
overlap = get_overlap(intersection, htbn_mapped, not_htbn_mapped)
if overlap > overlap_percent:
raise ThreshholdError(f"overlap {overlap}% exceeds threshhold {overlap_percent}%")
htbn_retention = 100 * len(htbn_unique) / len(htbn)
not_htbn_retention = 100 * len(not_htbn_unique) / len(not_htbn)
if htbn_retention > retention_percent:
raise ThreshholdError(f"retention {htbn_retention} in koans with BN doesn't meet threshhold {retention_percent}")
if not_htbn_retention > retention_percent:
raise ThreshholdError(f"retention {not_htbn_retention} in koans without BN doesn't meet threshhold {retention_percent}")
name = ', '.join(map(get_call_name, calls))
rule, penalty = fmt_rule(name, calls, htbn_mapped, not_htbn_mapped, intersection)
details = f"\n=== {name} ===\n"
if not intersection and 1 in [len(htbn_unique), len(not_htbn_unique)]:
details += "#############################\n"
details += "### LIKELY FOUND THE RULE ###\n"
details += "#############################\n"
details += f'penalty: {penalty}\n'
details += f"htbn retention: {htbn_retention}%\n"
details += f"not_htbn retention: {not_htbn_retention}%\n"
details += f"overlap: {overlap}%\n"
details += print_table(htbn_mapped, not_htbn_mapped, intersection)
print(f'\n{penalty} {name}\n')
return (rule, penalty, details)
def print_table(htbn_mapped, not_htbn_mapped, intersection):
fmt = "%10s %10s%1s | %10s %10s%s\n"
msg = fmt % ('has ', 'fn(has)', '', 'not has ', 'fn(not has)', '')
lines = itertools.zip_longest(htbn_mapped.items(), not_htbn_mapped.items(), fillvalue=(None, ''))
for (h, hm), (nh, nhm) in lines:
h = '' if h is None else f'{h}:'
nh = '' if nh is None else f'{nh}:'
hi = '!' if hm in intersection else ''
nhi = '!' if nhm in intersection else ''
hm = 'Err' if type(hm) is ValueError else str(hm)
nhm = 'Err' if type(nhm) is ValueError else str(nhm)
msg += fmt % (h, hm, hi, nh, nhm, nhi)
return msg
def analyze_until_break(htbn, not_htbn, calls_to_try):
rules = []
i = 0
try:
while True:
i += 1
print(f"\nround {i}:")
# choose a sequence of functions
for calls in itertools.product(calls_to_try, repeat=i):
# narrow down possible kwarg values to those suitable for all functions in this sequence
common_kwargs = dict()
for _, kwargs in calls:
for name, values in kwargs.items():
try:
common_values = common_kwargs[name]
common_values.intersection_update(values)
except KeyError:
common_values = set(values)
common_kwargs[name] = common_values
# if any kwarg has no possibilities left, then we can't run this sequence of functions
if not all(common_kwargs.values()):
print('x', end='')
continue
print('>', end='')
# choose a particular value for those kwargs...
for pinned_kwargs in dict_combinations(common_kwargs):
# ...and use those values to make each call
pinned_calls = [(fn, {k: pinned_kwargs[k] for k in kwargs.keys()}) for fn, kwargs in calls]
try:
rules.append(print_compose(htbn, not_htbn, pinned_calls))
print('o', end='')
except ThreshholdError as e:
print(f'\n{e}') # DEBUG
pass
input(f"\non to round {i+1}?")
except KeyboardInterrupt:
return rules
def analyze(htbn, not_htbn, calls_to_try):
rules = analyze_until_break(htbn, not_htbn, calls_to_try=calls_to_try)
rules.sort(key=lambda x: x[1])
for rule, penalty, details in rules:
try:
input("\nPress enter to continue...")
except KeyboardInterrupt:
return
print(penalty, rule, f'\n{details}\n')
@safe
def lemon4(n, base=10):
errors = []
if count_consecutive_digits(n, 0) > 1:
errors.append('consecutive zero digits')
digits = get_digits(n, base=base)
if digits[-1] == 4:
errors.append('mod 10 == 4')
p = product(digits)
if p % 2 == 1:
errors.append('digit product odd')
if errors:
raise ValueError(', '.join(errors))
return 'ok'
@safe
def lemon5(n, base=10):
errors = []
if count_consecutive_digits(n, 0) > 1:
errors.append('consecutive zero digits')
digits = get_digits(n, base=base)
if product(digits) % 2 == 1:
errors.append('digit product odd')
if errors:
raise ValueError(', '.join(errors))
# return ','.join(str(x) for x in primefac.primefac(n))
return f"{count_prime_factors(n)} prime factors"
def read_export(f):
htbn = []
not_htbn = []
for line in f:
line = line.strip()
if not line:
continue
line = line[1:-1].split('","')
try:
n = int(line[1])
except ValueError:
continue
has = line[2] == '1'
group = htbn if has else not_htbn
group.append(n)
return (htbn, not_htbn)
if __name__ == "__main__":
with open(sys.argv[1], 'r') as f:
htbn, not_htbn = read_export(f)
htbn.sort()
not_htbn.sort()
# analyze(htbn, not_htbn, [(lemon5, dict())])
analyze(htbn, not_htbn, tf_try_calls)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment