Skip to content

Instantly share code, notes, and snippets.

@viveret
Last active August 16, 2019 20:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save viveret/300a4f16880de9d65932a4486f58d5c2 to your computer and use it in GitHub Desktop.
Save viveret/300a4f16880de9d65932a4486f58d5c2 to your computer and use it in GitHub Desktop.
Text Sentiment Analysis using Tiny DNN and bag-of-words corpus of 12,000 words and 3 classes (positive, negative, neutral). Also uses JSON for modern C++ and Windows Sockets
#include <cstdlib>
#include <iostream>
#include <vector>
#include "tiny-dnn/tiny_dnn/tiny_dnn.h"
#include <ws2tcpip.h>
#include <stdlib.h>
#include <stdio.h>
#include <limits>
#include "nlohmann/json.hpp"
#pragma comment (lib, "Ws2_32.lib")
using namespace tiny_dnn;
using namespace tiny_dnn::activation;
void trainNN(network<sequential>* nn, adagrad opt);
void testNN(network<sequential>* nn, adagrad opt);
void daemonNN(network<sequential>* nn);
void incomingRequestNN(network<sequential>* nn, SOCKET client);
void loadData();
const int corpus_count = 12000;
const int word_count = 1;
const int sentiment_count = 3;
const auto past_experience_path = "past_experience.bin";
std::string train_path = "training-1.csv";
std::string DEFAULT_PORT = "1111";
tiny_dnn::tensor_t inputs;
tiny_dnn::tensor_t outputs;
int main()
{
WORD versionWanted = MAKEWORD(1, 1);
WSADATA wsaData;
if (WSAStartup(versionWanted, &wsaData)) {
std::cout << "Error starting WSA: " << WSAGetLastError() << std::endl;
return -1;
}
auto should_continue = true;
std::cout << "Creating NN..." << std::endl;
network<sequential>* nn = new network<sequential>();
adagrad opt;
using pool = max_pooling_layer;
using fc = fully_connected_layer;
using relu = relu_layer;
using softmax = softmax_layer;
*nn << fc(corpus_count * word_count, corpus_count / 2)
<< fc(corpus_count / 2, 100)
<< convolutional_layer(100, 1, 10, 1, 1, 3, padding::valid, true, 10) << tanh_layer()
<< max_pooling_layer(10, 1, 3, 10, 10, false) << tanh_layer();
std::cout << "Done creaiting NN, loading past experience..." << std::endl;
if (std::ifstream(past_experience_path).good()) {
nn->load(past_experience_path, content_type::weights);
}
else {
std::cout << "No past experience, starting new!" << std::endl;
}
std::cout << "Done loading past experience." << std::endl;
while (should_continue) {
std::cout << std::endl << std::endl << "Commands:" << std::endl
<< "train\t\t- Training neural network" << std::endl
<< "accuracy\t- Compare neural network to training data" << std::endl
<< "save\t\t- Save neural network to default path " << past_experience_path << std::endl
<< "saveas\t\t- Save neural network to custom path" << std::endl
<< "daemon\t\t- Allow local JSON TCP connections to port " << DEFAULT_PORT << " to classify data" << std::endl
<< "quit\t\t- Quit program. Can also use 'q' or 'exit'." << std::endl << "> ";
std::string nextCommand;
std::cin >> nextCommand;
if (nextCommand._Equal("train")) {
if (inputs.empty()) {
loadData();
}
trainNN(nn, opt);
std::cout << "Success!" << std::endl;
}
else if (nextCommand._Equal("accuracy")) {
if (inputs.empty()) {
loadData();
}
testNN(nn, opt);
std::cout << "Success!" << std::endl;
}
else if (nextCommand._Equal("save")) {
nn->save(past_experience_path, content_type::weights);
std::cout << "Success!" << std::endl;
}
else if (nextCommand._Equal("saveas")) {
std::string save_path;
std::cin >> save_path;
nn->save(save_path, content_type::weights);
std::cout << "Success!" << std::endl;
}
else if (nextCommand._Equal("daemon")) {
daemonNN(nn);
}
else if (nextCommand._Equal("q") || nextCommand._Equal("quit") || nextCommand._Equal("exit")) {
should_continue = false;
}
else {
std::cout << "Invalid command." << std::endl;
}
}
std::cout << std::endl << "Done!" << std::endl;
}
int countFloats(std::string row, char sep) {
auto numItems = 1;
for (int i = 0; i < row.length(); i++) {
if (row[i] == sep) {
numItems++;
}
}
return numItems;
}
float_t* parseFloats(std::string row, int length) {
float_t* cols = new float_t[length];
std::stringstream ss(row);
for (int i = 0; i < length; i++) {
ss >> cols[i];
ss.ignore(1);
}
return cols;
}
std::vector<float_t*>* getAllLinesOfCsv(std::string path, char sep, int output_size) {
std::vector<float_t*>* ret = new std::vector<float_t*>();
std::ifstream file(path);
std::string row;
file >> row;
auto headerCount = countFloats(row, sep);
if (headerCount < 2) {
assert("header count must have width and count");
}
ret->push_back(parseFloats(row, headerCount));
tiny_dnn::progress_display disp((int)ret->front()[1]);
tiny_dnn::timer t;
while (file >> row) {
auto numItems = countFloats(row, sep);
if (numItems != output_size) {
std::cout << "Row " << ret->size() << " has different size: " << numItems << " (should be " << output_size << ")" << std::endl;
}
ret->push_back(parseFloats(row, numItems));
disp += 1;
}
std::cout << "Reading " << ret->size() << " rows complete. " << t.elapsed() << "s elapsed." << std::endl;
return ret;
}
void loadData() {
if (train_path.empty()) {
std::cout << "path to wordbag -> sentiment list: ";
std::cin >> train_path;
}
auto stringData = getAllLinesOfCsv(train_path, ';', corpus_count + 1);
auto header = stringData->at(0);
inputs.resize((int)header[1]);
outputs.resize((int)header[1]);
std::cout << "Loading " << inputs.size() << " rows:" << std::endl;
tiny_dnn::progress_display disp(inputs.size());
tiny_dnn::timer t;
for (int i = 1; i < stringData->size(); i++) {
auto entry = stringData->at(i);
auto sentiment = (int)entry[corpus_count];
vec_t input(corpus_count, 0.0f);
vec_t fitTo(sentiment_count);
for (int w = 0; w < corpus_count; w++) {
input[w] = entry[w];
}
for (int i = 0; i < sentiment_count; i++) {
fitTo[i] = i == (sentiment - 1) ? 1.0f : 0.0f;
}
inputs[i - 1] = input;
outputs[i - 1] = fitTo;
disp += 1;
}
std::cout << "Loading " << inputs.size() << " rows complete. " << t.elapsed() << "s elapsed." << std::endl;
}
void trainNN(network<sequential>* nn, adagrad opt) {
size_t batch_size = 0;
size_t epochs = 0;
if (batch_size == 0 || epochs == 0) {
std::cout << "batch size: ";
std::cin >> batch_size;
std::cout << "epochs: ";
std::cin >> epochs;
}
tiny_dnn::progress_display disp(inputs.size());
tiny_dnn::timer t;
double learning_rate;
std::cout << "learning rate: ";
std::cin >> learning_rate;
opt.alpha = learning_rate;
int epoch = 1;
// create callback
auto on_enumerate_epoch = [&]() {
std::cout << "Epoch " << epoch << "/" << epochs << " finished. "
<< t.elapsed() << "s elapsed." << std::endl;
++epoch;
nn->set_netphase(tiny_dnn::net_phase::test);
float loss = nn->get_loss<mse>(inputs, outputs) / inputs.size();
nn->set_netphase(tiny_dnn::net_phase::train);
std::cout << "mse: " << loss << std::endl;
// continue training.
disp.restart(inputs.size());
t.restart();
};
auto on_enumerate_minibatch = [&]() { disp += batch_size; };
auto ret = nn->fit<mse>(opt, inputs, outputs, batch_size, epochs, on_enumerate_minibatch, on_enumerate_epoch, false, 8);
if (!ret) {
std::cout << "Error while training" << std::endl;
}
else {
std::cout << "Done training" << std::endl;
}
}
void testNN(network<sequential>* nn, adagrad opt) {
int totalCorrectPrediction = 0;
for (int testIndex = 0; testIndex < inputs.size(); testIndex++) {
auto input = inputs.at(testIndex);
auto output = outputs.at(testIndex);
auto ret = nn->predict(input);
if (ret.size() == output.size()) {
int predictedOutput = -1, actualOutput = -1;
float_t maxConfidence = 0.0f;
for (int categoryIndex = 0; categoryIndex < output.size(); categoryIndex++) {
if (output.at(categoryIndex) > 0.5f) {
predictedOutput = categoryIndex;
}
auto confidence = ret.at(categoryIndex);
if (confidence > maxConfidence) {
actualOutput = categoryIndex;
maxConfidence = confidence;
}
}
if (predictedOutput >= 0) {
if (predictedOutput == actualOutput) {
totalCorrectPrediction++;
}
}
else {
std::cout << "Error while training: data id " << testIndex << " has low confidence" << std::endl;
}
std::cout << totalCorrectPrediction * 100.0f / testIndex << "% (" << totalCorrectPrediction << " right and "
<< (testIndex - totalCorrectPrediction + 1) << " wrong out of " << testIndex << " so far)" << std::endl;
}
else {
std::cout << "Error while training: prediction output wrong size" << std::endl;
}
}
std::cout << "TOTAL: " << totalCorrectPrediction * 100.0f / inputs.size() << "% (" << totalCorrectPrediction << " right and "
<< (inputs.size() - totalCorrectPrediction) << " wrong out of " << inputs.size() << ")" << std::endl;
}
void daemonNN(network<sequential>* nn) {
std::string port;
std::cout << "port: ";
std::cin >> port;
struct addrinfo* result = NULL, * ptr = NULL, hints;
ZeroMemory(&hints, sizeof(hints));
hints.ai_family = AF_INET;
hints.ai_socktype = SOCK_STREAM;
hints.ai_protocol = IPPROTO_TCP;
hints.ai_flags = AI_PASSIVE;
// Resolve the local address and port to be used by the server
auto r = getaddrinfo(NULL, port.c_str(), &hints, &result);
if (r != 0) {
std::cout << "getaddrinfo failed: " << r;
WSACleanup();
return;
}
// Create a SOCKET for connecting to server
auto ListenSocket = socket(result->ai_family, result->ai_socktype, result->ai_protocol);
if (ListenSocket == INVALID_SOCKET) {
std::cout << "socket failed with error: " << WSAGetLastError();
freeaddrinfo(result);
WSACleanup();
return;
}
r = bind(ListenSocket, result->ai_addr, (int)result->ai_addrlen);
if (r == SOCKET_ERROR) {
std::cout << "bind failed with error: " << WSAGetLastError() << std::endl;
freeaddrinfo(result);
closesocket(ListenSocket);
WSACleanup();
return;
}
freeaddrinfo(result);
r = listen(ListenSocket, SOMAXCONN);
if (r == SOCKET_ERROR) {
std::cout << "listen failed with error: " << WSAGetLastError() << std::endl;
closesocket(ListenSocket);
WSACleanup();
return;
}
while (true) {
// Accept a client socket
auto ClientSocket = accept(ListenSocket, NULL, NULL);
if (ClientSocket == INVALID_SOCKET) {
std::cout << "accept failed with error: " << WSAGetLastError() << std::endl;
closesocket(ListenSocket);
WSACleanup();
return;
}
incomingRequestNN(nn, ClientSocket);
}
closesocket(ListenSocket);
}
nlohmann::json readJson(SOCKET client) {
tiny_dnn::timer t;
auto iResult = 0;
unsigned long bytesRemaining = 0;
const int recvbuflen = sizeof(std::uint8_t) * 256 * 4;
char recvbuf[recvbuflen];
std::vector<std::uint8_t> inputJson;
// Get content length
iResult = recv(client, recvbuf, recvbuflen, 0);
if (iResult >= sizeof(unsigned long long)) {
bytesRemaining = *((unsigned long long*)recvbuf);
std::cout << "Content length: " << bytesRemaining << std::endl;
bytesRemaining -= iResult - sizeof(unsigned long long);
for (int i = sizeof(unsigned long long); i < iResult / sizeof(std::uint8_t); i++) {
inputJson.push_back(((std::uint8_t*)recvbuf)[i]);
}
}
else {
std::cout << "missing content length" << std::endl;
return nullptr;
}
if (bytesRemaining > 0 && bytesRemaining < INT_MAX) {
tiny_dnn::progress_display disp(bytesRemaining);
do {
iResult = recv(client, recvbuf, recvbuflen, 0);
if (iResult > 0) {
bytesRemaining -= iResult;
disp += iResult;
for (int i = 0; i < iResult / sizeof(std::uint8_t); i++) {
inputJson.push_back(((std::uint8_t*)recvbuf)[i]);
}
}
else if (iResult == 0)
std::cout << "Connection closing..." << std::endl;
else {
std::cout << "recv failed with error: " << WSAGetLastError() << std::endl;
closesocket(client);
WSACleanup();
return nullptr;
}
} while (iResult > 0 && bytesRemaining > 0 && bytesRemaining < INT_MAX);
}
std::cout << "Request read finished. " << t.elapsed() << "s elapsed. Parsing..." << std::endl;
t.restart();
auto ret = nlohmann::json::from_bson(inputJson);
std::cout << "Request parse finished. " << t.elapsed() << "s elapsed." << std::endl;
return ret;
}
void incomingRequestNN(network<sequential>* nn, SOCKET client) {
{
auto addr = ((struct sockaddr_in*) & client)->sin_addr;
char addrStr[INET_ADDRSTRLEN];
inet_ntop(AF_INET, &addr, addrStr, INET_ADDRSTRLEN);
std::cout << "Client " << addrStr << " connected." << std::endl;
}
vec_t input(corpus_count, 0.0f);
auto requestJson = readJson(client);
if (requestJson == nullptr) {
return;
}
nlohmann::json returnJson;
if (requestJson.is_object()) {
auto testVal = requestJson["test"];
if (testVal != nullptr && testVal.is_array() && testVal.size() == corpus_count) {
int i = 0;
for (nlohmann::json::iterator it = testVal.begin(); it != testVal.end(); ++it) {
input[i] = *it;
i++;
}
auto ret = nn->predict(input);
int actualOutput = -1;
float_t maxConfidence = 0.0f;
for (int categoryIndex = 0; categoryIndex < ret.size(); categoryIndex++) {
auto confidence = ret.at(categoryIndex);
if (confidence > maxConfidence) {
actualOutput = categoryIndex;
maxConfidence = confidence;
}
}
returnJson["test"] = actualOutput;
}
else {
returnJson["error"] = "Invalid test";
}
}
else {
returnJson["error"] = "Invalid request (must be an object)";
}
std::vector<std::uint8_t> outputJson = nlohmann::json::to_bson(returnJson);
// Echo the buffer back to the sender
auto iSendResult = send(client, (const char*)outputJson.data(), outputJson.size(), 0);
if (iSendResult == SOCKET_ERROR) {
std::cout << "send failed with error: " << WSAGetLastError() << std::endl;
closesocket(client);
WSACleanup();
return;
}
std::cout << "Response sent, " << iSendResult << " bytes" << std::endl;
// shutdown the connection since we're done
auto iResult = shutdown(client, SD_SEND);
if (iResult == SOCKET_ERROR) {
std::cout << "shutdown failed with error: " << WSAGetLastError() << std::endl;
closesocket(client);
WSACleanup();
return;
}
// cleanup
closesocket(client);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment