Skip to content

Instantly share code, notes, and snippets.

@soravux
Last active March 5, 2024 08:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save soravux/1fe0992a79fc07a23d27 to your computer and use it in GitHub Desktop.
Save soravux/1fe0992a79fc07a23d27 to your computer and use it in GitHub Desktop.
Symbolic regression example from deap... with a twist!
import operator
import math
import random
import struct
import numpy
from deap import algorithms
from deap import base
from deap import creator
from deap import tools
from deap import gp
from ctypes import *
# This is obtained by doing a file "test.c" containing:
# double add(double x, double y) { return x + y; }
# and executing:
# gcc -c -O3 ./test.c
# objdump -S ./test.o
add_bytecode = b''.join([ # 0000000000000000 <add>:
b"\xf2", b"\x0f", b"\x58", b"\xc1", # addsd %xmm1,%xmm0
b"\xc3", # retq
])
sub_bytecode = b''.join([ # 0000000000000000 <sub>:
b"\xf2\x0f\x5c\xc1", # subsd %xmm1,%xmm0
b"\xc3", # retq
])
mul_bytecode = b''.join([ # 0000000000000000 <mul>:
b"\xf2\x0f\x59\xc1", # mulsd %xmm1,%xmm0
b"\xc3", # retq
])
neg_bytecode = b''.join([ # 0000000000000000 <neg>:
b"\xf2\x0f\x10\x0d\x05\x00\x00\x00", # movsd 0x05(%rip),%xmm1
b"\x66\x0f\x57\xc1", # xorpd %xmm1,%xmm0
b"\xc3", # retq
b"\x00\x00\x00\x00\x00\x00\x00\x80", # data: double with only sign flag activated
])
libc = CDLL("libc.so.6")
# Some constants
PROT_READ = 1
PROT_WRITE = 2
PROT_EXEC = 4
def executable_code(buff):
"""Return a pointer to a page-aligned executable buffer filled in with the data of the string provided.
The pointer should be freed with libc.free() when finished"""
buf = c_char_p(buff)
size = len(buff)
# Need to align to a page boundary, so use valloc
addr = libc.valloc(size)
addr = c_void_p(addr)
if 0 == addr:
raise Exception("Failed to allocate memory")
memmove(addr, buf, size)
if 0 != libc.mprotect(addr, len(buff), PROT_READ | PROT_WRITE | PROT_EXEC):
raise Exception("Failed to set protection on buffer")
return addr
add_code_ptr = executable_code(add_bytecode)
myAdd = cast(add_code_ptr, CFUNCTYPE(c_double, c_double, c_double))
sub_code_ptr = executable_code(sub_bytecode)
mySub = cast(sub_code_ptr, CFUNCTYPE(c_double, c_double, c_double))
mul_code_ptr = executable_code(mul_bytecode)
myMul = cast(mul_code_ptr, CFUNCTYPE(c_double, c_double, c_double))
neg_code_ptr = executable_code(neg_bytecode)
myNeg = cast(neg_code_ptr, CFUNCTYPE(c_double, c_double))
myAdd.__name__ = 'add'
mySub.__name__ = 'sub'
myMul.__name__ = 'mul'
myNeg.__name__ = 'neg'
primitives_bytecode = [
('add', add_bytecode),
('sub', sub_bytecode),
('mul', mul_bytecode),
('neg', neg_bytecode),
]
primitives_addr = {}
dist = 0
for k, v in primitives_bytecode:
primitives_addr[k] = dist
dist += len(v)
pset = gp.PrimitiveSet("MAIN", 1)
pset.addPrimitive(myAdd, 2)
pset.addPrimitive(mySub, 2)
pset.addPrimitive(myMul, 2)
pset.addPrimitive(myNeg, 1)
pset.addEphemeralConstant("rand101", lambda: random.randint(-1,1))
pset.renameArguments(ARG0='x')
creator.create("FitnessMin", base.Fitness, weights=(-1.0,))
creator.create("Individual", gp.PrimitiveTree, fitness=creator.FitnessMin)
toolbox = base.Toolbox()
toolbox.register("expr", gp.genHalfAndHalf, pset=pset, min_=1, max_=2)
toolbox.register("individual", tools.initIterate, creator.Individual, toolbox.expr)
toolbox.register("population", tools.initRepeat, list, toolbox.individual)
toolbox.register("compile", gp.compile, pset=pset)
def buildASM(individual):
CALL = b"\xE8"
MOV_XMM0toXMM15 = b"\xF3\x44\x0F\x7E\xF8"
MOV_XMM15toXMM0 = b"\xF3\x41\x0F\x7E\xC7"
PUSH_XMM0 = b"\x66\x48\x0F\x7E\xC0\x50"
PUSH_IMM_PRE = b"\x48\xB8"
PUSH_IMM_POST = b"\x50"
PUSH_XMM15 = b"\x66\x4C\x0F\x7E\xF8\x50"
MOV_IMM_PRE = b"\x48\xB8"
MOV_IMM_POST = b"\x66\x48\x0F\x6E\xC0"
POP_XMM0 = b"\x58\x66\x48\x0F\x6E\xC0"
POP_XMM1 = b"\x58\x66\x48\x0F\x6E\xC8"
POP_XMM2 = b"\x58\x66\x48\x0F\x6E\xD0"
POP_XMM3 = b"\x58\x66\x48\x0F\x6E\xD8"
POP_XMM4 = b"\x58\x66\x48\x0F\x6E\xE0"
RET = b"\xc3"
bcode = [RET]
nopush = True
for node in individual:
if node.arity > 1:
if not nopush:
bcode.append(PUSH_XMM0)
bcode.append(CALL + struct.pack('i', len(b"".join(bcode)) + primitives_addr[node.name]))
for reg, _ in zip([POP_XMM1, POP_XMM2, POP_XMM3], list(range(node.arity - 1))):
bcode.append(reg)
nopush = True
else:
if not hasattr(node, 'value'):
# Is a function
if not nopush:
bcode.append(PUSH_XMM0)
bcode.append(CALL + struct.pack('i', len(b"".join(bcode)) + primitives_addr[node.name]))
nopush = True
elif isinstance(node.value, str):
# Is an argument
if nopush:
bcode.append(MOV_XMM15toXMM0)
else:
bcode.append(PUSH_XMM15)
nopush = False
else:
# Is an immediate value
if nopush:
bcode.append(MOV_IMM_PRE + struct.pack('d', node.value) + MOV_IMM_POST)
else:
bcode.append(PUSH_IMM_PRE + struct.pack('d', node.value) + PUSH_IMM_POST)
nopush = False
bcode.append(MOV_XMM0toXMM15)
bcode.reverse()
bcode.append(b"".join(list(zip(*primitives_bytecode))[1]))
return (b"".join(bcode))
def evalSymbReg(individual, points):
# Transform the tree expression in a callable function
ind_code = buildASM(individual)
ind_code_ptr = executable_code(ind_code)
this_ind_fct = cast(ind_code_ptr, CFUNCTYPE(c_double, c_double))
# Evaluate the mean squared error between the expression
# and the real function : x**4 + x**3 + x**2 + x
sqerrors = ((this_ind_fct(x) - x**4 - x**3 - x**2 - x)**2 for x in points)
return math.fsum(sqerrors) / len(points),
toolbox.register("evaluate", evalSymbReg, points=[x/10. for x in range(-10,10)])
toolbox.register("select", tools.selTournament, tournsize=3)
toolbox.register("mate", gp.cxOnePoint)
toolbox.register("expr_mut", gp.genFull, min_=0, max_=2)
toolbox.register("mutate", gp.mutUniform, expr=toolbox.expr_mut, pset=pset)
def main():
random.seed(318)
pop = toolbox.population(n=300)
hof = tools.HallOfFame(1)
stats_fit = tools.Statistics(lambda ind: ind.fitness.values)
stats_size = tools.Statistics(len)
mstats = tools.MultiStatistics(fitness=stats_fit, size=stats_size)
mstats.register("avg", numpy.mean)
mstats.register("std", numpy.std)
mstats.register("min", numpy.min)
mstats.register("max", numpy.max)
pop, log = algorithms.eaSimple(pop, toolbox, 0.5, 0.1, 40, stats=mstats,
halloffame=hof, verbose=True)
# print log
return pop, log, hof
if __name__ == "__main__":
main()
libc.free(add_code_ptr)
libc.free(sub_code_ptr)
libc.free(mul_code_ptr)
libc.free(neg_code_ptr)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment