Skip to content

Instantly share code, notes, and snippets.

@hmate9
Created December 4, 2016 20:34
Show Gist options
  • Save hmate9/cfbb2d2191665f10889f908aa1830a5c to your computer and use it in GitHub Desktop.
Save hmate9/cfbb2d2191665f10889f908aa1830a5c to your computer and use it in GitHub Desktop.
def extract_train_data(winner, replayName):
with open('replays/' + replayName) as replay_file:
game = json.load(replay_file)
height = game['height']
width = game['width']
player_id = get_id(game, winner)
print "Winner:", winner, "(", player_id, ")"
num_frames = game['num_frames']
dataX = []
dataY = []
weights = []
frames = game['frames']
moves = game['moves']
for i, frame in enumerate(frames):
if i == num_frames - 1 or i > limit_n_frames:
break # No moves were made at this frame
total_squares = 0 # How many squares we own
for y, row in enumerate(frame):
for x, square in enumerate(row):
owner = square[0]
if owner == player_id:
total_squares += 1
# Now construct the training data
for y, row in enumerate(frame):
for x, square in enumerate(row):
owner = square[0]
strength = square[1]
move = moves[i][y][x]
if owner == player_id:
if move == 0:
if random.randint(0,still_downsample) != 0:
continue # We downsample
data = np.zeros(shape=(4, look_area, look_area))
found_enemy = False # True if we found at least one enemy
for oi, oX in enumerate(range(x - look_distance, x + look_distance + 1)):
for oj, oY in enumerate(range(y - look_distance, y + look_distance + 1)):
oX = (oX + width) % width
oY = (oY + height) % height
owner = frame[oY][oX][0]
if owner != player_id:
found_enemy = True
strength = frame[oY][oX][1]
production = game['productions'][oY][oX]
data[0][oj][oi] = float(strength) / float(max_strength)
data[1][oj][oi] = float(production) / float(max_production)
data[2][oj][oi] = int(owner == player_id)
data[3][oj][oi] = int(owner not in (0, player_id))
if found_enemy:
dataX.append(data)
dataY.append(move)
weights.append(float(1) / float(total_squares)) # TODO: This is incorrect. total_squares is not the right number
#weights.append(1)
return dataX, dataY, weights
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment