Skip to content

Instantly share code, notes, and snippets.

@jogardi
Last active June 14, 2022 16:45
Show Gist options
  • Save jogardi/b7e5e820d0e041b51bda077476c81046 to your computer and use it in GitHub Desktop.
Save jogardi/b7e5e820d0e041b51bda077476c81046 to your computer and use it in GitHub Desktop.
Example of managing resources and datasets for ML
import joblib, os, torch
from functools import cached_property
from google.cloud import storage
from google.oauth2 import service_account
import torchvision.transforms as transforms
def prep_dir(path: str) -> str:
if not os.path.exists(path):
os.mkdir(path)
return path
class Resources:
project_dir = utils.prep_dir(os.path.expanduser("~/.<your project name>"))
cached_files = utils.prep_dir(f"{project_dir}/cached_files")
memcache = joblib.Memory(f"{cached_files}/joblibcache")
@cached_property
def gstorage_client(self) -> storage.Client:
return storage.Client(credentials=self.gcreds, project=self.gcreds.project_id)
@cached_property
def gcreds(self):
return service_account.Credentials.from_service_account_file(filename=f"{self.project_dir}/gcloud_service_credentials.json")
@cached_property
def res_bucket(self) -> storage.Bucket:
return self.gstorage_client.get_bucket("res1")
def from_gcp(self, key: str) -> str:
parts = key.split("/")
if len(parts) > 1:
utils.prep_dir(f"{self.bucket_files_dir}/{'/'.join(parts[:-1])}")
file_path = f"{self.bucket_files_dir}/{key}"
if not os.path.exists(file_path):
source_blob = self.res_bucket.blob(key)
if source_blob.exists():
source_blob.download_to_filename(file_path)
return file_path
def from_gcp_dir(self, key: str, bucket=None):
if bucket is None:
bucket = self.res_bucket
file_path = f"{self.bucket_files_dir}/{'_'.join(key.split('/'))}"
if not os.path.exists(file_path):
os.makedirs(file_path)
blobs = bucket.list_blobs(prefix=key)
for blob in blobs:
subpath = f"{self.bucket_files_dir}/{blob.name}"
subpath_dir = "/".join(subpath.split("/")[:-1])
if not os.path.exists(subpath_dir):
os.makedirs(subpath_dir)
blob.download_to_filename(subpath)
return file_path
def load_transformed_ex_dataset(self) -> torch.utils.data.Dataset:
return torchvision.datasets.ImageFolder(
self.ex_dataset_path,
transform=transforms.Compose([
transforms.ToTensor()
])
)
def load_transformed_ex_dataset_cropped(self) -> torch.utils.data.Dataset:
return torchvision.datasets.ImageFolder(
self.ex_dataset_path,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.CenterCrop(256)
])
)
@cached_property
def ex_dataset_path(self) -> str:
return self.from_gcp("path to the file in your resources bucket")
project_res = Resources()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment