Skip to content

Instantly share code, notes, and snippets.

@nevercast
Created December 28, 2021 05:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nevercast/7ee7cb53783db03a3eca11bf49743e37 to your computer and use it in GitHub Desktop.
Save nevercast/7ee7cb53783db03a3eca11bf49743e37 to your computer and use it in GitHub Desktop.
Neural Network visualisations
# https://stackoverflow.com/questions/52468956/how-do-i-visualize-a-net-in-pytorch
import os
import numpy as np
import torch
from torch.distributions import Categorical
import torch.nn.functional as F
from rlgym.utils.gamestates import GameState, PlayerData, PhysicsObject
from training.obs import NectoObsBuilder
cur_dir = os.path.dirname(os.path.realpath(__file__))
model = torch.load(os.path.join(cur_dir, "necto-model.pt"))
state = [8.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -3968.989990234375, -725.5199584960938, 164.6199951171875, -65.81099700927734, -279.9809875488281, 381.6910095214844, 4.006410121917725, -3.778909921646118, -0.6741099953651428, 3968.989990234375, 725.5199584960938, 164.6199951171875, 65.81099700927734, 279.9809875488281, 381.6910095214844, -4.006410121917725, 3.778909921646118, -0.6741099953651428, 1.0, 0.0, -2956.60986328125, -1385.469970703125, 17.0, -0.15378732979297638, -0.004750117193907499, -0.0007363896002061665, 0.9880921244621277, -1480.6510009765625, -464.58099365234375, 0.3309999704360962, 0.0, 0.0002099999983329326, 0.03440999612212181, 2956.60986328125, 1385.469970703125, 17.0, 0.9880921244621277, -0.0007363896002061665, 0.004750117193907499, 0.15378732979297638, 1480.6510009765625, 464.58099365234375, 0.3309999704360962, -0.0, -0.0002099999983329326, 0.03440999612212181, 0.0, 0.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 1.0, 0.24686768651008606, 5.0, 1.0, -3007.469970703125, 2445.449951171875, 17.010000228881836, -0.3034341037273407, -0.004550510086119175, -0.001470081857405603, 0.952840268611908, -877.9209594726562, -505.760986328125, 0.23099999129772186, 0.001509999972768128, 0.0008099999977275729, 2.3045098781585693, 3007.469970703125, -2445.449951171875, 17.010000228881836, 0.952840268611908, -0.001470081857405603, 0.004550510086119175, 0.3034341037273407, 877.9209594726562, 505.760986328125, 0.23099999129772186, -0.001509999972768128, -0.0008099999977275729, 2.3045098781585693, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.7634262442588806]
state = GameState(state)
obs_builder = NectoObsBuilder(n_players=2)
# Discrete actions
action = np.array([
1, # throttle = 0
1, # steer = 0
1, # pitch = 0
1, # yaw = 0
1, # roll = 0
0, # jump = False
0, # boost = False
0, # skrrt = False
])
obs_builder.reset(state)
obs = obs_builder.build_obs(state.players[0], state, action)
# self.action = self.agent.act(obs)
state = tuple(torch.from_numpy(s).float() for s in obs)
with torch.no_grad():
out = model(state)
max_shape = max(o.shape[-1] for o in out)
logits = torch.stack(
[
l
if l.shape[-1] == max_shape
else F.pad(l, pad=(0, max_shape - l.shape[-1]), value=float("-inf"))
for l in out
]
).swapdims(0, 1).squeeze()
dist = Categorical(logits=logits)
actions = dist.sample().numpy()
actions = actions.reshape((-1, 5))
actions[:, 0] = actions[:, 0] - 1
actions[:, 1] = actions[:, 1] - 1
parsed = np.zeros((actions.shape[0], 8))
parsed[:, 0] = actions[:, 0] # throttle
parsed[:, 1] = actions[:, 1] # steer
parsed[:, 2] = actions[:, 0] # pitch
parsed[:, 3] = actions[:, 1] * (1 - actions[:, 4]) # yaw
parsed[:, 4] = actions[:, 1] * actions[:, 4] # roll
parsed[:, 5] = actions[:, 2] # jump
parsed[:, 6] = actions[:, 3] # boost
parsed[:, 7] = actions[:, 4] # handbrake
print(parsed[0])
input_names = None # I have no idea
output_names = ['o0', 'o1', 'o2', 'o3', 'o4']
torch.onnx.export(model, (state,), 'necto.onnx', input_names=input_names, output_names=output_names)
# from torchviz import make_dot
# make_dot(out, params=dict(list(model.named_parameters()))).render("torchviz", format="png")
# import hiddenlayer as hl
# GraphViz goes in ./bin
# os.environ['path'] += os.path.abspath(os.path.curdir + "\\bin")
# transforms = [ hl.transforms.Prune('Constant') ] # Removes Constant nodes from graph.
# graph = hl.build_graph(model, (state,))#, transforms=transforms)
# graph.theme = hl.graph.THEMES['blue'].copy()
# graph.save('layers', format='png')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment