This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
if __name__ == '__main__': | |
gin.parse_config(dqn_config, skip_unknown=False) | |
# train our runner | |
dqn_runner = run_experiment.create_runner(DQN_PATH, schedule='continuous_train') | |
print('Will train DQN agent, please be patient, may be a while...') | |
dqn_runner.run_experiment() | |
print('Done training!') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from absl import app | |
from dopamine.utils import example_viz_lib | |
def main(_): | |
example_viz_lib.run(agent='dqn', | |
game='Pong', | |
num_steps=100, | |
root_dir='C:/Users/Holm/Documents/dopamine', | |
restore_ckpt='C:/Users/Holm/Documents/dopamine/models/tf_ckpt-7', | |
use_legacy_checkpoint=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# import libraries | |
import os | |
import matplotlib.pyplot as plt | |
import torch | |
import torchvision | |
from torch.nn import functional as F | |
from torch.utils.data import DataLoader | |
import pytorch_lightning as pl | |
from pytorch_lightning import Trainer | |
from multiprocessing import Process |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from models.ConvLSTMCell import ConvLSTMCell | |
class EncoderDecoderConvLSTM(nn.Module): | |
def __init__(self, nf, in_chan): | |
super(EncoderDecoderConvLSTM, self).__init__() | |
""" ARCHITECTURE |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.nn as nn | |
import torch | |
class ConvLSTMCell(nn.Module): | |
def __init__(self, input_dim, hidden_dim, kernel_size, bias): | |
""" | |
Initialize ConvLSTM cell. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Verifying my Blockstack ID is secured with the address 1TvufPrebfps5n6Ti78vQP8d4W8mKL9p5 https://explorer.blockstack.org/address/1TvufPrebfps5n6Ti78vQP8d4W8mKL9p5 |