Skip to content

Instantly share code, notes, and snippets.

@crcrpar
Created March 12, 2021 01:13
Show Gist options
  • Save crcrpar/0f19954a66ff71bb8d70fc9bda71f01c to your computer and use it in GitHub Desktop.
Save crcrpar/0f19954a66ff71bb8d70fc9bda71f01c to your computer and use it in GitHub Desktop.
diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py
index e87cd46e..bfd59914 100644
--- a/torchvision/datasets/mnist.py
+++ b/torchvision/datasets/mnist.py
@@ -131,6 +131,11 @@ class MNIST(VisionDataset):
def class_to_idx(self) -> Dict[str, int]:
return {_class: i for i, _class in enumerate(self.classes)}
+ def _check_raw_data_exists(self) -> bool:
+ return all([
+ os.path.exists(os.path.join(self.raw_folder, filename)) for (filename, _) in self.resources
+ ])
+
def _check_exists(self) -> bool:
return (os.path.exists(os.path.join(self.processed_folder,
self.training_file)) and
@@ -140,34 +145,36 @@ class MNIST(VisionDataset):
def download(self) -> None:
"""Download the MNIST data if it doesn't exist in processed_folder already."""
+ if not self._check_raw_data_exists():
+ print('Dowanloading raw MNIST...')
+ os.makedirs(self.raw_folder, exist_ok=True)
+ os.makedirs(self.processed_folder, exist_ok=True)
+
+ # download files
+ for filename, md5 in self.resources:
+ for mirror in self.mirrors:
+ url = "{}{}".format(mirror, filename)
+ try:
+ print("Downloading {}".format(url))
+ download_and_extract_archive(
+ url, download_root=self.raw_folder,
+ filename=filename,
+ md5=md5
+ )
+ except URLError as error:
+ print(
+ "Failed to download (trying next):\n{}".format(error)
+ )
+ continue
+ finally:
+ print()
+ break
+ else:
+ raise RuntimeError("Error downloading {}".format(filename))
+
if self._check_exists():
return
- os.makedirs(self.raw_folder, exist_ok=True)
- os.makedirs(self.processed_folder, exist_ok=True)
-
- # download files
- for filename, md5 in self.resources:
- for mirror in self.mirrors:
- url = "{}{}".format(mirror, filename)
- try:
- print("Downloading {}".format(url))
- download_and_extract_archive(
- url, download_root=self.raw_folder,
- filename=filename,
- md5=md5
- )
- except URLError as error:
- print(
- "Failed to download (trying next):\n{}".format(error)
- )
- continue
- finally:
- print()
- break
- else:
- raise RuntimeError("Error downloading {}".format(filename))
-
# process and save as torch files
print('Processing...')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment