Skip to content

Instantly share code, notes, and snippets.

@edumunozsala
Last active September 25, 2020 18:18
Show Gist options
  • Save edumunozsala/bc7aeb121f4004dadcd7c54a1f399837 to your computer and use it in GitHub Desktop.
Save edumunozsala/bc7aeb121f4004dadcd7c54a1f399837 to your computer and use it in GitHub Desktop.
Train the model on Sagemaker CLTG
from sagemaker.pytorch import PyTorch
# Select the type of instance to use for training
#instance_type='ml.m4.4xlarge' # CPU instance
instance_type='ml.p2.xlarge' # GPU instance
#instance_type='local'
#Create the estimator object
estimator = PyTorch(entry_point="train.py",
source_dir="train",
role=role,
framework_version='0.4.0',
train_instance_count=1,
train_instance_type=instance_type,
hyperparameters={
'epochs': 50,
'hidden_dim': 512,
'n_layers': 2,
})
estimator.fit({'training': input_data})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment