Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
train_instance_type='ml.p3.2xlarge'
gpu_count=1
batch_size=64
output_path = 's3://{}/{}/output'.format(sess.default_bucket(), repo_name)
image_name = '{}.dkr.ecr.{}.amazonaws.com/{}:latest'.format(account, region, repo_name)
print(output_path)
print(image_name)
estimator = sagemaker.estimator.Estimator(
image_name=image_name,
base_job_name=base_job_name,
role=role,
train_instance_count=1,
train_instance_type=train_instance_type,
output_path=output_path,
sagemaker_session=sess)
estimator.set_hyperparameters(lr=0.0001, epochs=10, gpus=gpu_count, batch_size=batch_size)
estimator.fit({'training': train_input_path, 'validation': validation_input_path})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.