Skip to content

Instantly share code, notes, and snippets.

@dzonesasaki
Created December 15, 2023 02:32
Show Gist options
  • Save dzonesasaki/576c14268aae4df4e8f90d49757509a4 to your computer and use it in GitHub Desktop.
Save dzonesasaki/576c14268aae4df4e8f90d49757509a4 to your computer and use it in GitHub Desktop.
run nue-asr with a long wav file on CPU
import nue_asr
import librosa
import numpy as np
import sys
# pip
# torch
# transformers
# git+https://github.com/rinnakk/nue-asr
# git clone https://huggingface.co/rinna/nue-asr
modelName ='./nue-asr'
tokenizerName = './nue-asr'
myDevice = 'cpu'
wavFileName = './voice_s16le16kHz.wav'
# usage : python3 runnue.py [wavfile] [nueDir]
if len(sys.argv)>2:
modelName = sys.argv[2]
tokenizerName = modelName
if len(sys.argv)>1:
wavFileName = sys.argv[1]
def simpleVAD(snd , tMin=1, tMax=16, sr=16000, wms=4, factTh=0.5):
nMin = tMin/wms*1000//1
nMax = tMax/wms*1000//1
vRmsAll = (snd**2).mean()**0.5
vThresh = vRmsAll*factTh
lenWin = sr*wms//1000
nWin = ((len(snd))//lenWin)
nmod = len(snd) - nWin*lenWin
vAryRms = np.zeros(nWin+1)
vAryRms[:nWin] = (snd[:-nmod].reshape(nWin,lenWin)**2).mean(1)**0.5
vAryRms[nWin] = (snd[len(snd)-nmod:]**2).mean()**0.5
vflagValid = (vAryRms > vThresh)*1
vStFlag = np.zeros(len(vflagValid))
vStFlag[0] = vflagValid[0]
vEndFlag = np.zeros(len(vflagValid))
indxFlg = np.array(len(vflagValid))
vStFlag[1:] = ((vflagValid[1:] - vflagValid[:-1]) == 1)*1
vEndFlag[1:] = ((vflagValid[1:] - vflagValid[:-1]) == -1)*1
aposStart=[]
aposStop=[]
bflagAct = False
pPrevSt = 0
for k,v in enumerate(vflagValid):
if bflagAct == False:
if vStFlag[k] ==1:
bflagAct = True
aposStart.append(k*lenWin)
pPrevSt = k
continue
if bflagAct == True:
if (k-pPrevSt) < nMin:
continue
if (k-pPrevSt) >= nMax:
bflagAct = False
aposStop.append(k*lenWin)
continue
if vEndFlag[k] ==1:
bflagAct = False
aposStop.append(k*lenWin)
if len(aposStart) > len(aposStop):
aposStop.append(len(snd))
retList =[]
for k in range(len(aposStart)):
retList.append([aposStart[k],aposStop[k]])
return retList
sampleRate=16000
snd,sr = librosa.load(wavFileName, sr=sampleRate)
periodSecMin = 5
periodSecMax = 16
windowMilliSec = 4
threshSoundRms = 0.5
segSp = simpleVAD(snd,periodSecMin,periodSecMax,sampleRate,windowMilliSec,threshSoundRms)
lstSnd =[]
for c in segSp:
lstSnd.append(snd[c[0]:c[1]])
model = nue_asr.load_model(modelName, device= myDevice,fp16=False)
tokenizer = nue_asr.load_tokenizer(tokenizerName)
# result = nue_asr.transcribe(model, tokenizer, wavFileName)
# print(result.text)
for snda in lstSnd:
result = nue_asr.transcribe(model, tokenizer, snda)
print(result.text)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment