Skip to content

Instantly share code, notes, and snippets.

@utahka
Last active February 21, 2020 15:02
Show Gist options
  • Save utahka/23a717d41c79b904787242f202635b59 to your computer and use it in GitHub Desktop.
Save utahka/23a717d41c79b904787242f202635b59 to your computer and use it in GitHub Desktop.
deap を利用した OneMax 問題
import random
from copy import deepcopy
from deap import base, creator, tools
def eval_one_max(individual) -> tuple:
return sum(individual),
def main():
# creator.FitnessMax クラスと、creator.Individual クラスを作成
creator.create("FitnessMax", base.Fitness, weights=(1.0,))
creator.create("Individual", list, fitness=creator.FitnessMax)
toolbox = base.Toolbox()
# 個体とその集団を作るための関数を作成
toolbox.register("attr_bool", random.randint, 0, 1)
toolbox.register("individual", tools.initRepeat, creator.Individual, toolbox.attr_bool, 100)
toolbox.register("population", tools.initRepeat, list, toolbox.individual)
# 遺伝的アルゴリズムにおける評価、交叉、突然変異、選択をおこなう関数を作成
toolbox.register("evaluate", eval_one_max)
toolbox.register("mate", tools.cxTwoPoint)
toolbox.register("mutate", tools.mutFlipBit, indpb=0.05)
toolbox.register("select", tools.selTournament, tournsize=3)
# 適応度の計算
pop = toolbox.population(n=300)
for ind, fit in zip(pop, map(toolbox.evaluate, pop)):
ind.fitness.values = fit
fits = [ind.fitness.values[0] for ind in pop]
# 交叉する確率と突然変異する確率
CXPB, MUTPB = 0.5, 0.2
n_epochs = 300
for epoch in range(n_epochs):
if not max(fits) < 100:
break
if epoch % 10 == 0:
print(f"--- epoch {epoch} ---")
# 次世代の個体群を選択後、deep copy
offsprings = toolbox.select(pop, len(pop))
offsprings = list(map(toolbox.clone, offsprings))
# 交叉
for child1, child2 in zip(offsprings[::2], offsprings[1::2]):
if random.random() < CXPB:
toolbox.mate(child1, child2)
del child1.fitness.values
del child2.fitness.values
# 突然変異
for child in offsprings:
if random.random() < MUTPB:
toolbox.mutate(child)
del child.fitness.values
invalid_ind = list(filter(lambda ind: not ind.fitness.valid, offsprings))
# invalid_ind = [offspring for offspring in offsprings if not offspring.fitness.valid]
fits = map(toolbox.evaluate, invalid_ind)
for ind, fit in zip(invalid_ind, fits):
ind.fitness.values = fit
pop = deepcopy(offsprings)
fits = [ind.fitness.values[0] for ind in pop]
print(max(fits))
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment