Skip to content

Instantly share code, notes, and snippets.

@YashasSamaga
Created December 16, 2018 14:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save YashasSamaga/1725320041b5579fdf531c8ac39a2dc8 to your computer and use it in GitHub Desktop.
Save YashasSamaga/1725320041b5579fdf531c8ac39a2dc8 to your computer and use it in GitHub Desktop.
#include "main.h"
#include <dlib/dnn.h>
#include <sampml/svm_classifier.hpp>
#include "tools/fixed_thread_pool.hpp"
#include "iscript.hpp"
#include "classifier.hpp"
#include "transform.hpp"
#include "dnn.hpp"
namespace classifier {
using sample_type = output_vector;
double test_vector_svm(const sample_type& sample) {
static thread_local sampml::trainer::svm_classifier<sample_type> svm;
static thread_local bool loaded = false;
if (loaded == false) {
svm.deserialize("models/svm_classifier.dat");
loaded = true;
}
return svm.test(sample);
}
double test_vector_dnn(const sample_type& sample) {
static thread_local aa_network_type net;
static thread_local bool loaded = false;
if (loaded == false) {
dlib::deserialize("models/dnn_classifier.dat") >> net;
loaded = true;
}
return net(sample);
}
struct queue_item_tag_t {
int playerid;
AMX *amx;
};
struct input_queue_item_t {
sample_type sample;
queue_item_tag_t tag;
};
struct output_queue_item_t {
float probabilities[2];
queue_item_tag_t tag;
};
struct process_functor_t {
void operator()(input_queue_item_t& input, output_queue_item_t& output) {
sample_type& sample = input.sample;
output.probabilities[0] = test_vector_svm(sample);
output.probabilities[1] = test_vector_dnn(sample);
output.tag = input.tag;
}
};
fixed_thread_pool<input_queue_item_t, output_queue_item_t, process_functor_t> pool(2);
void ProcessTick() {
std::vector<output_queue_item_t> results;
pool.deqeue_all(results);
for (auto&& item : results) {
AMX* amx = item.tag.amx;
if (iscript::IsValidAmx(amx)) {
int cb_idx = -1;
if (amx_FindPublic(amx, "OnPlayerSuspectedForAimbot", &cb_idx) != AMX_ERR_NONE || cb_idx < 0) {
// OnPlayerSuspectedForAimbot(playerid, Float:probablities[2])
cell probablities[2] = { amx_ftoc(item.probabilities[0]), amx_ftoc(item.probabilities[1]) };
cell amx_addr, *phys_addr;
amx_Allot(amx, sizeof(probablities) / sizeof(cell), &amx_addr, &phys_addr);
memcpy(phys_addr, probablities, sizeof(probablities));
amx_Push(amx, amx_addr);
amx_Push(amx, item.tag.playerid);
amx_Exec(amx, NULL, cb_idx);
}
}
}
}
namespace natives {
static input_vector pawn_array_to_vector(cell data[]) {
input_vector vector;
for (int i = 0; i < input_vector::NR; i++) {
switch (i) {
case bHit:
case iShooterCameraMode:
case iShooterState:
case iShooterSpecialAction:
case bShooterInVehicle:
case bShooterSurfingVehicle:
case bShooterSurfingObject:
case iShooterWeaponID:
case iShooterSkinID:
case iShooterID:
case iVictimState:
case iVictimSpecialAction:
case bVictimInVehicle:
case bVictimSurfingVehicle:
case bVictimSurfingObject:
case iVictimWeaponID:
case iVictimSkinID:
case iVictimID:
case iHitType:
case iShooterPing:
case iVictimPing:
case iSecond:
case iTick:
vector(i) = (data[i]);
break;
default:
vector(i) = amx_ctof(data[i]);
}
}
return vector;
}
cell AMX_NATIVE_CALL submit_vector(AMX * amx, cell* params)
{
cell playerid = params[1];
cell *data;
amx_GetAddr(amx, params[2], &data);
auto vector = pawn_array_to_vector(data);
static std::array<transformer, MAX_PLAYERS> transformers;
transformers[playerid].submit(vector);
if (transformers[playerid].pool.size()) {
input_queue_item_t item;
item.sample = transformers[playerid].pool.back();
item.tag.amx = amx;
item.tag.playerid = playerid;
transformers[playerid].pool.pop_back();
return true;
}
return false;
}
}
}
template <class InputTaskItem, class OutputTaskItem, class TaskFunction>
class fixed_thread_pool {
public:
fixed_thread_pool(int num) { start(num); }
~fixed_thread_pool() { stop(); }
bool empty() {
return;
std::lock_guard<std::mutex> lock(queue_lock);
return output_queue.empty();
}
void enqueue(InputTaskItem&& task) {
return;
std::unique_lock<std::mutex> lock(queue_lock);
queue.push(task);
lock.unlock();
queue_not_empty.notify_one();
}
OutputTaskItem dequeue() {
return;
std::unique_lock<std::mutex> lock(queue_lock);
assert(queue.size() > 0);
OutputTaskItem item = queue.back();
queue.pop();
lock.unlock();
return item;
}
void deqeue_all(std::vector<OutputTaskItem>& results) {
return;
std::unique_lock<std::mutex> lock(queue_lock);
while (!output_queue.empty()) {
results.push_back(output_queue.back());
output_queue.pop();
}
}
private:
std::vector<std::thread> threads;
std::queue<InputTaskItem> input_queue;
std::queue<OutputTaskItem> output_queue;
std::condition_variable queue_not_empty;
std::mutex queue_lock;
bool stop_threads;
void start(int num = 2) {
stop_threads = false;
std::cout << "INPUT\n" << std::endl;
for (int i = 0; i < num; i++) {
std::cout << "CREATE\n" << std::endl;
threads.push_back(std::thread([](){
std::cout << "Hello from thread " << std::this_thread::get_id() << std::endl;
}));
}
std::cout << "OUTPUT\n" << std::endl;
}
void stop() {
std::unique_lock<std::mutex> lock(queue_lock);
stop_threads = true;
queue_not_empty.notify_all();
for(auto& thread : threads )
thread.join();
}
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment