Created
August 4, 2018 12:27
-
-
Save patricoferris/7248819337dd5ed60d76c91f6d38a823 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Our batch generator - the bias allows us to control how much of the context is from genre and how much is from score | |
def gen_batch(genres, scores, size, bias): | |
#Initialise the numpy arrays - the xs are the target words and the ys the inidividual contexts. | |
xs = np.ndarray(shape=(size), dtype=np.int32) | |
ys = np.ndarray(shape=(size, 1), dtype=np.int32) | |
for idx in range(size): | |
b = np.random.randint(10) | |
if b < bias: | |
genre = np.random.randint(len(genres)) | |
g1 = np.random.randint(len(genres[genre_lookup[genre]])) | |
g2 = np.random.randint(len(genres[genre_lookup[genre]])) | |
#Probably an inefficient way to make sure we don't choose the same artist | |
while g1 == g2: | |
g2 = np.random.randint(len(genres[genre_lookup[genre]])) | |
#Adding the artist unique integer id | |
xs[idx] = genres[genre_lookup[genre]][g1] | |
ys[idx][0] = genres[genre_lookup[genre]][g2] | |
else: | |
score = np.random.randint(len(scores)) | |
s1 = np.random.randint(len(scores[score])) | |
s2 = np.random.randint(len(scores[score])) | |
while s1 == s2: | |
s2 = np.random.randint(len(scores[score])) | |
xs[idx] = scores[score][s1] | |
ys[idx][0] = scores[score][s2] | |
return xs, ys | |
#A useful dictionary to go from unique integer to artist name | |
artist_decode = dict(zip(artist_lookup.values(), artist_lookup.keys())) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment