Skip to content

Instantly share code, notes, and snippets.

@joydeb28
Last active May 6, 2020 04:46
Show Gist options
  • Save joydeb28/6add8e80d657e27c87f86cc51110c2d3 to your computer and use it in GitHub Desktop.
Save joydeb28/6add8e80d657e27c87f86cc51110c2d3 to your computer and use it in GitHub Desktop.
class DesignModel():
def __init__(self):
self.model = None
self.train_data = [train_input_ids, train_input_masks, train_segment_ids]
self.train_labels = train_labels
def bert_model(self,max_seq_length):
in_id = Input(shape=(max_seq_length,), dtype=tf.int32, name="input_ids")
in_mask = Input(shape=(max_seq_length,), dtype=tf.int32, name="input_masks")
in_segment = Input(shape=(max_seq_length,), dtype=tf.int32, name="segment_ids")
bert_inputs = [in_id, in_mask, in_segment]
bert_pooling_out, bert_sequence_out = bert_model_obj.bert_module(bert_inputs)
out = GlobalAveragePooling1D()(bert_sequence_out)
out = Dropout(0.2)(out)
out = Dense(len(load_data_obj.cat_to_intent), activation="softmax", name="dense_output")(out)
self.model = Model(inputs=bert_inputs, outputs=out)
self.model.compile(optimizer=tf.keras.optimizers.Adam(1e-5),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name="acc")])
self.model.summary()
def model_train(self,batch_size,num_epoch):
print("Fitting to model")
self.model.fit(self.train_data,self.train_labels,epochs=num_epoch,batch_size=batch_size,validation_split=0.2,shuffle=True)
print("Model Training complete.")
def save_model(self,model,model_name):
self.model.save(model_name+".h5")
print("Model saved to Model folder.")
model_obj = DesignModel()
model_obj.bert_model(bert_model_obj.max_len)
model_obj.model_train(32,1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment