Created
June 29, 2021 21:54
-
-
Save cjmcgraw/ff39d16babd5258db8874439631c2843 to your computer and use it in GitHub Desktop.
Projecting CYOA into a higher dimensional space
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from tensorflow.keras import layers | |
from tensorflow.keras.layers.experimental.preprocessing import StringLookup | |
import logging as log | |
import numpy as np | |
import sys | |
log.basicConfig( | |
level=log.DEBUG, | |
stream=sys.stdout, | |
) | |
if __name__ == '__main__': | |
dataset = ( | |
tf.data.experimental.CsvDataset( | |
filenames="cyoa.csv", | |
record_defaults=["", "", ""], | |
header=True | |
) | |
.map(lambda name, team, stars: ( | |
tf.strings.regex_replace(name, '"', ''), | |
tf.strings.regex_replace(team, '"', ''), | |
tf.cast(tf.strings.length(stars, unit='UTF8_CHAR'), tf.float32), | |
)) | |
.filter(lambda name, team, stars: | |
tf.strings.length(name) > 0 and | |
tf.strings.length(team) > 0 and | |
stars > 0.0, | |
) | |
.map(lambda name, team, stars: ((name, team), stars)) | |
.shuffle(buffer_size=100, reshuffle_each_iteration=True) | |
.batch(64) | |
) | |
all_names, all_teams = zip(*[record[0] for record in iter(dataset.unbatch())]) | |
unique_names = sorted({x.numpy() for x in all_names}) | |
unique_teams = sorted({x.numpy() for x in all_teams}) | |
print(f"unique_names = {unique_names}") | |
print(f"unique_teams = {unique_teams}") | |
print([row for row in iter(dataset)]) | |
name_lookup = StringLookup(vocabulary=unique_names) | |
team_lookup = StringLookup(vocabulary=unique_teams) | |
name_embedding = layers.Embedding(input_dim=len(unique_names) + 2, output_dim=6) | |
team_embedding = layers.Embedding(input_dim=len(unique_teams) + 2, output_dim=6) | |
flatten = layers.Flatten() | |
dot_product = layers.Dot(axes=1) | |
name = tf.keras.Input(shape=(1,), dtype=tf.string, name='name') | |
team = tf.keras.Input(shape=(1,), dtype=tf.string, name='team') | |
name_id = name_lookup(name) | |
team_id = team_lookup(team) | |
name_vector = flatten(name_embedding(name_id)) | |
team_vector = flatten(team_embedding(team_id)) | |
y = dot_product([name_vector, team_vector]) | |
model = tf.keras.Model( | |
inputs=[name, team], | |
outputs=y | |
) | |
model.compile( | |
optimizer='adam', | |
loss='mean_squared_error', | |
metrics='mae' | |
) | |
print(model.summary()) | |
model.fit( | |
dataset.repeat(30), | |
epochs=30, | |
) | |
for name in unique_names: | |
_id = name_lookup([name]) | |
vec = flatten(name_embedding(_id)) | |
print(f"{name},{vec.numpy()[0]}") | |
for team in unique_teams: | |
_id = team_lookup([team]) | |
vec = flatten(team_embedding(_id)) | |
print(f"{team},{vec.numpy()[0]}") | |
def get_user_vector(name): | |
return flatten(name_embedding(name_lookup(name))) | |
def get_team_vector(team): | |
return flatten(team_embedding(team_lookup(team))) | |
team_vectors = { | |
team: get_team_vector(team) | |
for team in unique_teams | |
} | |
name_vectors = { | |
name: get_user_vector(name) | |
for name in unique_names | |
} | |
def dot_vecs(a, b): | |
return tf.reduce_sum(a * b) | |
def get_best_names(name=None, team=None): | |
if name: | |
return sorted( | |
unique_names, | |
key=lambda x: dot_vecs( | |
name_vectors[name], | |
name_vectors[x] | |
) | |
) | |
if team: | |
return sorted( | |
unique_names, | |
key=lambda x: dot_vecs( | |
team_vectors[team], | |
name_vectors[x] | |
) | |
) | |
def get_best_teams(name=None, team=None): | |
if name: | |
return sorted( | |
unique_teams, | |
key=lambda x: dot_vecs( | |
name_vectors[name], | |
team_vectors[x] | |
) | |
) | |
if team: | |
return sorted( | |
unique_teams, | |
key=lambda x: dot_vecs( | |
team_vectors[team], | |
team_vectors[x] | |
) | |
) | |
print("") | |
for name in unique_names: | |
best_names = get_best_names(name=name) | |
best_teams = get_best_teams(name=name) | |
print(name) | |
print("best teams:") | |
print(best_teams[:5]) | |
print("best names:") | |
print(best_names[:5]) | |
print("") | |
for team in unique_teams: | |
best_names = get_best_names(team=team) | |
best_teams = get_best_teams(team=team) | |
print(team) | |
print("best names:") | |
print(best_names[:5]) | |
print("best teams:") | |
print(best_teams[:5]) | |
print("") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment