This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
parser = argparse.ArgumentParser() | |
parser.add_argument('--data-folder', type=str, dest='data_folder', help='data folder mounting point') | |
parser.add_argument('--batch-size', type=int, dest='batch_size', default=50, help='mini batch size for training') | |
parser.add_argument('--x_filename', type=str, dest='x_filename', help='Filename with training data') | |
parser.add_argument('--y_filename', type=str, dest='y_filename', help='Filename with label data') | |
parser.add_argument('--training_size', type=str, dest='training_size', help='Size of training dataset') | |
parser.add_argument('--n_epochs', type=int, dest='n_epochs', help='Number of epochs') | |
args = parser.parse_args() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# start an Azure ML run | |
run = Run.get_context() | |
class LogRunMetrics(Callback): | |
# callback at the end of every epoch | |
def on_epoch_end(self, epoch, log): | |
# log a value repeated which creates a list | |
run.log('Loss', log['loss']) | |
run.log('Accuracy', log['acc']) | |
run.log('Val_Loss', log['val_loss']) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class CharVocab: | |
''' Create a Vocabulary for ''' | |
def __init__(self, type_vocab,pad_token='<PAD>', eos_token='<EOS>', unk_token='<UNK>'): #Initialization of the type of vocabulary | |
self.type = type_vocab | |
#self.int2char ={} | |
self.int2char = [] | |
if pad_token !=None: | |
self.int2char += [pad_token] | |
if eos_token !=None: | |
self.int2char += [eos_token] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def one_hot_encode(indices, dict_size): | |
''' Define one hot encode matrix for our sequences''' | |
# Creating a multi-dimensional array with the desired output shape | |
# Encode every integer with its one hot representation | |
features = np.eye(dict_size, dtype=np.float32)[indices.flatten()] | |
# Finally reshape it to get back to the original array | |
features = features.reshape((*indices.shape, dict_size)) | |
return features |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sagemaker | |
# Get the session id | |
sagemaker_session = sagemaker.Session() | |
# Get the bucet, in our example the default buack | |
bucket = sagemaker_session.default_bucket() | |
# Set the S3 subfolder where our data will be stored | |
prefix = 'sagemaker/char_level_rnn' | |
# Get the role for permission | |
role = sagemaker.get_execution_role() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
from torch.autograd import Variable | |
class RNNModel(nn.Module): | |
def __init__(self, vocab_size, embedding_size, hidden_dim, n_layers, drop_rate=0.2): | |
super(RNNModel, self).__init__() | |
# Defining some parameters |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def batch_generator_sequence(features_seq, label_seq, batch_size, seq_len): | |
"""Generator function that yields batches of data (input and target) | |
Args: | |
batch_size (int): number of examples (in this case, sentences) per batch. | |
max_length (int): maximum length of the output tensor. | |
NOTE: max_length includes the end-of-sentence character that will be added | |
to the tensor. | |
Keep in mind that the length of the tensor is always 1 + the length | |
of the original line of characters. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sagemaker.pytorch import PyTorch | |
# Select the type of instance to use for training | |
#instance_type='ml.m4.4xlarge' # CPU instance | |
instance_type='ml.p2.xlarge' # GPU instance | |
#instance_type='local' | |
#Create the estimator object | |
estimator = PyTorch(entry_point="train.py", | |
source_dir="train", | |
role=role, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def sample_from_probs(probs, top_n=10): | |
""" | |
truncated weighted random choice. | |
""" | |
_, indices = torch.sort(probs) | |
# set probabilities after top_n to 0 | |
probs[indices.data[:-top_n]] = 0 | |
# Sampling the index of the predicted next char | |
sampled_index = torch.multinomial(probs, 1) | |
return sampled_index |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sagemaker.predictor import RealTimePredictor | |
from sagemaker.pytorch import PyTorchModel | |
class StringPredictor(RealTimePredictor): | |
def __init__(self, endpoint_name, sagemaker_session): | |
super(StringPredictor, self).__init__(endpoint_name, sagemaker_session, content_type='text/plain') | |
# Create a model in Sagemaker | |
model = PyTorchModel(model_data=estimator.model_data, | |
role = role, | |
framework_version='0.4.0', |
OlderNewer