Forked from spazewalker/convert_TCN_speech_recog_model.ipynb
Created
September 20, 2023 06:26
-
-
Save Zikovich/813169e25a3b15de512656113c068e29 to your computer and use it in GitHub Desktop.
Code used to export onnx model of TCN based audio visual speech recognition model. This is a part of OpenCV GSoC 2022.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
"""TCN_AVSpeech_model_comp.ipynb | |
Automatically generated by Colaboratory. | |
Original file is located at | |
https://colab.research.google.com/drive/1awBCZ5O6uAT32cHvufNWad5m6q26TuqQ | |
""" | |
# Commented out IPython magic to ensure Python compatibility. | |
# Clone original repo, install requirements, download models & files from google drive | |
# !git clone --recursive https://github.com/mpc001/Lipreading_using_Temporal_Convolutional_Networks.git | |
# %cd Lipreading_using_Temporal_Convolutional_Networks/ | |
# !git checkout 47872c9a7a357b70a4adc97e51658c1e43fde8d9 | |
# !pip install -r requirements.txt | |
# !gdown --id 12mHlNQKCE2AXkFHzvRyqSbsmOMEs259i | |
# !gdown --id 16asCjDdGnnP3AFJZtDlYehHe7qrQ5AXq | |
# !unzip LRW_landmarks.zip -o ./landmarks | |
# !cd models && gdown --id 1tYNYOiJhVNQgf8Rt-X64uzso3Py-RSvu | |
# !cd models && gdown --id 1h6JVCAoLlq-StCkT_a7n_-qmViHK-iUT | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import json | |
from pathlib import Path | |
from lipreading.model import * | |
import torch.onnx | |
class AVLipreading(nn.Module): | |
def __init__( self, comb_wt=0.5, margin=20): | |
# comb_wt is the wt given to audio. Should be between 0 and 1 | |
self.margin=margin | |
if(comb_wt<0 or comb_wt>1): | |
raise Exception() | |
super(AVLipreading, self).__init__() | |
with open('configs/lrw_resnet18_mstcn.json') as fp: | |
config = json.load(fp) | |
tcn_options = { | |
'num_layers': config['tcn_num_layers'], | |
'kernel_size': config['tcn_kernel_size'], | |
'dropout': config['tcn_dropout'], | |
'dwpw': config['tcn_dwpw'], | |
'width_mult': config['tcn_width_mult'], | |
} | |
self.audio_model = Lipreading( | |
num_classes=500, | |
tcn_options=tcn_options, | |
backbone_type=config['backbone_type'], | |
relu_type=config['relu_type'], | |
width_mult=config['width_mult'], | |
extract_feats=False, | |
modality = 'raw_audio' | |
) | |
self.video_model = Lipreading( | |
num_classes=500, | |
tcn_options=tcn_options, | |
backbone_type=config['backbone_type'], | |
relu_type=config['relu_type'], | |
width_mult=config['width_mult'], | |
extract_feats=False, | |
modality = 'video' | |
) | |
self.video_model.load_state_dict(torch.load(Path('models/lrw_resnet18_mstcn_adamw_s3.pth.tar'), map_location='cpu')['model_state_dict']) | |
self.audio_model.load_state_dict(torch.load(Path('models/lrw_resnet18_mstcn_audio_adamw.pth.tar'), map_location='cpu')['model_state_dict']) | |
self.wt = comb_wt | |
def forward(self, audio_input, video_input): | |
audio_sampling_rate = 16000 | |
video_fps = 30 | |
A = self.audio_model.forward(audio_input, [self.margin]) | |
V = self.video_model.forward(video_input, [self.margin*audio_sampling_rate//video_fps]) | |
# Here's the combining step. I'm currently using weighted average | |
return A*self.wt+V*(1-self.wt) | |
model = AVLipreading() | |
model = model.eval() | |
audio_input = torch.randn(1,1,20*16000//30) | |
video_input = torch.randn(1,1,20,96,96) | |
final_out = model.forward(audio_input, video_input) | |
# Export the final model here. | |
torch.onnx.export(model.eval(),(audio_input, video_input), 'AVSpeechRecog.onnx', opset_version=11, input_names=["audio_input", "video_input"]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment