Created
November 7, 2022 11:08
-
-
Save MinaGabriel/7870fd55a2aa2d557d7787068adad29b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.example.whisper; | |
import androidx.appcompat.app.AppCompatActivity; | |
import androidx.core.app.ActivityCompat; | |
import androidx.core.content.ContextCompat; | |
import android.Manifest; | |
import android.annotation.SuppressLint; | |
import android.content.Context; | |
import android.content.Intent; | |
import android.content.pm.PackageManager; | |
import android.media.AudioFormat; | |
import android.media.AudioRecord; | |
import android.media.MediaRecorder; | |
import android.os.Build; | |
import android.os.Bundle; | |
import android.os.Environment; | |
import android.os.Handler; | |
import android.os.HandlerThread; | |
import android.util.Log; | |
import android.view.View; | |
import android.widget.ArrayAdapter; | |
import android.widget.Button; | |
import android.widget.ListView; | |
import android.widget.TextView; | |
import com.musicg.wave.extension.Spectrogram; | |
import com.orhanobut.logger.AndroidLogAdapter; | |
import com.orhanobut.logger.Logger; | |
import org.pytorch.IValue; | |
import org.pytorch.LiteModuleLoader; | |
import org.pytorch.Module; | |
import org.pytorch.Tensor; | |
import java.io.BufferedOutputStream; | |
import java.io.DataOutputStream; | |
import java.io.File; | |
import java.io.FileInputStream; | |
import java.io.FileNotFoundException; | |
import java.io.FileOutputStream; | |
import java.io.IOException; | |
import java.io.InputStream; | |
import java.io.OutputStream; | |
import java.nio.ByteBuffer; | |
import java.nio.ByteOrder; | |
import java.nio.FloatBuffer; | |
import java.util.ArrayList; | |
import java.util.List; | |
public class MainActivity extends AppCompatActivity implements Runnable { | |
private static final String[] CAMERA_PERMISSION = new String[]{Manifest.permission.CAMERA}; | |
private static final int CAMERA_REQUEST_CODE = 10; | |
Spectrogram spectrogram; | |
ListView simpleListView; | |
ArrayAdapter<String> arrayAdapter; | |
// array objects | |
List<String> commandList = new ArrayList<>(); | |
private static final int RECORDER_BPP = 16; | |
private static final String AUDIO_RECORDER_FILE_EXT_WAV = ".wav"; | |
private static final String AUDIO_RECORDER_FOLDER = "AudioRecorder"; | |
private static final String AUDIO_RECORDER_TEMP_FILE = "record_temp.raw"; | |
private static final int RECORDER_CHANNELS = AudioFormat.CHANNEL_IN_STEREO; | |
private static final int RECORDER_AUDIO_ENCODING = AudioFormat.ENCODING_PCM_16BIT; | |
int bufferSize; | |
private final static int REQUEST_RECORD_AUDIO = 13; | |
private final static int AUDIO_LEN_IN_SECOND = 6; | |
private final static int SAMPLE_RATE = 16000; | |
private final static int RECORDING_LENGTH = SAMPLE_RATE * AUDIO_LEN_IN_SECOND; | |
private Module module; | |
private Button mButton; | |
private int mStart = 1; | |
//Thread runs outside of the activity life cycle, | |
//Should be cleaned properly to prevent any thread leaks | |
private HandlerThread mTimerThread; | |
// --- A Handler allows communicating between Main Thread UI and TimerThread | |
private Handler mTimerHandler; | |
private Runnable mRunnable = new Runnable() { | |
@SuppressLint("DefaultLocale") | |
@Override | |
public void run() { | |
//recursive call to keep thread running | |
mTimerHandler.postDelayed(mRunnable, 1000); | |
MainActivity.this.runOnUiThread( | |
() -> { | |
//set the counter down on button | |
mButton.setText(String.format("Listening - %ds left", AUDIO_LEN_IN_SECOND - mStart)); | |
mStart += 1; | |
Logger.d("Hello"); | |
} | |
); | |
} | |
}; | |
private static final String TAG = MainActivity.class.getName(); | |
@SuppressLint("MissingInflatedId") | |
@Override | |
protected void onCreate(Bundle savedInstanceState) { | |
super.onCreate(savedInstanceState); | |
setContentView(R.layout.activity_main); | |
simpleListView = (ListView) findViewById(R.id.simpleListView); | |
arrayAdapter = new ArrayAdapter<String>(this, | |
R.layout.command_view, R.id.itemTextView, commandList); | |
simpleListView.setAdapter(arrayAdapter); | |
Button enableCamera = findViewById(R.id.enableCamera); | |
enableCamera.setOnClickListener(new View.OnClickListener() { | |
@Override | |
public void onClick(View v) { | |
if (hasCameraPermission()) { | |
enableCamera(); | |
} else { | |
requestPermission(); | |
} | |
} | |
}); | |
//for logging | |
Logger.addLogAdapter(new AndroidLogAdapter()); | |
mButton = findViewById(R.id.btnRecognize); | |
//mTextView = findViewById(R.id.tvResult); | |
mButton.setOnClickListener(new View.OnClickListener() { | |
@SuppressLint("DefaultLocale") | |
@Override | |
public void onClick(View view) { | |
mButton.setText(String.format("Listening - %ds left", AUDIO_LEN_IN_SECOND)); | |
mButton.setEnabled(false); | |
//start a new thread for recording | |
Thread thread = new Thread(MainActivity.this); | |
thread.start(); | |
Logger.d("Thread started"); | |
//start the new thread | |
mTimerThread = new HandlerThread("Timer"); | |
mTimerThread.start(); | |
//start the new handler | |
mTimerHandler = new Handler(mTimerThread.getLooper()); | |
//will call mRunnable every second. | |
//This will run in the Background but will die on destroy | |
mTimerHandler.postDelayed(mRunnable, 1000); | |
} | |
}); | |
requestMicrophonePermission(); | |
} | |
private void enableCamera() { | |
Intent intent = new Intent(this, CameraActivity.class); | |
startActivity(intent); | |
} | |
private void requestMicrophonePermission() { | |
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) { | |
requestPermissions( | |
new String[]{android.Manifest.permission.RECORD_AUDIO}, REQUEST_RECORD_AUDIO); | |
} | |
} | |
@Override | |
protected void onDestroy() { | |
stopTimerThread(); | |
super.onDestroy(); | |
} | |
protected void stopTimerThread() { | |
mTimerThread.quitSafely(); | |
try { | |
mTimerThread.join(); | |
mTimerThread = null; | |
mTimerHandler = null; | |
mStart = 1; | |
} catch (InterruptedException e) { | |
Log.e(TAG, "Error on stopping background thread", e); | |
} | |
} | |
@Override | |
public void run() { | |
Logger.d("Run"); | |
android.os.Process.setThreadPriority(android.os.Process.THREAD_PRIORITY_AUDIO); | |
//getMinBufferSize(int sampleRateInHz, int channelConfig, int audioFormat) | |
//Returns the minimum buffer size required for the successful creation of an AudioRecord object, in byte units. | |
bufferSize = AudioRecord.getMinBufferSize(SAMPLE_RATE, RECORDER_CHANNELS, | |
RECORDER_AUDIO_ENCODING); | |
if (ActivityCompat.checkSelfPermission(this, Manifest.permission.RECORD_AUDIO) != PackageManager.PERMISSION_GRANTED) { | |
// TODO: Consider calling | |
// ActivityCompat#requestPermissions | |
// here to request the missing permissions, and then overriding | |
// public void onRequestPermissionsResult(int requestCode, String[] permissions, | |
// int[] grantResults) | |
// to handle the case where the user grants the permission. See the documentation | |
// for ActivityCompat#requestPermissions for more details. | |
return; | |
} | |
AudioRecord record = new AudioRecord(MediaRecorder.AudioSource.MIC, | |
SAMPLE_RATE, | |
AudioFormat.CHANNEL_IN_MONO, AudioFormat.ENCODING_PCM_16BIT, | |
bufferSize); | |
if (record.getState() != AudioRecord.STATE_INITIALIZED) { | |
Logger.e("Audio Record can't initialize!"); | |
return; | |
} | |
record.startRecording(); | |
long shortsRead = 0; | |
int recordingOffset = 0; | |
short[] audioBuffer = new short[bufferSize / 2]; | |
byte data[] = new byte[bufferSize]; | |
short[] recordingBuffer = new short[RECORDING_LENGTH]; | |
String filename = getTempFilename(); | |
FileOutputStream os = null; | |
try { | |
os = new FileOutputStream(filename); | |
} catch (FileNotFoundException e) { | |
// TODO Auto-generated catch block | |
e.printStackTrace(); | |
} | |
while (shortsRead < RECORDING_LENGTH) { | |
int numberOfShort = record.read(audioBuffer, 0, audioBuffer.length); | |
int read = record.read(data, 0, audioBuffer.length); | |
try { | |
os.write(data); | |
} catch (IOException e) { | |
e.printStackTrace(); | |
} | |
shortsRead += numberOfShort; | |
System.arraycopy(audioBuffer, 0, recordingBuffer, recordingOffset, numberOfShort); | |
recordingOffset += numberOfShort; | |
} | |
try { | |
os.close(); | |
} catch (IOException e) { | |
e.printStackTrace(); | |
} | |
stopTimerThread(); | |
//---------------- | |
copyWaveFile(filename, getFilename()); | |
//---------------- | |
record.stop(); | |
record.release(); | |
//Since you are running in another thread | |
// without the runOnUiThread we get this error | |
// Only the original thread that created a view hierarchy can touch its views. | |
runOnUiThread(new Runnable() { | |
@Override | |
public void run() { | |
mButton.setText("Recognizing..."); | |
} | |
}); | |
float[] floatInputBuffer = new float[RECORDING_LENGTH]; | |
// feed in float values between -1.0f and 1.0f by dividing the signed 16-bit inputs. | |
for (int i = 0; i < RECORDING_LENGTH; ++i) { | |
floatInputBuffer[i] = recordingBuffer[i] / (float) Short.MAX_VALUE; | |
} | |
//convert and save to file | |
final String result = recognize(floatInputBuffer); | |
runOnUiThread(new Runnable() { | |
@Override | |
public void run() { | |
showTranslationResult(result); | |
mButton.setEnabled(true); | |
mButton.setText("Start"); | |
} | |
}); | |
} | |
private String assetFilePath(Context context, String assetName) { | |
File file = new File(context.getFilesDir(), assetName); | |
if (file.exists() && file.length() > 0) { | |
return file.getAbsolutePath(); | |
} | |
try (InputStream is = context.getAssets().open(assetName)) { | |
try (OutputStream os = new FileOutputStream(file)) { | |
byte[] buffer = new byte[4 * 1024]; | |
int read; | |
while ((read = is.read(buffer)) != -1) { | |
os.write(buffer, 0, read); | |
} | |
os.flush(); | |
} | |
return file.getAbsolutePath(); | |
} catch (IOException e) { | |
Log.e(TAG, assetName + ": " + e.getLocalizedMessage()); | |
} | |
return null; | |
} | |
private void showTranslationResult(String result) { | |
//mTextView.setText(result); | |
//put in commandList and notify adaptor of changes. | |
commandList.add(result); | |
arrayAdapter.notifyDataSetChanged(); | |
} | |
private String recognize(float[] floatInputBuffer) { | |
if (module == null) { | |
module = LiteModuleLoader.load(assetFilePath(getApplicationContext(), "wav2vec2.ptl")); | |
} | |
double wav2vecinput[] = new double[RECORDING_LENGTH]; | |
for (int n = 0; n < RECORDING_LENGTH; n++) | |
wav2vecinput[n] = floatInputBuffer[n]; | |
FloatBuffer inTensorBuffer = Tensor.allocateFloatBuffer(RECORDING_LENGTH); | |
for (double val : wav2vecinput) | |
inTensorBuffer.put((float) val); | |
Tensor inTensor = Tensor.fromBlob(inTensorBuffer, new long[]{1, RECORDING_LENGTH}); | |
final String result = module.forward(IValue.from(inTensor)).toStr(); | |
return result; | |
} | |
private boolean hasCameraPermission() { | |
return ContextCompat.checkSelfPermission( | |
this, | |
Manifest.permission.CAMERA | |
) == PackageManager.PERMISSION_GRANTED; | |
} | |
private void requestPermission() { | |
ActivityCompat.requestPermissions( | |
this, | |
CAMERA_PERMISSION, | |
CAMERA_REQUEST_CODE | |
); | |
} | |
private String getFilename() { | |
File filepath = Environment.getExternalStoragePublicDirectory(Environment.DIRECTORY_DOCUMENTS); | |
File dir = new File (filepath + "/new_wav_directory"); | |
if (dir.exists() == false){ | |
dir.mkdirs(); | |
} | |
File file = new File(filepath, AUDIO_RECORDER_FOLDER); | |
if (!file.exists()) { | |
file.mkdirs(); | |
} | |
return (file.getAbsolutePath() + "/" + System.currentTimeMillis() + AUDIO_RECORDER_FILE_EXT_WAV); | |
} | |
private String getTempFilename() { | |
File filepath = Environment.getExternalStoragePublicDirectory(Environment.DIRECTORY_DOCUMENTS); | |
File dir = new File (filepath + "/new_wav_directory"); | |
if (dir.exists() == false){ | |
dir.mkdirs(); | |
} | |
File tempFile = new File(dir, AUDIO_RECORDER_TEMP_FILE); | |
if (tempFile.exists()) | |
tempFile.delete(); | |
return (dir.getAbsolutePath() + "/" + AUDIO_RECORDER_TEMP_FILE); | |
} | |
private void copyWaveFile(String inFilename, String outFilename) { | |
FileInputStream in = null; | |
FileOutputStream out = null; | |
long totalAudioLen = 0; | |
long totalDataLen = totalAudioLen + 36; | |
long longSampleRate = SAMPLE_RATE; | |
//we are using mono should be 1 channel | |
int channels = 1; | |
long byteRate = RECORDER_BPP * SAMPLE_RATE * channels / 8; | |
byte[] data = new byte[bufferSize]; | |
try { | |
in = new FileInputStream(inFilename); | |
out = new FileOutputStream(outFilename); | |
totalAudioLen = in.getChannel().size(); | |
totalDataLen = totalAudioLen + 36; | |
Logger.d("File size: " + totalDataLen); | |
WriteWaveFileHeader(out, totalAudioLen, totalDataLen, | |
longSampleRate, channels, byteRate); | |
while ( in .read(data) != -1) { | |
out.write(data); | |
} | |
in .close(); | |
out.close(); | |
} catch (FileNotFoundException e) { | |
e.printStackTrace(); | |
} catch (IOException e) { | |
e.printStackTrace(); | |
} | |
} | |
private void WriteWaveFileHeader( | |
FileOutputStream out, long totalAudioLen, | |
long totalDataLen, long longSampleRate, int channels, | |
long byteRate) throws IOException { | |
byte[] header = new byte[44]; | |
header[0] = 'R'; // RIFF/WAVE header | |
header[1] = 'I'; | |
header[2] = 'F'; | |
header[3] = 'F'; | |
header[4] = (byte)(totalDataLen & 0xff); | |
header[5] = (byte)((totalDataLen >> 8) & 0xff); | |
header[6] = (byte)((totalDataLen >> 16) & 0xff); | |
header[7] = (byte)((totalDataLen >> 24) & 0xff); | |
header[8] = 'W'; | |
header[9] = 'A'; | |
header[10] = 'V'; | |
header[11] = 'E'; | |
header[12] = 'f'; // 'fmt ' chunk | |
header[13] = 'm'; | |
header[14] = 't'; | |
header[15] = ' '; | |
header[16] = 16; // 4 bytes: size of 'fmt ' chunk | |
header[17] = 0; | |
header[18] = 0; | |
header[19] = 0; | |
header[20] = 1; // format = 1 | |
header[21] = 0; | |
header[22] = (byte) channels; | |
header[23] = 0; | |
header[24] = (byte)(longSampleRate & 0xff); | |
header[25] = (byte)((longSampleRate >> 8) & 0xff); | |
header[26] = (byte)((longSampleRate >> 16) & 0xff); | |
header[27] = (byte)((longSampleRate >> 24) & 0xff); | |
header[28] = (byte)(byteRate & 0xff); | |
header[29] = (byte)((byteRate >> 8) & 0xff); | |
header[30] = (byte)((byteRate >> 16) & 0xff); | |
header[31] = (byte)((byteRate >> 24) & 0xff); | |
header[32] = (byte)(2 * 16 / 8); // block align | |
header[33] = 0; | |
header[34] = RECORDER_BPP; // bits per sample | |
header[35] = 0; | |
header[36] = 'd'; | |
header[37] = 'a'; | |
header[38] = 't'; | |
header[39] = 'a'; | |
header[40] = (byte)(totalAudioLen & 0xff); | |
header[41] = (byte)((totalAudioLen >> 8) & 0xff); | |
header[42] = (byte)((totalAudioLen >> 16) & 0xff); | |
header[43] = (byte)((totalAudioLen >> 24) & 0xff); | |
out.write(header, 0, 44); | |
} | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment