Last active
October 25, 2023 10:25
-
-
Save kinchungwong/c6dd8016957ab8ca79d873449980f57a to your computer and use it in GitHub Desktop.
MITx 6.86x Project 3 Google Colab helper code
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
%load_ext autoreload | |
%autoreload 2 | |
import os | |
from os.path import join as pathjoin | |
from os.path import isdir, isfile | |
import shutil | |
import subprocess | |
import datetime | |
import functools | |
from functools import lru_cache | |
import contextlib | |
from contextlib import contextmanager | |
import filecmp | |
from google.colab import drive | |
assets_URL = "https://courses.edx.org/__REDACTED__/__FILENAME__.tar.gz" | |
runtime_basedir = "/content/project3_runtime" | |
backup_basedir = "/content/drive/MyDrive/Colab/__COURSENAME__" | |
backup_prefix = "project3_" | |
part2_nn_dir = pathjoin(runtime_basedir, "mnist/part2-nn") | |
part2_mnist_dir = pathjoin(runtime_basedir, "mnist/part2-mnist") | |
part2_twodigit_dir = pathjoin(runtime_basedir, "mnist/part2-twodigit") | |
def run_part2_nn(): | |
os.chdir(part2_nn_dir) | |
%run -i 'neural_nets.py' | |
def make_timestamp_filename(): | |
"""Returns a string that can be used as part of a file name. | |
""" | |
t = datetime.datetime.now() | |
fmt = "%Y-%m-%d_%H-%M-%S_%f" | |
return t.strftime(fmt) | |
@contextmanager | |
def cwd(path): | |
"""Change working directory for a code block and restore to | |
original working directory upon exit. | |
Credits: | |
https://stackoverflow.com/a/37996581 | |
""" | |
oldpwd = os.getcwd() | |
os.chdir(path) | |
try: | |
yield | |
finally: | |
os.chdir(oldpwd) | |
def recursive_files(path, dir_filter = None, file_filter = None): | |
if dir_filter is None: | |
dir_filter = lambda _: True | |
elif not callable(dir_filter): | |
raise Exception("recursive_files(): bad dir_filter") | |
if file_filter is None: | |
file_filter = lambda _: True | |
elif not callable(file_filter): | |
raise Exception("recursive_files(): bad file_filter") | |
nofollow = [".", ".."] | |
paths = [path] | |
files = [] | |
next_p = 0 | |
while next_p < len(paths): | |
path = paths[next_p] | |
next_p += 1 | |
with os.scandir(path) as itr: | |
for entry in itr: | |
if entry.name in nofollow: | |
continue | |
ep = pathjoin(path, entry.name) | |
if entry.is_file() and file_filter(ep): | |
files.append(ep) | |
elif entry.is_dir() and dir_filter(ep): | |
paths.append(ep) | |
return files | |
def sync_files(srcdir, dstdir, relative_files): | |
"""For each file path in relative_files, locate the file in srcdir | |
and copy each into dstdir. | |
The file paths in relative_files can contain the path separator. | |
The caller is responsible for ensuring that the file paths can be | |
correctly combined with either srcdir or dstdir using os.path.join(). | |
""" | |
for filename in relative_files: | |
srcfull = pathjoin(srcdir, filename).replace("/./", "/") | |
dstfull = pathjoin(dstdir, filename).replace("/./", "/") | |
if not isfile(srcfull): | |
print(f"Source file {srcfull} not found; skipped") | |
continue | |
dst_parent = os.path.dirname(dstfull) | |
if not isdir(dst_parent): | |
print(f"Creating destination directory {dst_parent}") | |
os.makedirs(dst_parent, exist_ok=True) | |
print(f"Copying from {srcfull} to {dstfull}") | |
shutil.copyfile(srcfull, dstfull) | |
def diff_files(srcdir, dstdir, relative_files): | |
"""For each file path in relative_files, locate the file in srcdir | |
and dstdir, and if both copies exist, print a summary of diffs. | |
The file paths in relative_files can contain the path separator. | |
The caller is responsible for ensuring that the file paths can be | |
correctly combined with either srcdir or dstdir using os.path.join(). | |
""" | |
for filename in relative_files: | |
srcfull = pathjoin(srcdir, filename).replace("/./", "/") | |
dstfull = pathjoin(dstdir, filename).replace("/./", "/") | |
if not isfile(srcfull): | |
print(f"Source file {srcfull} not found; skipped") | |
continue | |
if not isfile(dstfull): | |
print(f"Dest file {dstfull} not found; skipped") | |
continue | |
cmdopts = " ".join([ | |
"--text", | |
"--strip-trailing-cr", | |
"--unified=10", | |
"--ignore-trailing-space", | |
"--ignore-tab-expansion", | |
"--ignore-all-space", | |
"--ignore-blank-lines", | |
]) | |
cmd = f"diff {cmdopts} {srcfull} {dstfull}" | |
# subprocess.call(cmd, shell=True) | |
!{cmd} | |
def connect_to_drive(): | |
if not isdir("/content/drive"): | |
drive.mount('/content/drive') | |
def extract_project_template(): | |
print("Downloading project data archive...") | |
if not isdir(runtime_basedir): | |
os.makedirs(runtime_basedir, exist_ok=True) | |
temp_tgz = "download_" + make_timestamp_filename() + ".tgz" | |
cmd = f"cd {runtime_basedir} && wget -O {temp_tgz} {assets_URL}" | |
subprocess.call(cmd, shell=True) | |
print("Extracting data files...") | |
cmd = f"cd {runtime_basedir} && tar xvf {temp_tgz}" | |
subprocess.call(cmd, shell=True) | |
print("Cleaning up downloads...") | |
cmd = f"cd {runtime_basedir} && rm {temp_tgz}" | |
subprocess.call(cmd, shell=True) | |
print("Done.") | |
class project_file_filter_cls: | |
def __init__(self): | |
self.utils_path = [ | |
"mnist/utils.py" | |
] | |
self.part2_dirs = [ | |
"mnist/part2-nn", | |
"mnist/part2-mnist", | |
"mnist/part2-twodigit", | |
] | |
def __call__(self, file_path: str): | |
if not file_path.endswith(".py"): | |
return False | |
for p in self.utils_path: | |
if file_path.endswith(p): | |
return True | |
for d in self.part2_dirs: | |
if d in file_path: | |
return True | |
return False | |
@lru_cache | |
def scan_project_files(): | |
if not isdir(runtime_basedir): | |
raise NotADirectoryError(f"Project runtime directory not found: {runtime_basedir}") | |
with cwd(runtime_basedir): | |
filenames: list[str] = recursive_files( | |
"./", | |
file_filter = project_file_filter_cls() | |
) | |
filenames.sort() | |
return filenames | |
BACKUP_TO_NEW = "__BACKUP_TO_NEW__" | |
RESTORE_FROM_EARLIEST = "__RESTORE_FROM_EARLIEST__" | |
RESTORE_FROM_LATEST = "__RESTORE_FROM_LATEST__" | |
RUNTIME_BASEDIR = "__RUNTIME_BASEDIR__" | |
DEFAULT_BACKUP = BACKUP_TO_NEW | |
DEFAULT_RESTORE = RESTORE_FROM_LATEST | |
def list_backup_dirs() -> list[str]: | |
nofollow = [ | |
".", | |
"..", | |
] | |
if not isdir(backup_basedir): | |
raise NotADirectoryError(backup_basedir) | |
with cwd(backup_basedir): | |
backup_dirs = os.listdir("./") | |
backup_dirs = [p for p in backup_dirs if (p not in nofollow)] | |
backup_dirs.sort() | |
return backup_dirs | |
def resolve_backup_dir(ident: str) -> str: | |
"""Resolves a backup directory (in Google Drive) for a | |
backup or restore operation. | |
The special values DEFAULT_BACKUP and DEFAULT_RESTORE | |
can be used. | |
Arguments: | |
ident: a string or condition to help select the | |
correct backup directory for the operation. | |
Exceptions: | |
NotADirectoryError if the backup base directory does | |
not exist. This happens if the Google Drive has not | |
been mounted. | |
Exception if the ident string or condition does not | |
select a unique backup directory. This happens if | |
it matches nothing, or if it matches more than one | |
backup directory. | |
""" | |
if not isdir(backup_basedir): | |
raise NotADirectoryError(backup_basedir) | |
if not isdir(runtime_basedir): | |
raise NotADirectoryError(runtime_basedir) | |
if ident == RUNTIME_BASEDIR: | |
return runtime_basedir | |
if ident == BACKUP_TO_NEW: | |
backup_dir = backup_prefix + make_timestamp_filename() | |
backup_dir = pathjoin(backup_basedir, backup_dir) | |
backup_dir = backup_dir.replace("/./", "/") | |
return backup_dir | |
backup_dirs = list_backup_dirs() | |
if ident == RESTORE_FROM_EARLIEST: | |
if len(backup_dirs) >= 1: | |
backup_dirs = [backup_dirs[0]] | |
else: | |
raise Exception("No backup directories available to restore from.") | |
elif ident == RESTORE_FROM_LATEST: | |
if len(backup_dirs) >= 1: | |
backup_dirs = [backup_dirs[-1]] | |
else: | |
raise Exception("No backup directories available to restore from.") | |
elif isinstance(ident, str): | |
backup_dirs = [p for p in backup_dirs if (ident in p)] | |
elif callable(ident): | |
backup_dirs = [p for p in backup_dirs if (ident(p))] | |
else: | |
raise TypeError( | |
"restore_files(): bad lookup type " + type(ident).__name__) | |
if len(backup_dirs) == 1: | |
backup_dir = backup_dirs[0] | |
backup_dir = pathjoin(backup_basedir, backup_dir) | |
backup_dir = backup_dir.replace("/./", "/") | |
return backup_dir | |
else: | |
raise Exception( | |
"resolve_backup_dir(): the lookup string \"" + ident | |
+ "\" does not identify any unique backup directory " | |
+ "to restore from.") | |
def is_backup_needed() -> bool: | |
filecmp.clear_cache() | |
project_files = scan_project_files() | |
backup_dirs = list_backup_dirs() | |
backup_dirs = [ | |
pathjoin(backup_basedir, p).replace("/./", "/") | |
for p in backup_dirs | |
] | |
for backup_dir in backup_dirs: | |
dir_has_changed = False | |
for filename in project_files: | |
srcfull = pathjoin(runtime_basedir, filename).replace("/./", "/") | |
dstfull = pathjoin(backup_dir, filename).replace("/./", "/") | |
if not isfile(srcfull) or not isfile(dstfull): | |
dir_has_changed = True | |
break | |
file_is_same = filecmp.cmp(srcfull, dstfull, shallow=False) | |
if not file_is_same: | |
dir_has_changed = True | |
break | |
if not dir_has_changed: | |
print("Runtime directory is identical to ") | |
print("\t", backup_dir) | |
print("No backup needed.") | |
return False | |
# None of the backup dirs matches exactly with the runtime base dir. | |
print("Backup is needed, runtime directory contains new changes.") | |
return True | |
def backup_files() -> str: | |
"""Backs up Python source files from project directory into Google Drive. | |
Returns: | |
The newly-created, timestamped backup directory in Google Drive. | |
""" | |
if not isdir(runtime_basedir): | |
raise NotADirectoryError(runtime_basedir) | |
if not isdir("/content/drive"): | |
raise NotADirectoryError("/content/drive") | |
print("Backing up project source code...") | |
backup_dir = resolve_backup_dir(DEFAULT_BACKUP) | |
if not isdir(backup_dir): | |
print(f"Creating backup directory {backup_dir}") | |
os.makedirs(backup_dir, exist_ok=True) | |
project_files = scan_project_files() | |
sync_files(runtime_basedir, backup_dir, project_files) | |
return backup_dir | |
def restore_files( | |
ident: str = DEFAULT_RESTORE | |
) -> str: | |
"""Restores Python source files from project directory into Google Drive. | |
Arguments: | |
ident -- any substring that help uniquely identify the timestamped | |
backup folder in Google Drive. | |
Special value "DEFAULT_RESTORE" automatically picks the | |
most recent backup. | |
""" | |
if not isdir(runtime_basedir): | |
raise NotADirectoryError(runtime_basedir) | |
if not isdir("/content/drive"): | |
raise NotADirectoryError("/content/drive") | |
backup_dir = resolve_backup_dir(ident) | |
print("Restoring project source code from " + backup_dir + " ...") | |
project_files = scan_project_files() | |
sync_files(backup_dir, runtime_basedir, project_files) | |
def change_runtime_dir(): | |
if not isdir(runtime_basedir): | |
raise NotADirectoryError(runtime_basedir) | |
os.chdir(runtime_basedir) | |
### | |
### Paste these code in a separate cell. | |
### The following cell should be run once, when connecting to | |
### a *** NEW *** instance of Google Colab. | |
### | |
if True: | |
connect_to_drive() | |
if True: | |
extract_project_template() | |
scan_project_files() | |
### | |
### When needing to restore previous work: | |
### Set the condition to False so that restore_files() will be executed. | |
### When needing to save work: | |
### Set the condition to True so that backup_files() will be executed. | |
### | |
if True: | |
if is_backup_needed(): | |
backup_files() | |
else: | |
restore_files() | |
### | |
### Prints the list of project files that will be backed up. | |
### | |
### Note that this is determined based on files that exist in the project | |
### template. In other words, any additional files not present in the project | |
### template will be ignored. | |
### | |
scan_project_files() | |
### | |
### Compares the current working copy with each of the backup copies | |
### to determine if there are unsaved changes. | |
### | |
print("is_backup_needed() = ", is_backup_needed()) | |
### | |
### Prints out two source diff reports, the first one with respect to | |
### the earliest backup, and the second one with respect to the latest | |
### backup copy. | |
### | |
if True: | |
diff_files( | |
resolve_backup_dir(RESTORE_FROM_EARLIEST), | |
resolve_backup_dir(RUNTIME_BASEDIR), | |
scan_project_files()) | |
if True: | |
diff_files( | |
resolve_backup_dir(RESTORE_FROM_LATEST), | |
resolve_backup_dir(RUNTIME_BASEDIR), | |
scan_project_files()) | |
### | |
### Set each block to True to execute the main script corresponding | |
### to each part of the project. | |
### | |
### Each main script will be executed in the same namespace (scope) | |
### as the Colab notebook namespace; therefore, variables can be | |
### inspected from the notebook itself. | |
### | |
### While this is convenient, do remember to reset the Colab runtime | |
### when moving on to a different main script as you work through | |
### each part of the project. | |
### | |
if False: | |
# Project 3, Part 1, hand-crafted neural network (tabs 1-5) | |
with cwd(part2_nn_dir): | |
%run -i 'neural_nets.py' | |
if False: | |
# Project 3, Part 2, PyTorch fully-connected network (tabs 8) | |
with cwd(part2_mnist_dir): | |
%run -i 'nnet_fc.py' | |
if False: | |
# Project 3, Part 2, PyTorch convolutional network (tabs 9) | |
with cwd(part2_mnist_dir): | |
%run -i 'nnet_cnn.py' | |
if False: | |
# Project 3, Part 3, Two digit recognition, MLP (tabs 10) | |
with cwd(part2_twodigit_dir): | |
%run -i 'mlp.py' | |
if True: | |
# Project 3, Part 3, Two digit recognition, CNN (tabs 10) | |
with cwd(part2_twodigit_dir): | |
%run -i 'conv.py' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment