Last active
April 3, 2020 12:53
-
-
Save xiong-jie-y/b9651bceb65f856c5ecf1789138b32ed to your computer and use it in GitHub Desktop.
Kiritan AI
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from flask import Flask, request | |
import json | |
from annoy import AnnoyIndex | |
import camphr | |
from urllib.parse import urlparse | |
app = Flask(__name__) | |
nlp = camphr.load( | |
""" | |
lang: | |
name: ja_mecab # lang名 | |
pipeline: | |
transformers_model: | |
trf_name_or_path: bert-base-japanese # モデル名 | |
""" | |
) | |
u = AnnoyIndex(768, 'angular') | |
u.load("index.ann") | |
conv_pairs = json.load(open("conversations.json")) | |
@app.route('/kiritan/talk_to') | |
def get_string_reply(): | |
message = request.args.get('message') | |
closest = u.get_nns_by_vector(nlp(message).vector.tolist(), 1)[0] | |
reply = conv_pairs[closest] | |
audio_url = 'http://localhost:20020/static/audio/' + reply['audio_name'] | |
return { | |
"reply": reply['alice_reply'], | |
"audio_url": audio_url | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* (C) Copyright IBM Corp. 2015, 2020. | |
* | |
* Licensed under the Apache License, Version 2.0 (the "License"); | |
* you may not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
* | |
*/ | |
#pragma warning disable 0649 | |
using UnityEngine; | |
using System.Collections; | |
using System.Collections.Generic; | |
using UnityEngine.UI; | |
using IBM.Watson.SpeechToText.V1; | |
using IBM.Cloud.SDK; | |
using IBM.Cloud.SDK.Authentication; | |
using IBM.Cloud.SDK.Authentication.Iam; | |
using IBM.Cloud.SDK.Utilities; | |
using IBM.Cloud.SDK.DataTypes; | |
using UnityEngine.Networking; | |
namespace IBM.Watsson.Examples | |
{ | |
[System.Serializable] | |
public struct ConversationPair | |
{ | |
public string audio_url; | |
public string reply; | |
} | |
public class ExampleStreaming : MonoBehaviour | |
{ | |
#region PLEASE SET THESE VARIABLES IN THE INSPECTOR | |
[Space(10)] | |
[Tooltip("The service URL (optional). This defaults to \"https://stream.watsonplatform.net/speech-to-text/api\"")] | |
[SerializeField] | |
private string _serviceUrl; | |
// [Tooltip("Text field to display the results of streaming.")] | |
// public Text ResultsField; | |
// [Tooltip("The field to show reply message.")] | |
// public Text ReplyField; | |
public AudioSource audioSource; | |
[Header("IAM Authentication")] | |
[Tooltip("The IAM apikey.")] | |
[SerializeField] | |
private string _iamApikey; | |
[Header("Parameters")] | |
// https://www.ibm.com/watson/developercloud/speech-to-text/api/v1/curl.html?curl#get-model | |
[Tooltip("The Model to use. This defaults to en-US_BroadbandModel")] | |
[SerializeField] | |
private string _recognizeModel; | |
#endregion | |
private int _recordingRoutine = 0; | |
private string _microphoneID = null; | |
private AudioClip _recording = null; | |
private int _recordingBufferSize = 1; | |
private int _recordingHZ = 22050; | |
private const string _talkUrl = "http://192.168.1.102:20020/kiritan/talk_to?message="; | |
AudioClip myClip; | |
private SpeechToTextService _service; | |
void Start() | |
{ | |
StartCoroutine(GetPermission()); | |
// audioSource = GetComponent<AudioSource>(); | |
LogSystem.InstallDefaultReactors(); | |
Runnable.Run(CreateService()); | |
} | |
private IEnumerator GetPermission() { | |
yield return Application.RequestUserAuthorization(UserAuthorization.Microphone); | |
if (Application.HasUserAuthorization(UserAuthorization.Microphone)) | |
{ | |
Debug.Log("Microphone found"); | |
} | |
else | |
{ | |
Debug.Log("Microphone not found"); | |
} | |
} | |
private IEnumerator CreateService() | |
{ | |
if (string.IsNullOrEmpty(_iamApikey)) | |
{ | |
throw new IBMException("Plesae provide IAM ApiKey for the service."); | |
} | |
IamAuthenticator authenticator = new IamAuthenticator(apikey: _iamApikey); | |
// Wait for tokendata | |
while (!authenticator.CanAuthenticate()) | |
yield return null; | |
_service = new SpeechToTextService(authenticator); | |
if (!string.IsNullOrEmpty(_serviceUrl)) | |
{ | |
_service.SetServiceUrl(_serviceUrl); | |
} | |
_service.StreamMultipart = true; | |
Active = true; | |
StartRecording(); | |
} | |
public bool Active | |
{ | |
get { return _service.IsListening; } | |
set | |
{ | |
if (value && !_service.IsListening) | |
{ | |
_service.RecognizeModel = (string.IsNullOrEmpty(_recognizeModel) ? "en-US_BroadbandModel" : _recognizeModel); | |
_service.DetectSilence = true; | |
_service.EnableWordConfidence = true; | |
_service.EnableTimestamps = true; | |
_service.SilenceThreshold = 0.01f; | |
_service.MaxAlternatives = 1; | |
_service.EnableInterimResults = true; | |
_service.OnError = OnError; | |
_service.InactivityTimeout = -1; | |
_service.ProfanityFilter = false; | |
_service.SmartFormatting = true; | |
_service.SpeakerLabels = false; | |
_service.WordAlternativesThreshold = null; | |
_service.EndOfPhraseSilenceTime = null; | |
_service.StartListening(OnRecognize, OnRecognizeSpeaker); | |
} | |
else if (!value && _service.IsListening) | |
{ | |
_service.StopListening(); | |
} | |
} | |
} | |
private void StartRecording() | |
{ | |
if (_recordingRoutine == 0) | |
{ | |
UnityObjectUtil.StartDestroyQueue(); | |
_recordingRoutine = Runnable.Run(RecordingHandler()); | |
} | |
} | |
private void StopRecording() | |
{ | |
if (_recordingRoutine != 0) | |
{ | |
Microphone.End(_microphoneID); | |
Runnable.Stop(_recordingRoutine); | |
_recordingRoutine = 0; | |
} | |
} | |
private void OnError(string error) | |
{ | |
Active = false; | |
Log.Debug("ExampleStreaming.OnError()", "Error! {0}", error); | |
} | |
private IEnumerator RecordingHandler() | |
{ | |
Log.Debug("ExampleStreaming.RecordingHandler()", "devices: {0}", Microphone.devices); | |
_recording = Microphone.Start(_microphoneID, true, _recordingBufferSize, _recordingHZ); | |
yield return null; // let _recordingRoutine get set.. | |
if (_recording == null) | |
{ | |
StopRecording(); | |
yield break; | |
} | |
bool bFirstBlock = true; | |
int midPoint = _recording.samples / 2; | |
float[] samples = null; | |
while (_recordingRoutine != 0 && _recording != null) | |
{ | |
int writePos = Microphone.GetPosition(_microphoneID); | |
if (writePos > _recording.samples || !Microphone.IsRecording(_microphoneID)) | |
{ | |
Log.Error("ExampleStreaming.RecordingHandler()", "Microphone disconnected."); | |
StopRecording(); | |
yield break; | |
} | |
if ((bFirstBlock && writePos >= midPoint) | |
|| (!bFirstBlock && writePos < midPoint)) | |
{ | |
// front block is recorded, make a RecordClip and pass it onto our callback. | |
samples = new float[midPoint]; | |
_recording.GetData(samples, bFirstBlock ? 0 : midPoint); | |
AudioData record = new AudioData(); | |
record.MaxLevel = Mathf.Max(Mathf.Abs(Mathf.Min(samples)), Mathf.Max(samples)); | |
record.Clip = AudioClip.Create("Recording", midPoint, _recording.channels, _recordingHZ, false); | |
record.Clip.SetData(samples, 0); | |
_service.OnListen(record); | |
bFirstBlock = !bFirstBlock; | |
} | |
else | |
{ | |
// calculate the number of samples remaining until we ready for a block of audio, | |
// and wait that amount of time it will take to record. | |
int remaining = bFirstBlock ? (midPoint - writePos) : (_recording.samples - writePos); | |
float timeRemaining = (float)remaining / (float)_recordingHZ; | |
yield return new WaitForSeconds(timeRemaining); | |
} | |
} | |
yield break; | |
} | |
IEnumerator OnFinalecognized(string text) { | |
//URLをGETで用意 | |
UnityWebRequest webRequest = UnityWebRequest.Get(_talkUrl + text); | |
//URLに接続して結果が戻ってくるまで待機 | |
yield return webRequest.SendWebRequest(); | |
//エラーが出ていないかチェック | |
if (webRequest.isNetworkError) | |
{ | |
//通信失敗 | |
Debug.Log(webRequest.error); | |
} | |
else | |
{ | |
//通信成功 | |
Debug.Log(webRequest.downloadHandler.text); | |
// ReplyField.text = webRequest.downloadHandler.text; | |
ConversationPair pair = JsonUtility.FromJson<ConversationPair>(webRequest.downloadHandler.text); | |
StartCoroutine("GetAudioClip", pair.audio_url); | |
} | |
} | |
IEnumerator GetAudioClip(string audio_url) | |
{ | |
using (UnityWebRequest www = UnityWebRequestMultimedia.GetAudioClip(audio_url, AudioType.WAV)) | |
{ | |
yield return www.Send(); | |
// if (www.isError) | |
if (www.isNetworkError) | |
{ | |
Debug.Log(www.error); | |
} | |
else | |
{ | |
myClip = DownloadHandlerAudioClip.GetContent(www); | |
audioSource.clip = myClip; | |
audioSource.Play(); | |
} | |
} | |
} | |
private void OnRecognize(SpeechRecognitionEvent result) | |
{ | |
if (result != null && result.results.Length > 0) | |
{ | |
foreach (var res in result.results) | |
{ | |
foreach (var alt in res.alternatives) | |
{ | |
string text = string.Format("{0} ({1}, {2:0.00})\n", alt.transcript, res.final ? "Final" : "Interim", alt.confidence); | |
Log.Debug("ExampleStreaming.OnRecognize()", text); | |
// ResultsField.text = text; | |
if (res.final) { | |
StartCoroutine("OnFinalecognized", text); | |
} | |
} | |
if (res.keywords_result != null && res.keywords_result.keyword != null) | |
{ | |
foreach (var keyword in res.keywords_result.keyword) | |
{ | |
Log.Debug("ExampleStreaming.OnRecognize()", "keyword: {0}, confidence: {1}, start time: {2}, end time: {3}", keyword.normalized_text, keyword.confidence, keyword.start_time, keyword.end_time); | |
} | |
} | |
if (res.word_alternatives != null) | |
{ | |
foreach (var wordAlternative in res.word_alternatives) | |
{ | |
Log.Debug("ExampleStreaming.OnRecognize()", "Word alternatives found. Start time: {0} | EndTime: {1}", wordAlternative.start_time, wordAlternative.end_time); | |
foreach (var alternative in wordAlternative.alternatives) | |
Log.Debug("ExampleStreaming.OnRecognize()", "\t word: {0} | confidence: {1}", alternative.word, alternative.confidence); | |
} | |
} | |
} | |
} | |
} | |
private void OnRecognizeSpeaker(SpeakerRecognitionEvent result) | |
{ | |
if (result != null) | |
{ | |
foreach (SpeakerLabelsResult labelResult in result.speaker_labels) | |
{ | |
Log.Debug("ExampleStreaming.OnRecognizeSpeaker()", string.Format("speaker result: {0} | confidence: {3} | from: {1} | to: {2}", labelResult.speaker, labelResult.from, labelResult.to, labelResult.confidence)); | |
} | |
} | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import camphr | |
import json | |
import pickle | |
from annoy import AnnoyIndex | |
nlp = camphr.load( | |
""" | |
lang: | |
name: ja_mecab # lang名 | |
pipeline: | |
transformers_model: | |
trf_name_or_path: bert-base-japanese # モデル名 | |
""" | |
) | |
conversation_examples = [ | |
# ("ただいま", "おかえり", ""), | |
("おやすみなさい。", "おやすみなさい。", "おやすみなさい。.wav"), | |
("疲れた", "お疲れ様", "otsukaresama.wav"), | |
("いってきます", "仕事頑張ってね。", "お仕事頑張ってね。.wav") | |
] | |
conversation_data = [] | |
for bob_talk, alice_reply, audio_name in conversation_examples: | |
doc = nlp(bob_talk) | |
conversation_data.append(dict( | |
embedding=doc.vector.tolist(), | |
bob_talk=bob_talk, | |
audio_name=audio_name, | |
alice_reply=alice_reply)) | |
json.dump(conversation_data, open("conversations.json", 'w')) | |
vec_dimension = len(conversation_data[0]['embedding']) | |
t = AnnoyIndex(vec_dimension, 'angular') # Length of item vector that will be indexed | |
for i, conv_exam in enumerate(conversation_data): | |
t.add_item(i, conv_exam['embedding']) | |
print(f"Dimension is {vec_dimension}.") | |
t.build(10) # 10 trees | |
t.save('index.ann') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment