Skip to content

Instantly share code, notes, and snippets.

@kinchungwong
Last active October 25, 2023 10:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kinchungwong/c6dd8016957ab8ca79d873449980f57a to your computer and use it in GitHub Desktop.
Save kinchungwong/c6dd8016957ab8ca79d873449980f57a to your computer and use it in GitHub Desktop.
MITx 6.86x Project 3 Google Colab helper code
%load_ext autoreload
%autoreload 2
import os
from os.path import join as pathjoin
from os.path import isdir, isfile
import shutil
import subprocess
import datetime
import functools
from functools import lru_cache
import contextlib
from contextlib import contextmanager
import filecmp
from google.colab import drive
assets_URL = "https://courses.edx.org/__REDACTED__/__FILENAME__.tar.gz"
runtime_basedir = "/content/project3_runtime"
backup_basedir = "/content/drive/MyDrive/Colab/__COURSENAME__"
backup_prefix = "project3_"
part2_nn_dir = pathjoin(runtime_basedir, "mnist/part2-nn")
part2_mnist_dir = pathjoin(runtime_basedir, "mnist/part2-mnist")
part2_twodigit_dir = pathjoin(runtime_basedir, "mnist/part2-twodigit")
def run_part2_nn():
os.chdir(part2_nn_dir)
%run -i 'neural_nets.py'
def make_timestamp_filename():
"""Returns a string that can be used as part of a file name.
"""
t = datetime.datetime.now()
fmt = "%Y-%m-%d_%H-%M-%S_%f"
return t.strftime(fmt)
@contextmanager
def cwd(path):
"""Change working directory for a code block and restore to
original working directory upon exit.
Credits:
https://stackoverflow.com/a/37996581
"""
oldpwd = os.getcwd()
os.chdir(path)
try:
yield
finally:
os.chdir(oldpwd)
def recursive_files(path, dir_filter = None, file_filter = None):
if dir_filter is None:
dir_filter = lambda _: True
elif not callable(dir_filter):
raise Exception("recursive_files(): bad dir_filter")
if file_filter is None:
file_filter = lambda _: True
elif not callable(file_filter):
raise Exception("recursive_files(): bad file_filter")
nofollow = [".", ".."]
paths = [path]
files = []
next_p = 0
while next_p < len(paths):
path = paths[next_p]
next_p += 1
with os.scandir(path) as itr:
for entry in itr:
if entry.name in nofollow:
continue
ep = pathjoin(path, entry.name)
if entry.is_file() and file_filter(ep):
files.append(ep)
elif entry.is_dir() and dir_filter(ep):
paths.append(ep)
return files
def sync_files(srcdir, dstdir, relative_files):
"""For each file path in relative_files, locate the file in srcdir
and copy each into dstdir.
The file paths in relative_files can contain the path separator.
The caller is responsible for ensuring that the file paths can be
correctly combined with either srcdir or dstdir using os.path.join().
"""
for filename in relative_files:
srcfull = pathjoin(srcdir, filename).replace("/./", "/")
dstfull = pathjoin(dstdir, filename).replace("/./", "/")
if not isfile(srcfull):
print(f"Source file {srcfull} not found; skipped")
continue
dst_parent = os.path.dirname(dstfull)
if not isdir(dst_parent):
print(f"Creating destination directory {dst_parent}")
os.makedirs(dst_parent, exist_ok=True)
print(f"Copying from {srcfull} to {dstfull}")
shutil.copyfile(srcfull, dstfull)
def diff_files(srcdir, dstdir, relative_files):
"""For each file path in relative_files, locate the file in srcdir
and dstdir, and if both copies exist, print a summary of diffs.
The file paths in relative_files can contain the path separator.
The caller is responsible for ensuring that the file paths can be
correctly combined with either srcdir or dstdir using os.path.join().
"""
for filename in relative_files:
srcfull = pathjoin(srcdir, filename).replace("/./", "/")
dstfull = pathjoin(dstdir, filename).replace("/./", "/")
if not isfile(srcfull):
print(f"Source file {srcfull} not found; skipped")
continue
if not isfile(dstfull):
print(f"Dest file {dstfull} not found; skipped")
continue
cmdopts = " ".join([
"--text",
"--strip-trailing-cr",
"--unified=10",
"--ignore-trailing-space",
"--ignore-tab-expansion",
"--ignore-all-space",
"--ignore-blank-lines",
])
cmd = f"diff {cmdopts} {srcfull} {dstfull}"
# subprocess.call(cmd, shell=True)
!{cmd}
def connect_to_drive():
if not isdir("/content/drive"):
drive.mount('/content/drive')
def extract_project_template():
print("Downloading project data archive...")
if not isdir(runtime_basedir):
os.makedirs(runtime_basedir, exist_ok=True)
temp_tgz = "download_" + make_timestamp_filename() + ".tgz"
cmd = f"cd {runtime_basedir} && wget -O {temp_tgz} {assets_URL}"
subprocess.call(cmd, shell=True)
print("Extracting data files...")
cmd = f"cd {runtime_basedir} && tar xvf {temp_tgz}"
subprocess.call(cmd, shell=True)
print("Cleaning up downloads...")
cmd = f"cd {runtime_basedir} && rm {temp_tgz}"
subprocess.call(cmd, shell=True)
print("Done.")
class project_file_filter_cls:
def __init__(self):
self.utils_path = [
"mnist/utils.py"
]
self.part2_dirs = [
"mnist/part2-nn",
"mnist/part2-mnist",
"mnist/part2-twodigit",
]
def __call__(self, file_path: str):
if not file_path.endswith(".py"):
return False
for p in self.utils_path:
if file_path.endswith(p):
return True
for d in self.part2_dirs:
if d in file_path:
return True
return False
@lru_cache
def scan_project_files():
if not isdir(runtime_basedir):
raise NotADirectoryError(f"Project runtime directory not found: {runtime_basedir}")
with cwd(runtime_basedir):
filenames: list[str] = recursive_files(
"./",
file_filter = project_file_filter_cls()
)
filenames.sort()
return filenames
BACKUP_TO_NEW = "__BACKUP_TO_NEW__"
RESTORE_FROM_EARLIEST = "__RESTORE_FROM_EARLIEST__"
RESTORE_FROM_LATEST = "__RESTORE_FROM_LATEST__"
RUNTIME_BASEDIR = "__RUNTIME_BASEDIR__"
DEFAULT_BACKUP = BACKUP_TO_NEW
DEFAULT_RESTORE = RESTORE_FROM_LATEST
def list_backup_dirs() -> list[str]:
nofollow = [
".",
"..",
]
if not isdir(backup_basedir):
raise NotADirectoryError(backup_basedir)
with cwd(backup_basedir):
backup_dirs = os.listdir("./")
backup_dirs = [p for p in backup_dirs if (p not in nofollow)]
backup_dirs.sort()
return backup_dirs
def resolve_backup_dir(ident: str) -> str:
"""Resolves a backup directory (in Google Drive) for a
backup or restore operation.
The special values DEFAULT_BACKUP and DEFAULT_RESTORE
can be used.
Arguments:
ident: a string or condition to help select the
correct backup directory for the operation.
Exceptions:
NotADirectoryError if the backup base directory does
not exist. This happens if the Google Drive has not
been mounted.
Exception if the ident string or condition does not
select a unique backup directory. This happens if
it matches nothing, or if it matches more than one
backup directory.
"""
if not isdir(backup_basedir):
raise NotADirectoryError(backup_basedir)
if not isdir(runtime_basedir):
raise NotADirectoryError(runtime_basedir)
if ident == RUNTIME_BASEDIR:
return runtime_basedir
if ident == BACKUP_TO_NEW:
backup_dir = backup_prefix + make_timestamp_filename()
backup_dir = pathjoin(backup_basedir, backup_dir)
backup_dir = backup_dir.replace("/./", "/")
return backup_dir
backup_dirs = list_backup_dirs()
if ident == RESTORE_FROM_EARLIEST:
if len(backup_dirs) >= 1:
backup_dirs = [backup_dirs[0]]
else:
raise Exception("No backup directories available to restore from.")
elif ident == RESTORE_FROM_LATEST:
if len(backup_dirs) >= 1:
backup_dirs = [backup_dirs[-1]]
else:
raise Exception("No backup directories available to restore from.")
elif isinstance(ident, str):
backup_dirs = [p for p in backup_dirs if (ident in p)]
elif callable(ident):
backup_dirs = [p for p in backup_dirs if (ident(p))]
else:
raise TypeError(
"restore_files(): bad lookup type " + type(ident).__name__)
if len(backup_dirs) == 1:
backup_dir = backup_dirs[0]
backup_dir = pathjoin(backup_basedir, backup_dir)
backup_dir = backup_dir.replace("/./", "/")
return backup_dir
else:
raise Exception(
"resolve_backup_dir(): the lookup string \"" + ident
+ "\" does not identify any unique backup directory "
+ "to restore from.")
def is_backup_needed() -> bool:
filecmp.clear_cache()
project_files = scan_project_files()
backup_dirs = list_backup_dirs()
backup_dirs = [
pathjoin(backup_basedir, p).replace("/./", "/")
for p in backup_dirs
]
for backup_dir in backup_dirs:
dir_has_changed = False
for filename in project_files:
srcfull = pathjoin(runtime_basedir, filename).replace("/./", "/")
dstfull = pathjoin(backup_dir, filename).replace("/./", "/")
if not isfile(srcfull) or not isfile(dstfull):
dir_has_changed = True
break
file_is_same = filecmp.cmp(srcfull, dstfull, shallow=False)
if not file_is_same:
dir_has_changed = True
break
if not dir_has_changed:
print("Runtime directory is identical to ")
print("\t", backup_dir)
print("No backup needed.")
return False
# None of the backup dirs matches exactly with the runtime base dir.
print("Backup is needed, runtime directory contains new changes.")
return True
def backup_files() -> str:
"""Backs up Python source files from project directory into Google Drive.
Returns:
The newly-created, timestamped backup directory in Google Drive.
"""
if not isdir(runtime_basedir):
raise NotADirectoryError(runtime_basedir)
if not isdir("/content/drive"):
raise NotADirectoryError("/content/drive")
print("Backing up project source code...")
backup_dir = resolve_backup_dir(DEFAULT_BACKUP)
if not isdir(backup_dir):
print(f"Creating backup directory {backup_dir}")
os.makedirs(backup_dir, exist_ok=True)
project_files = scan_project_files()
sync_files(runtime_basedir, backup_dir, project_files)
return backup_dir
def restore_files(
ident: str = DEFAULT_RESTORE
) -> str:
"""Restores Python source files from project directory into Google Drive.
Arguments:
ident -- any substring that help uniquely identify the timestamped
backup folder in Google Drive.
Special value "DEFAULT_RESTORE" automatically picks the
most recent backup.
"""
if not isdir(runtime_basedir):
raise NotADirectoryError(runtime_basedir)
if not isdir("/content/drive"):
raise NotADirectoryError("/content/drive")
backup_dir = resolve_backup_dir(ident)
print("Restoring project source code from " + backup_dir + " ...")
project_files = scan_project_files()
sync_files(backup_dir, runtime_basedir, project_files)
def change_runtime_dir():
if not isdir(runtime_basedir):
raise NotADirectoryError(runtime_basedir)
os.chdir(runtime_basedir)
###
### Paste these code in a separate cell.
### The following cell should be run once, when connecting to
### a *** NEW *** instance of Google Colab.
###
if True:
connect_to_drive()
if True:
extract_project_template()
scan_project_files()
###
### When needing to restore previous work:
### Set the condition to False so that restore_files() will be executed.
### When needing to save work:
### Set the condition to True so that backup_files() will be executed.
###
if True:
if is_backup_needed():
backup_files()
else:
restore_files()
###
### Prints the list of project files that will be backed up.
###
### Note that this is determined based on files that exist in the project
### template. In other words, any additional files not present in the project
### template will be ignored.
###
scan_project_files()
###
### Compares the current working copy with each of the backup copies
### to determine if there are unsaved changes.
###
print("is_backup_needed() = ", is_backup_needed())
###
### Prints out two source diff reports, the first one with respect to
### the earliest backup, and the second one with respect to the latest
### backup copy.
###
if True:
diff_files(
resolve_backup_dir(RESTORE_FROM_EARLIEST),
resolve_backup_dir(RUNTIME_BASEDIR),
scan_project_files())
if True:
diff_files(
resolve_backup_dir(RESTORE_FROM_LATEST),
resolve_backup_dir(RUNTIME_BASEDIR),
scan_project_files())
###
### Set each block to True to execute the main script corresponding
### to each part of the project.
###
### Each main script will be executed in the same namespace (scope)
### as the Colab notebook namespace; therefore, variables can be
### inspected from the notebook itself.
###
### While this is convenient, do remember to reset the Colab runtime
### when moving on to a different main script as you work through
### each part of the project.
###
if False:
# Project 3, Part 1, hand-crafted neural network (tabs 1-5)
with cwd(part2_nn_dir):
%run -i 'neural_nets.py'
if False:
# Project 3, Part 2, PyTorch fully-connected network (tabs 8)
with cwd(part2_mnist_dir):
%run -i 'nnet_fc.py'
if False:
# Project 3, Part 2, PyTorch convolutional network (tabs 9)
with cwd(part2_mnist_dir):
%run -i 'nnet_cnn.py'
if False:
# Project 3, Part 3, Two digit recognition, MLP (tabs 10)
with cwd(part2_twodigit_dir):
%run -i 'mlp.py'
if True:
# Project 3, Part 3, Two digit recognition, CNN (tabs 10)
with cwd(part2_twodigit_dir):
%run -i 'conv.py'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment