Skip to content

Instantly share code, notes, and snippets.

@bowbowbow
Last active August 23, 2019 03:51
Show Gist options
  • Save bowbowbow/4632c656cb2ca0a925fe4fc21a2b519d to your computer and use it in GitHub Desktop.
Save bowbowbow/4632c656cb2ca0a925fe4fc21a2b519d to your computer and use it in GitHub Desktop.
import os
import time
import sys
import datetime
import random
import json
import pickle
import numpy as np
from flask import Flask, session, g, request, render_template, redirect
from flask_mongoengine import MongoEngine
from nltk.tokenize import word_tokenize
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.utils.data
import torch.nn.functional as F
from pytorch_pretrained_bert import BertTokenizer
from pytorch_pretrained_bert import BertModel
from annotation.models import Sent
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
base_dir = os.path.abspath(os.path.dirname(__file__) + '/')
sys.path.append(base_dir)
app = Flask(__name__)
app.config.from_object('annotation.config.Config')
db = MongoEngine(app)
class Config:
model_path = os.path.join(base_dir, './models/05_model/epoch_046_f1_0.614')
sound_feature_path = os.path.join(base_dir, './data/sound_feature_pyAudio.pkl')
max_sound_len = 500
padding = 0
class ConvNet(nn.Module):
def __init__(self, hidden_size=200):
super(ConvNet, self).__init__()
self.hidden_size = hidden_size
self.layer1 = nn.Sequential(
nn.Conv1d(34, 68, kernel_size=5, stride=1, padding=2),
nn.BatchNorm1d(68),
nn.CELU(),
nn.MaxPool1d(kernel_size=5, stride=2),
)
self.layer2 = nn.Sequential(
nn.Conv1d(68, 128, kernel_size=5, stride=1, padding=2),
nn.BatchNorm1d(128),
nn.CELU(),
nn.MaxPool1d(kernel_size=5, stride=2),
)
self.layer3 = nn.Sequential(
nn.Conv1d(128, 256, kernel_size=5, stride=1, padding=2),
nn.BatchNorm1d(256),
nn.CELU(),
nn.MaxPool1d(kernel_size=5, stride=2),
nn.Dropout(0.5),
)
self.fc = nn.Sequential(
nn.Linear(15104, 6000),
nn.CELU(),
nn.Linear(6000, 2000),
nn.Linear(2000, self.hidden_size),
nn.CELU(),
)
def forward(self, x):
out = self.layer1(x)
out = self.layer2(out)
out = self.layer3(out)
out = out.reshape(out.size(0), -1)
out = self.fc(out)
return out
class BertNet(nn.Module):
def __init__(self, finetuning=True, hidden_size=200):
super().__init__()
self.bert = BertModel.from_pretrained('bert-base-uncased')
self.bert_output_size = 768
self.hidden_size = hidden_size
self.rnn = nn.LSTM(input_size=self.bert_output_size, hidden_size=self.hidden_size, batch_first=True, bidirectional=True)
self.fc = nn.Linear(self.hidden_size * 2, self.hidden_size)
self.drop = nn.Dropout(0.5)
self.finetuning = finetuning
def forward(self, x):
if self.training and self.finetuning:
self.bert.train()
encoded_layers, _ = self.bert(x)
enc1 = encoded_layers[-1] # [batch_size, max_len, hidden_size]
else:
self.bert.eval()
with torch.no_grad():
encoded_layers, _ = self.bert(x)
enc1 = encoded_layers[-1]
enc, (final_hidden_state, final_cell_state) = self.rnn(enc1) # final_hidden_sate: [1, batch_size, hidden_size]
# enc: [batch_size, seq_len, num_directions * hidden_size]
# Decode the hidden state of the last time step
enc = enc[:, -1, :]
logits = self.fc(enc)
logits = self.drop(logits)
return logits
class TensionRNNNet(nn.Module):
def __init__(self, num_classes=3, tension_embed_size=5, hidden_size=5):
super().__init__()
self.hidden_size = hidden_size
self.tension_embed_size = tension_embed_size
self.tension_embed = nn.Embedding(num_embeddings=num_classes, embedding_dim=self.tension_embed_size)
self.rnn = nn.LSTM(input_size=self.tension_embed_size, hidden_size=self.hidden_size, batch_first=True, bidirectional=True)
self.fc = nn.Linear(self.hidden_size * 2, self.hidden_size)
self.drop = nn.Dropout(0.5)
def forward(self, x):
# [batch_size, tension_length]
x = self.tension_embed(x) # [batch_size, tension_length, tension_embed_size]
enc, (final_hidden_state, final_cell_state) = self.rnn(x)
# enc: [batch_size, seq_len, num_directions * hidden_size]
# Decode the hidden state of the last time step
enc = enc[:, -1, :]
logits = self.fc(enc)
logits = self.drop(logits)
return logits
class MultiModal(nn.Module):
def __init__(self, num_classes=3, hidden_size=200, tension_hidden_size=10):
super().__init__()
self.hidden_size = hidden_size
self.num_classes = num_classes
self.bert_net = BertNet(hidden_size=self.hidden_size)
self.conv_net = ConvNet(hidden_size=self.hidden_size)
self.tension_rnn_net = TensionRNNNet(num_classes=num_classes, hidden_size=tension_hidden_size)
self.fc = nn.Sequential(
nn.Linear(self.hidden_size * 2 + tension_hidden_size, self.hidden_size * 2),
nn.CELU(),
nn.Linear(self.hidden_size * 2, self.hidden_size),
nn.Linear(self.hidden_size, self.num_classes),
)
def forward(self, text_x, sound_x, context_label_x):
text_x = self.bert_net(text_x)
sound_x = self.conv_net(sound_x)
context_label_x = self.tension_rnn_net(context_label_x)
out = torch.cat([text_x, sound_x, context_label_x], 1)
out = self.fc(out)
return out
def load_model():
global model
model = torch.load(Config.model_path)
model.eval()
global sound_feature
sound_feature = dict()
with open(Config.sound_feature_path, 'rb') as f:
sounds = pickle.load(f)
for sound in sounds:
features = []
for feature in sound['sound_feature']:
features.append(feature.tolist()[:Config.max_sound_len] + [Config.padding] * (Config.max_sound_len - len(feature)))
sound_feature[sound['send_id']] = features
@app.route("/predict", methods=["POST"])
def predict():
data = request.get_json()
sent_id = data['sent_id']
print('sent_id :', sent_id)
text = Sent.objects.get(id=sent_id).text
words = ['[CLS]'] + word_tokenize(text) + ['[SEP]']
sent_x = []
for w in words:
tokens = tokenizer.tokenize(w) if w not in ("[CLS]", "[SEP]") else [w]
xx = tokenizer.convert_tokens_to_ids(tokens)
sent_x.extend(xx)
sent_x = np.array([sent_x])
sound_x = np.array([sound_feature[sent_id]])
tension_x = np.array([[1, 1, 1, 1, 1]])
sent_x = torch.LongTensor(sent_x)
sound_x = torch.FloatTensor(sound_x)
tension_x = torch.LongTensor(tension_x)
out = model(sent_x, sound_x, tension_x)
out = F.softmax(out)
out = out.tolist()
result = out[0]
return json.dumps({
'result': result,
})
if __name__ == '__main__':
load_model()
FLASK_DEBUG = os.getenv('FLASK_DEBUG', False)
app.run(host="0.0.0.0", debug=True, port=8081)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment