|
# The code in this file is derived from github.com/determined-ai/determined, |
|
# which is under the Apache2 license. |
|
|
|
import functools |
|
import os |
|
from typing import Any, Dict, List, cast, Optional, Callable |
|
|
|
import numpy as np |
|
import tensorflow as tf |
|
from tensorflow.keras.layers import Dense |
|
from tensorflow.keras.losses import mean_squared_error |
|
from tensorflow.keras.models import Sequential |
|
from tensorflow.keras.optimizers import SGD |
|
|
|
import determined as det |
|
import determined.keras |
|
import yogadl.storage |
|
|
|
from tensorflow.keras.utils import Sequence |
|
|
|
####################### |
|
# Begin YogaDL Helper # |
|
####################### |
|
|
|
class YogaDL: |
|
""" |
|
YogaDL is a helper for using YogaDL directly within Determined. |
|
|
|
Args: |
|
config: The data_layer sub-config from the experiment config. |
|
|
|
dist: The DistributedContext for this training, usually context.distributed. |
|
|
|
per_slot_batch_size: The result of context.get_per_slot_batch_size(). |
|
|
|
seed: The result of context.get_trial_seed(). |
|
|
|
offset_records: The number of records to skip into training. |
|
|
|
rw_coordinator_url: When using s3 or gcs, you must run a rw coordinator service accessible |
|
to all training nodes, and you pass the url for accessing that service here. |
|
|
|
coordinator_cert_file: If you use a self-signed cert for the rw-coordinator, set the |
|
filename to the cert file heere. |
|
|
|
coordinator_cert_file: If you use a self-signed cert for the rw-coordinator, set the name |
|
that appearson the certificate here. |
|
""" |
|
def __init__( |
|
self, |
|
config: Dict[str, Any], |
|
dist: det.core.DistributedContext, |
|
per_slot_batch_size: int, |
|
seed: int, |
|
offset_records: int, |
|
# rw_coordinator_url is required for cloud storage and ignored for shared_fs |
|
rw_coordinator_url: Optional[str] = None, |
|
# you may also optionally provide a self-signed rw-coordinator certificate |
|
coordinator_cert_file: Optional[str] = None, |
|
coordinator_cert_name: Optional[str] = None, |
|
) -> None: |
|
self.dist = dist |
|
self.config = config |
|
self.per_slot_batch_size = per_slot_batch_size |
|
self.seed = seed |
|
self.offset_records = offset_records |
|
|
|
session_config = None |
|
if self.dist.size > 1: |
|
# For multi-GPU training, we map processes to individual GPUs. TF requires |
|
# that for each instantiation of `tf.Session`, the process is mapped |
|
# to the same GPU. |
|
session_config = tf.compat.v1.ConfigProto() |
|
session_config.gpu_options.visible_device_list = str(self.dist.local_rank) |
|
|
|
def get_storage_path(storage_path_key: str) -> str: |
|
default_path = os.path.expanduser("~/data/determined") |
|
storage_path = config.get(storage_path_key, default_path) |
|
os.makedirs(storage_path, exist_ok=True) |
|
return storage_path |
|
|
|
self.storage: yogadl.Storage |
|
|
|
typ = config.get("type") |
|
if typ == "shared_fs": |
|
storage_path = get_storage_path("container_storage_path") |
|
storage_config = yogadl.storage.LFSConfigurations(storage_dir_path=storage_path) |
|
self.storage = yogadl.storage.LFSStorage( |
|
storage_config, tensorflow_config=session_config |
|
) |
|
elif typ == "s3": |
|
assert rw_coordinator_url, "for s3 storage, you must provide a read-write" |
|
storage_path = get_storage_path("local_cache_container_path") |
|
storage_config = yogadl.storage.S3Configurations( |
|
bucket=config["bucket"], |
|
bucket_directory_path=config["bucket_directory_path"], |
|
url=rw_coordinator_url, |
|
local_cache_dir=storage_path, |
|
access_key=config.get("access_key"), |
|
secret_key=config.get("secret_key"), |
|
endpoint_url=config.get("endpoint_url"), |
|
coordinator_cert_file=coordinator_cert_file, |
|
coordinator_cert_name=coordinator_cert_name, |
|
) |
|
self.storage = yogadl.storage.S3Storage( |
|
storage_config, tensorflow_config=session_config |
|
) |
|
elif typ == "gcp": |
|
assert rw_coordinator_url, "for s3 storage, you must provide a read-write" |
|
storage_path = get_storage_path("local_cache_container_path") |
|
storage_config = yogadl.storage.GCSConfigurations( |
|
bucket=config["bucket"], |
|
url=rw_coordinator_url, |
|
local_cache_dir=storage_path, |
|
coordinator_cert_file=coordinator_cert_file, |
|
coordinator_cert_name=coordinator_cert_name, |
|
) |
|
self.storage = yogadl.storage.GCSStorage( |
|
storage_config, tensorflow_config=session_config |
|
) |
|
else: |
|
raise ValueError(f'config.get("type") ({typ}) not recognized') |
|
|
|
def make_decorator( |
|
self, |
|
dataset_id: str, |
|
dataset_version: str, |
|
shuffle: bool, |
|
skip_shuffle_at_epoch_end: bool, |
|
drop_shard_remainder: bool, |
|
) -> Callable: |
|
|
|
def decorator(make_dataset_fn: Callable) -> Callable: |
|
|
|
@functools.wraps(make_dataset_fn) |
|
def decorated(*args: Any, **kwargs: Any) -> Any: |
|
|
|
@self.storage.cacheable( |
|
dataset_id=dataset_id, |
|
dataset_version=dataset_version, |
|
) |
|
def make_dataset() -> yogadl.DataRef: |
|
print(f"generating dataset {dataset_id}:{dataset_version}!") |
|
return make_dataset_fn(*args, **kwargs) |
|
|
|
stream = make_dataset().stream( |
|
start_offset=self.offset_records, |
|
shuffle=shuffle, |
|
skip_shuffle_at_epoch_end=skip_shuffle_at_epoch_end, |
|
shuffle_seed=self.seed, |
|
shard_rank=self.dist.rank, |
|
num_shards=self.dist.size, |
|
drop_shard_remainder=drop_shard_remainder, |
|
) |
|
return yogadl.tensorflow.make_tf_dataset(stream) |
|
|
|
return decorated |
|
|
|
return decorator |
|
|
|
def cache_train_dataset( |
|
self, |
|
dataset_id: str, |
|
dataset_version: str, |
|
shuffle: bool = False, |
|
skip_shuffle_at_epoch_end: bool = False |
|
) -> Callable: |
|
return self.make_decorator( |
|
dataset_id=dataset_id, |
|
dataset_version=dataset_version + "_train", |
|
shuffle=shuffle, |
|
skip_shuffle_at_epoch_end=skip_shuffle_at_epoch_end, |
|
drop_shard_remainder=True, |
|
) |
|
|
|
def cache_validation_dataset( |
|
self, |
|
dataset_id: str, |
|
dataset_version: str, |
|
shuffle: bool = False, |
|
) -> Callable: |
|
return self.make_decorator( |
|
dataset_id=dataset_id, |
|
dataset_version=dataset_version + "_val", |
|
shuffle=shuffle, |
|
skip_shuffle_at_epoch_end=True, |
|
drop_shard_remainder=True, |
|
) |
|
|
|
################################## |
|
# Begin Example Model Definition # |
|
################################## |
|
|
|
def map_to_two(record): |
|
return 2.0 |
|
|
|
|
|
def tf_data_loader(batch_size, length): |
|
xtrain = tf.data.Dataset.range(length).map(map_to_two).batch(batch_size) |
|
ytrain = tf.data.Dataset.range(length).map(map_to_two).batch(batch_size) |
|
|
|
train_ds = tf.data.Dataset.zip((xtrain, ytrain)) |
|
return train_ds |
|
|
|
|
|
class OneVarTrial(det.keras.TFKerasTrial): |
|
""" |
|
Models a simple one variable(y = wx) neural network, and a MSE loss function. |
|
""" |
|
|
|
def __init__(self, context) -> None: |
|
self.context = context |
|
self.my_batch_size = self.context.get_per_slot_batch_size() # type: int |
|
self.my_lr = self.context.get_hparams()["learning_rate"] |
|
|
|
# Instantiate the YogaDL helper. |
|
self.yogadl = YogaDL( |
|
config=self.context.get_data_config()["data_layer"], |
|
dist=context.distributed, |
|
per_slot_batch_size=context.get_per_slot_batch_size(), |
|
seed=context.get_trial_seed(), |
|
# WATCH OUT: this is hack; context.env is a private, undocumented attribute, but this |
|
# is not very likely to break before TFKerasTrial is removed from the system. I can't |
|
# make any promises about internal fields though, of course. |
|
offset_records=context.env.steps_completed * context.get_per_slot_batch_size(), |
|
# ACTION ITEM: if you use s3 or gcs for data layer you need to run a rw-coordinator |
|
# service and provide the url here. |
|
rw_coordinator_url="ws://MY-IP-ADDR:9001", |
|
) |
|
|
|
def build_training_data_loader(self): |
|
# ACTION ITEM: convert deprecated decorator to decorator from our YogaDL helper. |
|
# @self.context.experimental.cache_train_dataset("ones", "1") |
|
@self.yogadl.cache_train_dataset("ones", "1") |
|
def make_ds(): |
|
return tf_data_loader(self.context.get_per_slot_batch_size(), 100) |
|
ds = make_ds() |
|
# shard_datset=False makes this a noop, since we already sharded inside the decorator. |
|
ds = self.context.wrap_dataset(ds, shard_dataset=False) |
|
return ds |
|
|
|
def build_validation_data_loader(self): |
|
# ACTION ITEM: convert deprecated decorator to decorator from our YogaDL helper. |
|
# @self.context.experimental.cache_validation_dataset("ones", "1") |
|
@self.yogadl.cache_validation_dataset("ones", "1") |
|
def make_ds(): |
|
return tf_data_loader(self.context.get_per_slot_batch_size(), 100) |
|
ds = make_ds() |
|
# shard_datset=False makes this a noop, since we already sharded inside the decorator. |
|
ds = self.context.wrap_dataset(ds, shard_dataset=False) |
|
return ds |
|
|
|
def build_model(self) -> Sequential: |
|
model = Sequential() |
|
model.add( |
|
Dense(1, activation=None, use_bias=False, kernel_initializer="zeros", input_shape=(1,)) |
|
) |
|
model = self.context.wrap_model(model) |
|
|
|
optimizer = SGD(learning_rate=self.my_lr) |
|
optimizer = self.context.wrap_optimizer(optimizer) |
|
|
|
model.compile(optimizer, mean_squared_error, metrics=["accuracy"]) |
|
|
|
return model |