Skip to content

Instantly share code, notes, and snippets.

@xiong-jie-y
Last active April 3, 2020 12:53
Show Gist options
  • Save xiong-jie-y/b9651bceb65f856c5ecf1789138b32ed to your computer and use it in GitHub Desktop.
Save xiong-jie-y/b9651bceb65f856c5ecf1789138b32ed to your computer and use it in GitHub Desktop.
Kiritan AI
from flask import Flask, request
import json
from annoy import AnnoyIndex
import camphr
from urllib.parse import urlparse
app = Flask(__name__)
nlp = camphr.load(
"""
lang:
name: ja_mecab # lang名
pipeline:
transformers_model:
trf_name_or_path: bert-base-japanese # モデル名
"""
)
u = AnnoyIndex(768, 'angular')
u.load("index.ann")
conv_pairs = json.load(open("conversations.json"))
@app.route('/kiritan/talk_to')
def get_string_reply():
message = request.args.get('message')
closest = u.get_nns_by_vector(nlp(message).vector.tolist(), 1)[0]
reply = conv_pairs[closest]
audio_url = 'http://localhost:20020/static/audio/' + reply['audio_name']
return {
"reply": reply['alice_reply'],
"audio_url": audio_url
}
/**
* (C) Copyright IBM Corp. 2015, 2020.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
#pragma warning disable 0649
using UnityEngine;
using System.Collections;
using System.Collections.Generic;
using UnityEngine.UI;
using IBM.Watson.SpeechToText.V1;
using IBM.Cloud.SDK;
using IBM.Cloud.SDK.Authentication;
using IBM.Cloud.SDK.Authentication.Iam;
using IBM.Cloud.SDK.Utilities;
using IBM.Cloud.SDK.DataTypes;
using UnityEngine.Networking;
namespace IBM.Watsson.Examples
{
[System.Serializable]
public struct ConversationPair
{
public string audio_url;
public string reply;
}
public class ExampleStreaming : MonoBehaviour
{
#region PLEASE SET THESE VARIABLES IN THE INSPECTOR
[Space(10)]
[Tooltip("The service URL (optional). This defaults to \"https://stream.watsonplatform.net/speech-to-text/api\"")]
[SerializeField]
private string _serviceUrl;
// [Tooltip("Text field to display the results of streaming.")]
// public Text ResultsField;
// [Tooltip("The field to show reply message.")]
// public Text ReplyField;
public AudioSource audioSource;
[Header("IAM Authentication")]
[Tooltip("The IAM apikey.")]
[SerializeField]
private string _iamApikey;
[Header("Parameters")]
// https://www.ibm.com/watson/developercloud/speech-to-text/api/v1/curl.html?curl#get-model
[Tooltip("The Model to use. This defaults to en-US_BroadbandModel")]
[SerializeField]
private string _recognizeModel;
#endregion
private int _recordingRoutine = 0;
private string _microphoneID = null;
private AudioClip _recording = null;
private int _recordingBufferSize = 1;
private int _recordingHZ = 22050;
private const string _talkUrl = "http://192.168.1.102:20020/kiritan/talk_to?message=";
AudioClip myClip;
private SpeechToTextService _service;
void Start()
{
StartCoroutine(GetPermission());
// audioSource = GetComponent<AudioSource>();
LogSystem.InstallDefaultReactors();
Runnable.Run(CreateService());
}
private IEnumerator GetPermission() {
yield return Application.RequestUserAuthorization(UserAuthorization.Microphone);
if (Application.HasUserAuthorization(UserAuthorization.Microphone))
{
Debug.Log("Microphone found");
}
else
{
Debug.Log("Microphone not found");
}
}
private IEnumerator CreateService()
{
if (string.IsNullOrEmpty(_iamApikey))
{
throw new IBMException("Plesae provide IAM ApiKey for the service.");
}
IamAuthenticator authenticator = new IamAuthenticator(apikey: _iamApikey);
// Wait for tokendata
while (!authenticator.CanAuthenticate())
yield return null;
_service = new SpeechToTextService(authenticator);
if (!string.IsNullOrEmpty(_serviceUrl))
{
_service.SetServiceUrl(_serviceUrl);
}
_service.StreamMultipart = true;
Active = true;
StartRecording();
}
public bool Active
{
get { return _service.IsListening; }
set
{
if (value && !_service.IsListening)
{
_service.RecognizeModel = (string.IsNullOrEmpty(_recognizeModel) ? "en-US_BroadbandModel" : _recognizeModel);
_service.DetectSilence = true;
_service.EnableWordConfidence = true;
_service.EnableTimestamps = true;
_service.SilenceThreshold = 0.01f;
_service.MaxAlternatives = 1;
_service.EnableInterimResults = true;
_service.OnError = OnError;
_service.InactivityTimeout = -1;
_service.ProfanityFilter = false;
_service.SmartFormatting = true;
_service.SpeakerLabels = false;
_service.WordAlternativesThreshold = null;
_service.EndOfPhraseSilenceTime = null;
_service.StartListening(OnRecognize, OnRecognizeSpeaker);
}
else if (!value && _service.IsListening)
{
_service.StopListening();
}
}
}
private void StartRecording()
{
if (_recordingRoutine == 0)
{
UnityObjectUtil.StartDestroyQueue();
_recordingRoutine = Runnable.Run(RecordingHandler());
}
}
private void StopRecording()
{
if (_recordingRoutine != 0)
{
Microphone.End(_microphoneID);
Runnable.Stop(_recordingRoutine);
_recordingRoutine = 0;
}
}
private void OnError(string error)
{
Active = false;
Log.Debug("ExampleStreaming.OnError()", "Error! {0}", error);
}
private IEnumerator RecordingHandler()
{
Log.Debug("ExampleStreaming.RecordingHandler()", "devices: {0}", Microphone.devices);
_recording = Microphone.Start(_microphoneID, true, _recordingBufferSize, _recordingHZ);
yield return null; // let _recordingRoutine get set..
if (_recording == null)
{
StopRecording();
yield break;
}
bool bFirstBlock = true;
int midPoint = _recording.samples / 2;
float[] samples = null;
while (_recordingRoutine != 0 && _recording != null)
{
int writePos = Microphone.GetPosition(_microphoneID);
if (writePos > _recording.samples || !Microphone.IsRecording(_microphoneID))
{
Log.Error("ExampleStreaming.RecordingHandler()", "Microphone disconnected.");
StopRecording();
yield break;
}
if ((bFirstBlock && writePos >= midPoint)
|| (!bFirstBlock && writePos < midPoint))
{
// front block is recorded, make a RecordClip and pass it onto our callback.
samples = new float[midPoint];
_recording.GetData(samples, bFirstBlock ? 0 : midPoint);
AudioData record = new AudioData();
record.MaxLevel = Mathf.Max(Mathf.Abs(Mathf.Min(samples)), Mathf.Max(samples));
record.Clip = AudioClip.Create("Recording", midPoint, _recording.channels, _recordingHZ, false);
record.Clip.SetData(samples, 0);
_service.OnListen(record);
bFirstBlock = !bFirstBlock;
}
else
{
// calculate the number of samples remaining until we ready for a block of audio,
// and wait that amount of time it will take to record.
int remaining = bFirstBlock ? (midPoint - writePos) : (_recording.samples - writePos);
float timeRemaining = (float)remaining / (float)_recordingHZ;
yield return new WaitForSeconds(timeRemaining);
}
}
yield break;
}
IEnumerator OnFinalecognized(string text) {
//URLをGETで用意
UnityWebRequest webRequest = UnityWebRequest.Get(_talkUrl + text);
//URLに接続して結果が戻ってくるまで待機
yield return webRequest.SendWebRequest();
//エラーが出ていないかチェック
if (webRequest.isNetworkError)
{
//通信失敗
Debug.Log(webRequest.error);
}
else
{
//通信成功
Debug.Log(webRequest.downloadHandler.text);
// ReplyField.text = webRequest.downloadHandler.text;
ConversationPair pair = JsonUtility.FromJson<ConversationPair>(webRequest.downloadHandler.text);
StartCoroutine("GetAudioClip", pair.audio_url);
}
}
IEnumerator GetAudioClip(string audio_url)
{
using (UnityWebRequest www = UnityWebRequestMultimedia.GetAudioClip(audio_url, AudioType.WAV))
{
yield return www.Send();
// if (www.isError)
if (www.isNetworkError)
{
Debug.Log(www.error);
}
else
{
myClip = DownloadHandlerAudioClip.GetContent(www);
audioSource.clip = myClip;
audioSource.Play();
}
}
}
private void OnRecognize(SpeechRecognitionEvent result)
{
if (result != null && result.results.Length > 0)
{
foreach (var res in result.results)
{
foreach (var alt in res.alternatives)
{
string text = string.Format("{0} ({1}, {2:0.00})\n", alt.transcript, res.final ? "Final" : "Interim", alt.confidence);
Log.Debug("ExampleStreaming.OnRecognize()", text);
// ResultsField.text = text;
if (res.final) {
StartCoroutine("OnFinalecognized", text);
}
}
if (res.keywords_result != null && res.keywords_result.keyword != null)
{
foreach (var keyword in res.keywords_result.keyword)
{
Log.Debug("ExampleStreaming.OnRecognize()", "keyword: {0}, confidence: {1}, start time: {2}, end time: {3}", keyword.normalized_text, keyword.confidence, keyword.start_time, keyword.end_time);
}
}
if (res.word_alternatives != null)
{
foreach (var wordAlternative in res.word_alternatives)
{
Log.Debug("ExampleStreaming.OnRecognize()", "Word alternatives found. Start time: {0} | EndTime: {1}", wordAlternative.start_time, wordAlternative.end_time);
foreach (var alternative in wordAlternative.alternatives)
Log.Debug("ExampleStreaming.OnRecognize()", "\t word: {0} | confidence: {1}", alternative.word, alternative.confidence);
}
}
}
}
}
private void OnRecognizeSpeaker(SpeakerRecognitionEvent result)
{
if (result != null)
{
foreach (SpeakerLabelsResult labelResult in result.speaker_labels)
{
Log.Debug("ExampleStreaming.OnRecognizeSpeaker()", string.Format("speaker result: {0} | confidence: {3} | from: {1} | to: {2}", labelResult.speaker, labelResult.from, labelResult.to, labelResult.confidence));
}
}
}
}
}
import camphr
import json
import pickle
from annoy import AnnoyIndex
nlp = camphr.load(
"""
lang:
name: ja_mecab # lang名
pipeline:
transformers_model:
trf_name_or_path: bert-base-japanese # モデル名
"""
)
conversation_examples = [
# ("ただいま", "おかえり", ""),
("おやすみなさい。", "おやすみなさい。", "おやすみなさい。.wav"),
("疲れた", "お疲れ様", "otsukaresama.wav"),
("いってきます", "仕事頑張ってね。", "お仕事頑張ってね。.wav")
]
conversation_data = []
for bob_talk, alice_reply, audio_name in conversation_examples:
doc = nlp(bob_talk)
conversation_data.append(dict(
embedding=doc.vector.tolist(),
bob_talk=bob_talk,
audio_name=audio_name,
alice_reply=alice_reply))
json.dump(conversation_data, open("conversations.json", 'w'))
vec_dimension = len(conversation_data[0]['embedding'])
t = AnnoyIndex(vec_dimension, 'angular') # Length of item vector that will be indexed
for i, conv_exam in enumerate(conversation_data):
t.add_item(i, conv_exam['embedding'])
print(f"Dimension is {vec_dimension}.")
t.build(10) # 10 trees
t.save('index.ann')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment