Skip to content

Instantly share code, notes, and snippets.

@yuyyuyu
Created March 8, 2020 15:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yuyyuyu/44191b2708dbe8b4674c971b1dff1e61 to your computer and use it in GitHub Desktop.
Save yuyyuyu/44191b2708dbe8b4674c971b1dff1e61 to your computer and use it in GitHub Desktop.
import os
from scipy.io import wavfile
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from keras.layers import Conv2D, MaxPool2D, Flatten, LSTM
from keras.layers import Dropout, Dense, TimeDistributed
from keras.models import Sequential
from keras.utils import to_categorical
from sklearn.utils.class_weight import compute_class_weight
from tqdm import tqdm
from python_speech_features import mfcc
import datetime
import pickle
from keras.callbacks import ModelCheckpoint
from cfg import Config
def get_conv_model():
model=Sequential()
model.add(Conv2D(16,(3,3),activation='relu',strides=(1,1),padding='same',input_shape=input_shape))
model.add(Conv2D(32,(3,3),activation='relu',strides=(1,1),padding='same'))
model.add(Conv2D(64,(3,3),activation='relu',strides=(1,1),padding='same'))
model.add(Conv2D(128,(3,3),activation='relu',strides=(1,1),padding='same'))
model.add(MaxPool2D(2,2))
model.add(Dropout(0.5))
model.add(Flatten())
model.add(Dense(128,activation='relu'))
model.add(Dense(64,activation='relu'))
model.add(Dense(10,activation='softmax'))
model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['acc'])
return model
def check_data():
if os.path.isfile(config.p_path):
print('Loading existing data for {} nodel'.format(config.mode))
with open(config.p_path,'rb') as handle:
tmp=pickle.load(handle)
return tmp
else:
return None
def build_rand_feat():
tmp=check_data()
if tmp:
return tmp.data[0],tmp.data[1]
X=[]
y=[]
_min,_max=float('inf'),-float('inf')
for _ in tqdm(range(n_samples)):
rand_class=np.random.choice(class_dist.index,p=prob_dist)
file=np.random.choice(df[df.label==rand_class].index)
rate,wav=wavfile.read('clean/'+file)
label=df.at[file,'label']
random_index=np.random.randint(0,wav.shape[0]-config.step)
sample=wav[random_index:random_index+config.step]
X_sample=mfcc(sample,rate,numcep=config.nfeat,nfilt=config.nfilt,nfft=config.nfft).T
_min=min(np.amin(X_sample),_min)
_max=max(np.amax(X_sample),_max)
X.append(X_sample if config.mode=='conv' else X_sample.T)
y.append(classes.index(label))
config.min=_min
config.max=_max
X,y=np.array(X),np.array(y)
X=(X-_min)/(_max-_min)
if config.mode=='conv':
X=X.reshape(X.shape[0],X.shape[1],X.shape[2],1)
elif config.mode=='time':
X=X.reshape(X.shape[0],X.shape[1],X.shape[2])
y=to_categorical(y,num_classes=10)
config.data=(X,y)
with open(config.p_path,'wb') as handle:
pickle.dump(config,handle,protocol=2)
return X,y
df = pd.read_csv('instruments.csv')
df.set_index('fname', inplace=True)
for f in df.index:
rate, signal = wavfile.read('clean/'+f)
df.at[f, 'length'] = signal.shape[0]/rate
classes = list(np.unique(df.label))
class_dist = df.groupby(['label'])['length'].mean()
n_samples=2*int(df['length'].sum()/0.1)
prob_dist=class_dist/class_dist.sum()
choices=np.random.choice(class_dist.index,p=prob_dist)
fig, ax = plt.subplots()
ax.set_title('Class Distribution', y=1.08)
ax.pie(class_dist, labels=class_dist.index, autopct='%1.1f%%',
shadow=False, startangle=90)
ax.axis('equal')
config=Config(mode='conv')
if config.mode=='conv':
X,y=build_rand_feat()
y_flat=np.argmax(y,axis=1)
input_shape=(X.shape[1],X.shape[2],1)
model=get_conv_model()
elif config.mode=='time':
X,y=build_rand_feat()
y_flat=np.argmac(y,axis=1)
input_shape=(X.shape[1],X.shape[2],1)
model=get_recurrent_model
class_weight=compute_class_weight('balanced',np.unique(y_flat),y_flat)
checkpoint=ModelCheckpoint(config.model_path,monitor='val_acc',verbose=1,mode='max',save_weights_only=False,period=1)
model.fit(X,y,epochs=10, batch_size=32,shuffle=True,validation_split=0.1,callbacks=[checkpoint])
model.save(config.model_path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment