Skip to content

Instantly share code, notes, and snippets.

@MinaGabriel
Created November 7, 2022 11:08
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save MinaGabriel/7870fd55a2aa2d557d7787068adad29b to your computer and use it in GitHub Desktop.
Save MinaGabriel/7870fd55a2aa2d557d7787068adad29b to your computer and use it in GitHub Desktop.
package com.example.whisper;
import androidx.appcompat.app.AppCompatActivity;
import androidx.core.app.ActivityCompat;
import androidx.core.content.ContextCompat;
import android.Manifest;
import android.annotation.SuppressLint;
import android.content.Context;
import android.content.Intent;
import android.content.pm.PackageManager;
import android.media.AudioFormat;
import android.media.AudioRecord;
import android.media.MediaRecorder;
import android.os.Build;
import android.os.Bundle;
import android.os.Environment;
import android.os.Handler;
import android.os.HandlerThread;
import android.util.Log;
import android.view.View;
import android.widget.ArrayAdapter;
import android.widget.Button;
import android.widget.ListView;
import android.widget.TextView;
import com.musicg.wave.extension.Spectrogram;
import com.orhanobut.logger.AndroidLogAdapter;
import com.orhanobut.logger.Logger;
import org.pytorch.IValue;
import org.pytorch.LiteModuleLoader;
import org.pytorch.Module;
import org.pytorch.Tensor;
import java.io.BufferedOutputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.util.ArrayList;
import java.util.List;
public class MainActivity extends AppCompatActivity implements Runnable {
private static final String[] CAMERA_PERMISSION = new String[]{Manifest.permission.CAMERA};
private static final int CAMERA_REQUEST_CODE = 10;
Spectrogram spectrogram;
ListView simpleListView;
ArrayAdapter<String> arrayAdapter;
// array objects
List<String> commandList = new ArrayList<>();
private static final int RECORDER_BPP = 16;
private static final String AUDIO_RECORDER_FILE_EXT_WAV = ".wav";
private static final String AUDIO_RECORDER_FOLDER = "AudioRecorder";
private static final String AUDIO_RECORDER_TEMP_FILE = "record_temp.raw";
private static final int RECORDER_CHANNELS = AudioFormat.CHANNEL_IN_STEREO;
private static final int RECORDER_AUDIO_ENCODING = AudioFormat.ENCODING_PCM_16BIT;
int bufferSize;
private final static int REQUEST_RECORD_AUDIO = 13;
private final static int AUDIO_LEN_IN_SECOND = 6;
private final static int SAMPLE_RATE = 16000;
private final static int RECORDING_LENGTH = SAMPLE_RATE * AUDIO_LEN_IN_SECOND;
private Module module;
private Button mButton;
private int mStart = 1;
//Thread runs outside of the activity life cycle,
//Should be cleaned properly to prevent any thread leaks
private HandlerThread mTimerThread;
// --- A Handler allows communicating between Main Thread UI and TimerThread
private Handler mTimerHandler;
private Runnable mRunnable = new Runnable() {
@SuppressLint("DefaultLocale")
@Override
public void run() {
//recursive call to keep thread running
mTimerHandler.postDelayed(mRunnable, 1000);
MainActivity.this.runOnUiThread(
() -> {
//set the counter down on button
mButton.setText(String.format("Listening - %ds left", AUDIO_LEN_IN_SECOND - mStart));
mStart += 1;
Logger.d("Hello");
}
);
}
};
private static final String TAG = MainActivity.class.getName();
@SuppressLint("MissingInflatedId")
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
simpleListView = (ListView) findViewById(R.id.simpleListView);
arrayAdapter = new ArrayAdapter<String>(this,
R.layout.command_view, R.id.itemTextView, commandList);
simpleListView.setAdapter(arrayAdapter);
Button enableCamera = findViewById(R.id.enableCamera);
enableCamera.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View v) {
if (hasCameraPermission()) {
enableCamera();
} else {
requestPermission();
}
}
});
//for logging
Logger.addLogAdapter(new AndroidLogAdapter());
mButton = findViewById(R.id.btnRecognize);
//mTextView = findViewById(R.id.tvResult);
mButton.setOnClickListener(new View.OnClickListener() {
@SuppressLint("DefaultLocale")
@Override
public void onClick(View view) {
mButton.setText(String.format("Listening - %ds left", AUDIO_LEN_IN_SECOND));
mButton.setEnabled(false);
//start a new thread for recording
Thread thread = new Thread(MainActivity.this);
thread.start();
Logger.d("Thread started");
//start the new thread
mTimerThread = new HandlerThread("Timer");
mTimerThread.start();
//start the new handler
mTimerHandler = new Handler(mTimerThread.getLooper());
//will call mRunnable every second.
//This will run in the Background but will die on destroy
mTimerHandler.postDelayed(mRunnable, 1000);
}
});
requestMicrophonePermission();
}
private void enableCamera() {
Intent intent = new Intent(this, CameraActivity.class);
startActivity(intent);
}
private void requestMicrophonePermission() {
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
requestPermissions(
new String[]{android.Manifest.permission.RECORD_AUDIO}, REQUEST_RECORD_AUDIO);
}
}
@Override
protected void onDestroy() {
stopTimerThread();
super.onDestroy();
}
protected void stopTimerThread() {
mTimerThread.quitSafely();
try {
mTimerThread.join();
mTimerThread = null;
mTimerHandler = null;
mStart = 1;
} catch (InterruptedException e) {
Log.e(TAG, "Error on stopping background thread", e);
}
}
@Override
public void run() {
Logger.d("Run");
android.os.Process.setThreadPriority(android.os.Process.THREAD_PRIORITY_AUDIO);
//getMinBufferSize(int sampleRateInHz, int channelConfig, int audioFormat)
//Returns the minimum buffer size required for the successful creation of an AudioRecord object, in byte units.
bufferSize = AudioRecord.getMinBufferSize(SAMPLE_RATE, RECORDER_CHANNELS,
RECORDER_AUDIO_ENCODING);
if (ActivityCompat.checkSelfPermission(this, Manifest.permission.RECORD_AUDIO) != PackageManager.PERMISSION_GRANTED) {
// TODO: Consider calling
// ActivityCompat#requestPermissions
// here to request the missing permissions, and then overriding
// public void onRequestPermissionsResult(int requestCode, String[] permissions,
// int[] grantResults)
// to handle the case where the user grants the permission. See the documentation
// for ActivityCompat#requestPermissions for more details.
return;
}
AudioRecord record = new AudioRecord(MediaRecorder.AudioSource.MIC,
SAMPLE_RATE,
AudioFormat.CHANNEL_IN_MONO, AudioFormat.ENCODING_PCM_16BIT,
bufferSize);
if (record.getState() != AudioRecord.STATE_INITIALIZED) {
Logger.e("Audio Record can't initialize!");
return;
}
record.startRecording();
long shortsRead = 0;
int recordingOffset = 0;
short[] audioBuffer = new short[bufferSize / 2];
byte data[] = new byte[bufferSize];
short[] recordingBuffer = new short[RECORDING_LENGTH];
String filename = getTempFilename();
FileOutputStream os = null;
try {
os = new FileOutputStream(filename);
} catch (FileNotFoundException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
while (shortsRead < RECORDING_LENGTH) {
int numberOfShort = record.read(audioBuffer, 0, audioBuffer.length);
int read = record.read(data, 0, audioBuffer.length);
try {
os.write(data);
} catch (IOException e) {
e.printStackTrace();
}
shortsRead += numberOfShort;
System.arraycopy(audioBuffer, 0, recordingBuffer, recordingOffset, numberOfShort);
recordingOffset += numberOfShort;
}
try {
os.close();
} catch (IOException e) {
e.printStackTrace();
}
stopTimerThread();
//----------------
copyWaveFile(filename, getFilename());
//----------------
record.stop();
record.release();
//Since you are running in another thread
// without the runOnUiThread we get this error
// Only the original thread that created a view hierarchy can touch its views.
runOnUiThread(new Runnable() {
@Override
public void run() {
mButton.setText("Recognizing...");
}
});
float[] floatInputBuffer = new float[RECORDING_LENGTH];
// feed in float values between -1.0f and 1.0f by dividing the signed 16-bit inputs.
for (int i = 0; i < RECORDING_LENGTH; ++i) {
floatInputBuffer[i] = recordingBuffer[i] / (float) Short.MAX_VALUE;
}
//convert and save to file
final String result = recognize(floatInputBuffer);
runOnUiThread(new Runnable() {
@Override
public void run() {
showTranslationResult(result);
mButton.setEnabled(true);
mButton.setText("Start");
}
});
}
private String assetFilePath(Context context, String assetName) {
File file = new File(context.getFilesDir(), assetName);
if (file.exists() && file.length() > 0) {
return file.getAbsolutePath();
}
try (InputStream is = context.getAssets().open(assetName)) {
try (OutputStream os = new FileOutputStream(file)) {
byte[] buffer = new byte[4 * 1024];
int read;
while ((read = is.read(buffer)) != -1) {
os.write(buffer, 0, read);
}
os.flush();
}
return file.getAbsolutePath();
} catch (IOException e) {
Log.e(TAG, assetName + ": " + e.getLocalizedMessage());
}
return null;
}
private void showTranslationResult(String result) {
//mTextView.setText(result);
//put in commandList and notify adaptor of changes.
commandList.add(result);
arrayAdapter.notifyDataSetChanged();
}
private String recognize(float[] floatInputBuffer) {
if (module == null) {
module = LiteModuleLoader.load(assetFilePath(getApplicationContext(), "wav2vec2.ptl"));
}
double wav2vecinput[] = new double[RECORDING_LENGTH];
for (int n = 0; n < RECORDING_LENGTH; n++)
wav2vecinput[n] = floatInputBuffer[n];
FloatBuffer inTensorBuffer = Tensor.allocateFloatBuffer(RECORDING_LENGTH);
for (double val : wav2vecinput)
inTensorBuffer.put((float) val);
Tensor inTensor = Tensor.fromBlob(inTensorBuffer, new long[]{1, RECORDING_LENGTH});
final String result = module.forward(IValue.from(inTensor)).toStr();
return result;
}
private boolean hasCameraPermission() {
return ContextCompat.checkSelfPermission(
this,
Manifest.permission.CAMERA
) == PackageManager.PERMISSION_GRANTED;
}
private void requestPermission() {
ActivityCompat.requestPermissions(
this,
CAMERA_PERMISSION,
CAMERA_REQUEST_CODE
);
}
private String getFilename() {
File filepath = Environment.getExternalStoragePublicDirectory(Environment.DIRECTORY_DOCUMENTS);
File dir = new File (filepath + "/new_wav_directory");
if (dir.exists() == false){
dir.mkdirs();
}
File file = new File(filepath, AUDIO_RECORDER_FOLDER);
if (!file.exists()) {
file.mkdirs();
}
return (file.getAbsolutePath() + "/" + System.currentTimeMillis() + AUDIO_RECORDER_FILE_EXT_WAV);
}
private String getTempFilename() {
File filepath = Environment.getExternalStoragePublicDirectory(Environment.DIRECTORY_DOCUMENTS);
File dir = new File (filepath + "/new_wav_directory");
if (dir.exists() == false){
dir.mkdirs();
}
File tempFile = new File(dir, AUDIO_RECORDER_TEMP_FILE);
if (tempFile.exists())
tempFile.delete();
return (dir.getAbsolutePath() + "/" + AUDIO_RECORDER_TEMP_FILE);
}
private void copyWaveFile(String inFilename, String outFilename) {
FileInputStream in = null;
FileOutputStream out = null;
long totalAudioLen = 0;
long totalDataLen = totalAudioLen + 36;
long longSampleRate = SAMPLE_RATE;
//we are using mono should be 1 channel
int channels = 1;
long byteRate = RECORDER_BPP * SAMPLE_RATE * channels / 8;
byte[] data = new byte[bufferSize];
try {
in = new FileInputStream(inFilename);
out = new FileOutputStream(outFilename);
totalAudioLen = in.getChannel().size();
totalDataLen = totalAudioLen + 36;
Logger.d("File size: " + totalDataLen);
WriteWaveFileHeader(out, totalAudioLen, totalDataLen,
longSampleRate, channels, byteRate);
while ( in .read(data) != -1) {
out.write(data);
}
in .close();
out.close();
} catch (FileNotFoundException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
}
}
private void WriteWaveFileHeader(
FileOutputStream out, long totalAudioLen,
long totalDataLen, long longSampleRate, int channels,
long byteRate) throws IOException {
byte[] header = new byte[44];
header[0] = 'R'; // RIFF/WAVE header
header[1] = 'I';
header[2] = 'F';
header[3] = 'F';
header[4] = (byte)(totalDataLen & 0xff);
header[5] = (byte)((totalDataLen >> 8) & 0xff);
header[6] = (byte)((totalDataLen >> 16) & 0xff);
header[7] = (byte)((totalDataLen >> 24) & 0xff);
header[8] = 'W';
header[9] = 'A';
header[10] = 'V';
header[11] = 'E';
header[12] = 'f'; // 'fmt ' chunk
header[13] = 'm';
header[14] = 't';
header[15] = ' ';
header[16] = 16; // 4 bytes: size of 'fmt ' chunk
header[17] = 0;
header[18] = 0;
header[19] = 0;
header[20] = 1; // format = 1
header[21] = 0;
header[22] = (byte) channels;
header[23] = 0;
header[24] = (byte)(longSampleRate & 0xff);
header[25] = (byte)((longSampleRate >> 8) & 0xff);
header[26] = (byte)((longSampleRate >> 16) & 0xff);
header[27] = (byte)((longSampleRate >> 24) & 0xff);
header[28] = (byte)(byteRate & 0xff);
header[29] = (byte)((byteRate >> 8) & 0xff);
header[30] = (byte)((byteRate >> 16) & 0xff);
header[31] = (byte)((byteRate >> 24) & 0xff);
header[32] = (byte)(2 * 16 / 8); // block align
header[33] = 0;
header[34] = RECORDER_BPP; // bits per sample
header[35] = 0;
header[36] = 'd';
header[37] = 'a';
header[38] = 't';
header[39] = 'a';
header[40] = (byte)(totalAudioLen & 0xff);
header[41] = (byte)((totalAudioLen >> 8) & 0xff);
header[42] = (byte)((totalAudioLen >> 16) & 0xff);
header[43] = (byte)((totalAudioLen >> 24) & 0xff);
out.write(header, 0, 44);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment