Skip to content

Instantly share code, notes, and snippets.

@siakon89
Created June 18, 2018 16:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save siakon89/d5fc9b872d8f115c4e82c258eb235ec1 to your computer and use it in GitHub Desktop.
Save siakon89/d5fc9b872d8f115c4e82c258eb235ec1 to your computer and use it in GitHub Desktop.
train_instance_type='ml.p3.2xlarge'
gpu_count=1
batch_size=64
output_path = 's3://{}/{}/output'.format(sess.default_bucket(), repo_name)
image_name = '{}.dkr.ecr.{}.amazonaws.com/{}:latest'.format(account, region, repo_name)
print(output_path)
print(image_name)
estimator = sagemaker.estimator.Estimator(
image_name=image_name,
base_job_name=base_job_name,
role=role,
train_instance_count=1,
train_instance_type=train_instance_type,
output_path=output_path,
sagemaker_session=sess)
estimator.set_hyperparameters(lr=0.0001, epochs=10, gpus=gpu_count, batch_size=batch_size)
estimator.fit({'training': train_input_path, 'validation': validation_input_path})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment