Last active
August 16, 2019 20:13
-
-
Save viveret/300a4f16880de9d65932a4486f58d5c2 to your computer and use it in GitHub Desktop.
Text Sentiment Analysis using Tiny DNN and bag-of-words corpus of 12,000 words and 3 classes (positive, negative, neutral). Also uses JSON for modern C++ and Windows Sockets
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <cstdlib> | |
#include <iostream> | |
#include <vector> | |
#include "tiny-dnn/tiny_dnn/tiny_dnn.h" | |
#include <ws2tcpip.h> | |
#include <stdlib.h> | |
#include <stdio.h> | |
#include <limits> | |
#include "nlohmann/json.hpp" | |
#pragma comment (lib, "Ws2_32.lib") | |
using namespace tiny_dnn; | |
using namespace tiny_dnn::activation; | |
void trainNN(network<sequential>* nn, adagrad opt); | |
void testNN(network<sequential>* nn, adagrad opt); | |
void daemonNN(network<sequential>* nn); | |
void incomingRequestNN(network<sequential>* nn, SOCKET client); | |
void loadData(); | |
const int corpus_count = 12000; | |
const int word_count = 1; | |
const int sentiment_count = 3; | |
const auto past_experience_path = "past_experience.bin"; | |
std::string train_path = "training-1.csv"; | |
std::string DEFAULT_PORT = "1111"; | |
tiny_dnn::tensor_t inputs; | |
tiny_dnn::tensor_t outputs; | |
int main() | |
{ | |
WORD versionWanted = MAKEWORD(1, 1); | |
WSADATA wsaData; | |
if (WSAStartup(versionWanted, &wsaData)) { | |
std::cout << "Error starting WSA: " << WSAGetLastError() << std::endl; | |
return -1; | |
} | |
auto should_continue = true; | |
std::cout << "Creating NN..." << std::endl; | |
network<sequential>* nn = new network<sequential>(); | |
adagrad opt; | |
using pool = max_pooling_layer; | |
using fc = fully_connected_layer; | |
using relu = relu_layer; | |
using softmax = softmax_layer; | |
*nn << fc(corpus_count * word_count, corpus_count / 2) | |
<< fc(corpus_count / 2, 100) | |
<< convolutional_layer(100, 1, 10, 1, 1, 3, padding::valid, true, 10) << tanh_layer() | |
<< max_pooling_layer(10, 1, 3, 10, 10, false) << tanh_layer(); | |
std::cout << "Done creaiting NN, loading past experience..." << std::endl; | |
if (std::ifstream(past_experience_path).good()) { | |
nn->load(past_experience_path, content_type::weights); | |
} | |
else { | |
std::cout << "No past experience, starting new!" << std::endl; | |
} | |
std::cout << "Done loading past experience." << std::endl; | |
while (should_continue) { | |
std::cout << std::endl << std::endl << "Commands:" << std::endl | |
<< "train\t\t- Training neural network" << std::endl | |
<< "accuracy\t- Compare neural network to training data" << std::endl | |
<< "save\t\t- Save neural network to default path " << past_experience_path << std::endl | |
<< "saveas\t\t- Save neural network to custom path" << std::endl | |
<< "daemon\t\t- Allow local JSON TCP connections to port " << DEFAULT_PORT << " to classify data" << std::endl | |
<< "quit\t\t- Quit program. Can also use 'q' or 'exit'." << std::endl << "> "; | |
std::string nextCommand; | |
std::cin >> nextCommand; | |
if (nextCommand._Equal("train")) { | |
if (inputs.empty()) { | |
loadData(); | |
} | |
trainNN(nn, opt); | |
std::cout << "Success!" << std::endl; | |
} | |
else if (nextCommand._Equal("accuracy")) { | |
if (inputs.empty()) { | |
loadData(); | |
} | |
testNN(nn, opt); | |
std::cout << "Success!" << std::endl; | |
} | |
else if (nextCommand._Equal("save")) { | |
nn->save(past_experience_path, content_type::weights); | |
std::cout << "Success!" << std::endl; | |
} | |
else if (nextCommand._Equal("saveas")) { | |
std::string save_path; | |
std::cin >> save_path; | |
nn->save(save_path, content_type::weights); | |
std::cout << "Success!" << std::endl; | |
} | |
else if (nextCommand._Equal("daemon")) { | |
daemonNN(nn); | |
} | |
else if (nextCommand._Equal("q") || nextCommand._Equal("quit") || nextCommand._Equal("exit")) { | |
should_continue = false; | |
} | |
else { | |
std::cout << "Invalid command." << std::endl; | |
} | |
} | |
std::cout << std::endl << "Done!" << std::endl; | |
} | |
int countFloats(std::string row, char sep) { | |
auto numItems = 1; | |
for (int i = 0; i < row.length(); i++) { | |
if (row[i] == sep) { | |
numItems++; | |
} | |
} | |
return numItems; | |
} | |
float_t* parseFloats(std::string row, int length) { | |
float_t* cols = new float_t[length]; | |
std::stringstream ss(row); | |
for (int i = 0; i < length; i++) { | |
ss >> cols[i]; | |
ss.ignore(1); | |
} | |
return cols; | |
} | |
std::vector<float_t*>* getAllLinesOfCsv(std::string path, char sep, int output_size) { | |
std::vector<float_t*>* ret = new std::vector<float_t*>(); | |
std::ifstream file(path); | |
std::string row; | |
file >> row; | |
auto headerCount = countFloats(row, sep); | |
if (headerCount < 2) { | |
assert("header count must have width and count"); | |
} | |
ret->push_back(parseFloats(row, headerCount)); | |
tiny_dnn::progress_display disp((int)ret->front()[1]); | |
tiny_dnn::timer t; | |
while (file >> row) { | |
auto numItems = countFloats(row, sep); | |
if (numItems != output_size) { | |
std::cout << "Row " << ret->size() << " has different size: " << numItems << " (should be " << output_size << ")" << std::endl; | |
} | |
ret->push_back(parseFloats(row, numItems)); | |
disp += 1; | |
} | |
std::cout << "Reading " << ret->size() << " rows complete. " << t.elapsed() << "s elapsed." << std::endl; | |
return ret; | |
} | |
void loadData() { | |
if (train_path.empty()) { | |
std::cout << "path to wordbag -> sentiment list: "; | |
std::cin >> train_path; | |
} | |
auto stringData = getAllLinesOfCsv(train_path, ';', corpus_count + 1); | |
auto header = stringData->at(0); | |
inputs.resize((int)header[1]); | |
outputs.resize((int)header[1]); | |
std::cout << "Loading " << inputs.size() << " rows:" << std::endl; | |
tiny_dnn::progress_display disp(inputs.size()); | |
tiny_dnn::timer t; | |
for (int i = 1; i < stringData->size(); i++) { | |
auto entry = stringData->at(i); | |
auto sentiment = (int)entry[corpus_count]; | |
vec_t input(corpus_count, 0.0f); | |
vec_t fitTo(sentiment_count); | |
for (int w = 0; w < corpus_count; w++) { | |
input[w] = entry[w]; | |
} | |
for (int i = 0; i < sentiment_count; i++) { | |
fitTo[i] = i == (sentiment - 1) ? 1.0f : 0.0f; | |
} | |
inputs[i - 1] = input; | |
outputs[i - 1] = fitTo; | |
disp += 1; | |
} | |
std::cout << "Loading " << inputs.size() << " rows complete. " << t.elapsed() << "s elapsed." << std::endl; | |
} | |
void trainNN(network<sequential>* nn, adagrad opt) { | |
size_t batch_size = 0; | |
size_t epochs = 0; | |
if (batch_size == 0 || epochs == 0) { | |
std::cout << "batch size: "; | |
std::cin >> batch_size; | |
std::cout << "epochs: "; | |
std::cin >> epochs; | |
} | |
tiny_dnn::progress_display disp(inputs.size()); | |
tiny_dnn::timer t; | |
double learning_rate; | |
std::cout << "learning rate: "; | |
std::cin >> learning_rate; | |
opt.alpha = learning_rate; | |
int epoch = 1; | |
// create callback | |
auto on_enumerate_epoch = [&]() { | |
std::cout << "Epoch " << epoch << "/" << epochs << " finished. " | |
<< t.elapsed() << "s elapsed." << std::endl; | |
++epoch; | |
nn->set_netphase(tiny_dnn::net_phase::test); | |
float loss = nn->get_loss<mse>(inputs, outputs) / inputs.size(); | |
nn->set_netphase(tiny_dnn::net_phase::train); | |
std::cout << "mse: " << loss << std::endl; | |
// continue training. | |
disp.restart(inputs.size()); | |
t.restart(); | |
}; | |
auto on_enumerate_minibatch = [&]() { disp += batch_size; }; | |
auto ret = nn->fit<mse>(opt, inputs, outputs, batch_size, epochs, on_enumerate_minibatch, on_enumerate_epoch, false, 8); | |
if (!ret) { | |
std::cout << "Error while training" << std::endl; | |
} | |
else { | |
std::cout << "Done training" << std::endl; | |
} | |
} | |
void testNN(network<sequential>* nn, adagrad opt) { | |
int totalCorrectPrediction = 0; | |
for (int testIndex = 0; testIndex < inputs.size(); testIndex++) { | |
auto input = inputs.at(testIndex); | |
auto output = outputs.at(testIndex); | |
auto ret = nn->predict(input); | |
if (ret.size() == output.size()) { | |
int predictedOutput = -1, actualOutput = -1; | |
float_t maxConfidence = 0.0f; | |
for (int categoryIndex = 0; categoryIndex < output.size(); categoryIndex++) { | |
if (output.at(categoryIndex) > 0.5f) { | |
predictedOutput = categoryIndex; | |
} | |
auto confidence = ret.at(categoryIndex); | |
if (confidence > maxConfidence) { | |
actualOutput = categoryIndex; | |
maxConfidence = confidence; | |
} | |
} | |
if (predictedOutput >= 0) { | |
if (predictedOutput == actualOutput) { | |
totalCorrectPrediction++; | |
} | |
} | |
else { | |
std::cout << "Error while training: data id " << testIndex << " has low confidence" << std::endl; | |
} | |
std::cout << totalCorrectPrediction * 100.0f / testIndex << "% (" << totalCorrectPrediction << " right and " | |
<< (testIndex - totalCorrectPrediction + 1) << " wrong out of " << testIndex << " so far)" << std::endl; | |
} | |
else { | |
std::cout << "Error while training: prediction output wrong size" << std::endl; | |
} | |
} | |
std::cout << "TOTAL: " << totalCorrectPrediction * 100.0f / inputs.size() << "% (" << totalCorrectPrediction << " right and " | |
<< (inputs.size() - totalCorrectPrediction) << " wrong out of " << inputs.size() << ")" << std::endl; | |
} | |
void daemonNN(network<sequential>* nn) { | |
std::string port; | |
std::cout << "port: "; | |
std::cin >> port; | |
struct addrinfo* result = NULL, * ptr = NULL, hints; | |
ZeroMemory(&hints, sizeof(hints)); | |
hints.ai_family = AF_INET; | |
hints.ai_socktype = SOCK_STREAM; | |
hints.ai_protocol = IPPROTO_TCP; | |
hints.ai_flags = AI_PASSIVE; | |
// Resolve the local address and port to be used by the server | |
auto r = getaddrinfo(NULL, port.c_str(), &hints, &result); | |
if (r != 0) { | |
std::cout << "getaddrinfo failed: " << r; | |
WSACleanup(); | |
return; | |
} | |
// Create a SOCKET for connecting to server | |
auto ListenSocket = socket(result->ai_family, result->ai_socktype, result->ai_protocol); | |
if (ListenSocket == INVALID_SOCKET) { | |
std::cout << "socket failed with error: " << WSAGetLastError(); | |
freeaddrinfo(result); | |
WSACleanup(); | |
return; | |
} | |
r = bind(ListenSocket, result->ai_addr, (int)result->ai_addrlen); | |
if (r == SOCKET_ERROR) { | |
std::cout << "bind failed with error: " << WSAGetLastError() << std::endl; | |
freeaddrinfo(result); | |
closesocket(ListenSocket); | |
WSACleanup(); | |
return; | |
} | |
freeaddrinfo(result); | |
r = listen(ListenSocket, SOMAXCONN); | |
if (r == SOCKET_ERROR) { | |
std::cout << "listen failed with error: " << WSAGetLastError() << std::endl; | |
closesocket(ListenSocket); | |
WSACleanup(); | |
return; | |
} | |
while (true) { | |
// Accept a client socket | |
auto ClientSocket = accept(ListenSocket, NULL, NULL); | |
if (ClientSocket == INVALID_SOCKET) { | |
std::cout << "accept failed with error: " << WSAGetLastError() << std::endl; | |
closesocket(ListenSocket); | |
WSACleanup(); | |
return; | |
} | |
incomingRequestNN(nn, ClientSocket); | |
} | |
closesocket(ListenSocket); | |
} | |
nlohmann::json readJson(SOCKET client) { | |
tiny_dnn::timer t; | |
auto iResult = 0; | |
unsigned long bytesRemaining = 0; | |
const int recvbuflen = sizeof(std::uint8_t) * 256 * 4; | |
char recvbuf[recvbuflen]; | |
std::vector<std::uint8_t> inputJson; | |
// Get content length | |
iResult = recv(client, recvbuf, recvbuflen, 0); | |
if (iResult >= sizeof(unsigned long long)) { | |
bytesRemaining = *((unsigned long long*)recvbuf); | |
std::cout << "Content length: " << bytesRemaining << std::endl; | |
bytesRemaining -= iResult - sizeof(unsigned long long); | |
for (int i = sizeof(unsigned long long); i < iResult / sizeof(std::uint8_t); i++) { | |
inputJson.push_back(((std::uint8_t*)recvbuf)[i]); | |
} | |
} | |
else { | |
std::cout << "missing content length" << std::endl; | |
return nullptr; | |
} | |
if (bytesRemaining > 0 && bytesRemaining < INT_MAX) { | |
tiny_dnn::progress_display disp(bytesRemaining); | |
do { | |
iResult = recv(client, recvbuf, recvbuflen, 0); | |
if (iResult > 0) { | |
bytesRemaining -= iResult; | |
disp += iResult; | |
for (int i = 0; i < iResult / sizeof(std::uint8_t); i++) { | |
inputJson.push_back(((std::uint8_t*)recvbuf)[i]); | |
} | |
} | |
else if (iResult == 0) | |
std::cout << "Connection closing..." << std::endl; | |
else { | |
std::cout << "recv failed with error: " << WSAGetLastError() << std::endl; | |
closesocket(client); | |
WSACleanup(); | |
return nullptr; | |
} | |
} while (iResult > 0 && bytesRemaining > 0 && bytesRemaining < INT_MAX); | |
} | |
std::cout << "Request read finished. " << t.elapsed() << "s elapsed. Parsing..." << std::endl; | |
t.restart(); | |
auto ret = nlohmann::json::from_bson(inputJson); | |
std::cout << "Request parse finished. " << t.elapsed() << "s elapsed." << std::endl; | |
return ret; | |
} | |
void incomingRequestNN(network<sequential>* nn, SOCKET client) { | |
{ | |
auto addr = ((struct sockaddr_in*) & client)->sin_addr; | |
char addrStr[INET_ADDRSTRLEN]; | |
inet_ntop(AF_INET, &addr, addrStr, INET_ADDRSTRLEN); | |
std::cout << "Client " << addrStr << " connected." << std::endl; | |
} | |
vec_t input(corpus_count, 0.0f); | |
auto requestJson = readJson(client); | |
if (requestJson == nullptr) { | |
return; | |
} | |
nlohmann::json returnJson; | |
if (requestJson.is_object()) { | |
auto testVal = requestJson["test"]; | |
if (testVal != nullptr && testVal.is_array() && testVal.size() == corpus_count) { | |
int i = 0; | |
for (nlohmann::json::iterator it = testVal.begin(); it != testVal.end(); ++it) { | |
input[i] = *it; | |
i++; | |
} | |
auto ret = nn->predict(input); | |
int actualOutput = -1; | |
float_t maxConfidence = 0.0f; | |
for (int categoryIndex = 0; categoryIndex < ret.size(); categoryIndex++) { | |
auto confidence = ret.at(categoryIndex); | |
if (confidence > maxConfidence) { | |
actualOutput = categoryIndex; | |
maxConfidence = confidence; | |
} | |
} | |
returnJson["test"] = actualOutput; | |
} | |
else { | |
returnJson["error"] = "Invalid test"; | |
} | |
} | |
else { | |
returnJson["error"] = "Invalid request (must be an object)"; | |
} | |
std::vector<std::uint8_t> outputJson = nlohmann::json::to_bson(returnJson); | |
// Echo the buffer back to the sender | |
auto iSendResult = send(client, (const char*)outputJson.data(), outputJson.size(), 0); | |
if (iSendResult == SOCKET_ERROR) { | |
std::cout << "send failed with error: " << WSAGetLastError() << std::endl; | |
closesocket(client); | |
WSACleanup(); | |
return; | |
} | |
std::cout << "Response sent, " << iSendResult << " bytes" << std::endl; | |
// shutdown the connection since we're done | |
auto iResult = shutdown(client, SD_SEND); | |
if (iResult == SOCKET_ERROR) { | |
std::cout << "shutdown failed with error: " << WSAGetLastError() << std::endl; | |
closesocket(client); | |
WSACleanup(); | |
return; | |
} | |
// cleanup | |
closesocket(client); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment