This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
from collections import defaultdict | |
import minitorch | |
import time | |
import sys | |
import numpy as np | |
FastTensorBackend = minitorch.TensorBackend(minitorch.FastOps) | |
GPUBackend = minitorch.TensorBackend(minitorch.CudaOps) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import plotly.graph_objects as go | |
import numpy as np | |
np.random.seed(1) | |
p = 4 | |
x = np.linspace(0,10) | |
y = x ** p | |
x_center = 5. | |
y_center = 5. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as onp | |
import jax | |
import jax.numpy as jnp | |
from jax import value_and_grad, grad, jit, random, lax | |
from jax.nn import log_softmax | |
from jax.scipy.special import logsumexp as lse | |
def init_model(rng, X, Z): | |
rng_z, rng_x = random.split(rng) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
networkx==2.4 | |
pydot==1.4.1 | |
streamlit==1.12.0 | |
watchdog==1.0.2 | |
plotly==4.14.3 | |
python-mnist |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
numpy == 1.22.1 | |
numba == 0.56 | |
pytest == 6.0.1 | |
pytest-env | |
pytest-runner == 5.2 | |
hypothesis == 4.38 | |
flake8==3.8.3 | |
black==22.3.0 | |
colorama==0.4.3 | |
pep8-naming==0.11.1 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
minitorch/operators.py | |
minitorch/module.py | |
minitorch/autodiff.py | |
minitorch/scalar.py | |
minitorch/tensor_data.py | |
minitorch/tensor_functions.py | |
minitorch/tensor_ops.py | |
project/run_scalar.py | |
project/run_tensor.py | |
tests/test_module.py |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Description: | |
Note: Make sure that both the new and old module files are in same directory! | |
This script helps you sync your previous module works with current modules. | |
It takes 2 arguments, source_dir_name and destination_dir_name. | |
All the files which will be moved are specified in files_to_sync.txt as newline separated strings | |
Usage: python sync_previous_module.py <source_dir_name> <dest_dir_name> |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/train.py b/train.py | |
index 7c8786e..3c7d3fc 100644 | |
--- a/train.py | |
+++ b/train.py | |
@@ -75,7 +75,9 @@ def main() -> None: | |
decoder = Decoder(N, K, image_shape) | |
model = CategoricalVAE(encoder, decoder) | |
- optimizer = optim.SGD(model.parameters(), lr=initial_learning_rate, momentum=0.9) | |
+ parameters = list(model.parameters()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/train.py b/train.py | |
index 7c8786e..e7cef7f 100644 | |
--- a/train.py | |
+++ b/train.py | |
@@ -56,7 +56,7 @@ def main() -> None: | |
wandb_log_interval = 100 | |
batch_size = 100 | |
max_steps = 50_000 | |
- initial_learning_rate = 0.001 | |
+ initial_learning_rate = 0.0001 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/start.py b/start.py | |
index 6ebf861..4e69d62 100644 | |
--- a/start.py | |
+++ b/start.py | |
@@ -219,7 +219,7 @@ class Problem: | |
best_parameters = None | |
best_dislike = float('inf') | |
- total_epochs = 4000 | |
+ total_epochs = 8000 |
NewerOlder