Last active
May 2, 2024 11:11
-
-
Save Gerold103/3e400ebdcd08824d6df44d408e9ce7a5 to your computer and use it in GitHub Desktop.
CTR Predictor Test
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
cmake_minimum_required(VERSION 3.10) | |
project(predictor) | |
set(CMAKE_CXX_STANDARD 17) | |
# Append Torch path | |
list(APPEND CMAKE_PREFIX_PATH "${CMAKE_SOURCE_DIR}/libtorch") | |
# Find Torch | |
find_package(Torch REQUIRED) | |
message(STATUS "Torch found: ${Torch_FOUND}") | |
# Define the executable | |
add_executable(predictor main.cpp) | |
# Link libraries | |
target_link_libraries(predictor ${TORCH_LIBRARIES}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//#include <chrono> | |
#include <torch/script.h> | |
// Campaign object like in the adserver code. | |
struct Campaign | |
{ | |
// Real CTR of the campaign for the past week. | |
float weeklyCtr_; | |
}; | |
// Content unit object like in the adserver code. | |
struct ContentUnit | |
{ | |
// Encoded bitmask of which learning tags the content unit has. | |
std::vector<bool> tagBits_; | |
}; | |
// Campaign and content unit containers, also like in the adserver code. | |
struct Mappings | |
{ | |
std::map<uint32_t, Campaign> campaigns_; | |
std::map<uint32_t, ContentUnit> contentUnits_; | |
}; | |
static torch::Tensor | |
torchPrepareRow(const Campaign& camp, const ContentUnit& cu) | |
{ | |
std::vector<float> inputRow; | |
inputRow.resize(1 + cu.tagBits_.size()); | |
inputRow[0] = camp.weeklyCtr_; | |
for (size_t i = 1; i < inputRow.size(); ++i) | |
{ | |
inputRow[i] = cu.tagBits_[i - 1] ? 1.0 : 0.0; | |
} | |
return torch::tensor(std::move(inputRow)); | |
} | |
static std::vector<c10::IValue> | |
torchConsolidateRows(torch::Tensor&& row) | |
{ | |
return {torch::stack(std::move(row))}; | |
} | |
static std::vector<c10::IValue> | |
torchConsolidateRows(std::vector<torch::Tensor>&& rows) | |
{ | |
return {torch::stack(std::move(rows))}; | |
} | |
struct Request | |
{ | |
Request(const Campaign& camp, const ContentUnit& cu) : camp_(camp), cu_(cu), probability_(0) {} | |
const Campaign& camp_; | |
const ContentUnit& cu_; | |
float probability_; | |
}; | |
// Model object. It encapsulates the implementation. For any given campaign and content | |
// unit it simply returns a float. All the tensor stuff is done inside, hidden from the | |
// class user. | |
struct Model | |
{ | |
Model() | |
: bin_(torch::jit::load("model.pt")) | |
{ | |
} | |
void | |
predict(Request* req) | |
{ | |
torch::Tensor output = bin_.forward(torchConsolidateRows(torchPrepareRow(req->camp_, req->cu_))).toTensor(); | |
if (output.size(0) != 1 or output.size(1) != 1) | |
{ | |
std::cout << "Couldn't predict\n"; | |
abort(); | |
} | |
req->probability_ = output[0][0].item<float>(); | |
} | |
void | |
predict(std::vector<Request>& reqs) | |
{ | |
std::vector<torch::Tensor> inputRows; | |
inputRows.reserve(reqs.size()); | |
for (Request& r : reqs) | |
{ | |
inputRows.emplace_back(torchPrepareRow(r.camp_, r.cu_)); | |
} | |
torch::Tensor output = bin_.forward(torchConsolidateRows(std::move(inputRows))).toTensor(); | |
if (output.size(0) != reqs.size() or output.size(1) != 1) | |
{ | |
std::cout << "Couldn't predict\n"; | |
abort(); | |
} | |
for (size_t i = 0; i < reqs.size(); ++i) | |
{ | |
float res = output[i][0].item<float>(); | |
if (res < 0 or res > 1) | |
{ | |
std::cout << "Invalid probability\n"; | |
abort(); | |
} | |
reqs[i].probability_ = res; | |
} | |
} | |
torch::jit::Module bin_; | |
}; | |
// One view sample from the real word recording of views. For testing. | |
struct Sample | |
{ | |
uint32_t cuID_; | |
uint32_t campID_; | |
}; | |
static std::vector<std::string> | |
split(const std::string &s, char delimiter) { | |
std::vector<std::string> tokens; | |
std::string token; | |
std::istringstream tokenStream(s); | |
while (std::getline(tokenStream, token, delimiter)) { | |
tokens.push_back(token); | |
} | |
return tokens; | |
} | |
static void | |
readCampaignRates(Mappings& maps) | |
{ | |
std::ifstream in("mapping_campaign_rates.csv"); | |
if (not in.is_open()) | |
{ | |
std::cout << "Couldn't open campaign rates\n"; | |
abort(); | |
} | |
std::string line; | |
if (not std::getline(in, line)) | |
{ | |
std::cout << "Couldn't read campaign rates header\n"; | |
abort(); | |
} | |
if (line != ",campaign_id,rate") | |
{ | |
std::cout << "Bad campaign rates header\n"; | |
abort(); | |
} | |
size_t lineIdx = 1; | |
while (std::getline(in, line)) | |
{ | |
++lineIdx; | |
std::vector<std::string> columns = split(line, ','); | |
if (columns.size() != 2 and columns.size() != 3) | |
{ | |
std::cout << "Bad campaign rates line " << lineIdx << '\n'; | |
std::cout << line << '\n'; | |
abort(); | |
} | |
unsigned long id = std::stoul(columns[1]); | |
if (id > UINT32_MAX) | |
{ | |
std::cout << "Too big campaign ID\n"; | |
abort(); | |
} | |
Campaign camp; | |
if (columns.size() == 3) | |
{ | |
camp.weeklyCtr_ = std::stof(columns[2]); | |
} | |
else | |
{ | |
camp.weeklyCtr_ = 0.005; | |
} | |
auto [it, ok] = maps.campaigns_.emplace(id, std::move(camp)); | |
if (not ok) | |
{ | |
std::cout << "Campaign ID must be unique\n"; | |
abort(); | |
} | |
} | |
} | |
static void | |
readContentUnits(Mappings& maps) | |
{ | |
std::ifstream in("mapping_contentunit_tags.csv"); | |
if (not in.is_open()) | |
{ | |
std::cout << "Couldn't open content unit tags\n"; | |
abort(); | |
} | |
std::string line; | |
if (not std::getline(in, line)) | |
{ | |
std::cout << "Couldn't read content unit tags header\n"; | |
abort(); | |
} | |
if (line.rfind("contentunit_id,", 0) != 0) | |
{ | |
std::cout << "Bad content unit tags header\n"; | |
abort(); | |
} | |
while (std::getline(in, line)) | |
{ | |
std::vector<std::string> columns = split(line, ','); | |
if (columns.size() < 2) | |
{ | |
std::cout << "Bad content unit tags line\n"; | |
abort(); | |
} | |
unsigned long id = std::stoul(columns[0]); | |
if (id > UINT32_MAX) | |
{ | |
std::cout << "Too big campaign ID\n"; | |
abort(); | |
} | |
ContentUnit unit; | |
unit.tagBits_.resize(columns.size() - 1); | |
for (size_t i = 1; i < columns.size(); ++i) | |
{ | |
float bit = std::stof(columns[i]); | |
if (bit != 0 and bit != 1) | |
{ | |
std::cout << "Bad tag bit\n"; | |
abort(); | |
} | |
unit.tagBits_[i - 1] = bit == 0 ? false : true; | |
} | |
auto [it, ok] = maps.contentUnits_.emplace(id, std::move(unit)); | |
if (not ok) | |
{ | |
std::cout << "Content unit ID must be unique\n"; | |
abort(); | |
} | |
} | |
} | |
static void | |
readFiles(Mappings& maps) | |
{ | |
readCampaignRates(maps); | |
readContentUnits(maps); | |
} | |
static std::vector<Sample> | |
readSampleData() | |
{ | |
std::ifstream inputFile("sample_data.csv"); | |
if (not inputFile.is_open()) | |
{ | |
std::cout << "Couldn't open input file\n"; | |
abort(); | |
} | |
std::string line; | |
if (not std::getline(inputFile, line)) | |
{ | |
std::cout << "Couldn't read input file header\n"; | |
abort(); | |
} | |
std::vector<std::string> columnNames = split(line, ';'); | |
ssize_t campIDIdx = -1; | |
ssize_t cuIDIdx = -1; | |
for (size_t i = 0; i < columnNames.size(); ++i) | |
{ | |
if (columnNames[i] == "campaign_id") | |
{ | |
if (campIDIdx != -1) | |
{ | |
std::cout << "Found second campaign ID column in input file\n"; | |
abort(); | |
} | |
campIDIdx = i; | |
continue; | |
} | |
if (columnNames[i] == "contentunit_id") | |
{ | |
if (cuIDIdx != -1) | |
{ | |
std::cout << "Found second content unit ID column in input file\n"; | |
abort(); | |
} | |
cuIDIdx = i; | |
continue; | |
} | |
} | |
if (campIDIdx == -1 or cuIDIdx == -1) | |
{ | |
std::cout << "Didn't find needed columns in input file\n"; | |
abort(); | |
} | |
std::vector<Sample> result; | |
size_t lineIdx = 1; | |
while (std::getline(inputFile, line)) | |
{ | |
++lineIdx; | |
std::vector<std::string> columns = split(line, ';'); | |
if (columns.size() <= cuIDIdx or columns.size() <= campIDIdx) | |
{ | |
std::cout << "Too short line " << lineIdx << " in the input file\n"; | |
std::cout << line << '\n'; | |
abort(); | |
} | |
unsigned long campID = std::stoul(columns[campIDIdx]); | |
unsigned long cuID = std::stoul(columns[cuIDIdx]); | |
if (campID > UINT32_MAX or cuID > UINT32_MAX) | |
{ | |
std::cout << "Too large IDs in the input file\n"; | |
abort(); | |
} | |
result.emplace_back(); | |
Sample& s = result.back(); | |
s.campID_ = campID; | |
s.cuID_ = cuID; | |
} | |
return result; | |
} | |
int | |
main() | |
{ | |
Mappings maps; | |
readFiles(maps); | |
std::vector<Sample> input = readSampleData(); | |
std::cout << "Samples: " << input.size() << '\n'; | |
std::chrono::steady_clock::time_point t1 = std::chrono::steady_clock::now(); | |
Model model; | |
std::chrono::steady_clock::time_point t2 = std::chrono::steady_clock::now(); | |
std::chrono::milliseconds d = std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1); | |
std::cout << "Model loading time: " << d.count() << " ms\n"; | |
std::vector<Request> requests; | |
requests.reserve(input.size()); | |
for (const Sample& s : input) | |
{ | |
const ContentUnit* cu; | |
{ | |
auto it = maps.contentUnits_.find(s.cuID_); | |
if (it == maps.contentUnits_.end()) | |
{ | |
std::cout << "Couldn't find a content unit by ID\n"; | |
abort(); | |
} | |
cu = &it->second; | |
} | |
const Campaign* camp; | |
{ | |
auto it = maps.campaigns_.find(s.campID_); | |
if (it == maps.campaigns_.end()) | |
{ | |
std::cout << "Couldn't find a campaign by ID\n"; | |
abort(); | |
} | |
camp = &it->second; | |
} | |
requests.emplace_back(*camp, *cu); | |
} | |
std::cout << '\n'; | |
////////////////////////////////////////////////////////////////////////////////////// | |
std::cout << "Predict one request at at time\n"; | |
t1 = std::chrono::steady_clock::now(); | |
for (Request& r : requests) | |
{ | |
model.predict(&r); | |
} | |
t2 = std::chrono::steady_clock::now(); | |
d = std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1); | |
std::cout | |
<< "\tTotal prediction time: " << d.count() << " ms\n" | |
<< "\tRPS: " << (input.size() * 1000 / d.count()) << '\n' | |
<< "\tTime per request: " << (d.count() + 0.0) / input.size() << " ms\n\n"; | |
////////////////////////////////////////////////////////////////////////////////////// | |
std::cout << "Predict all requests at once\n"; | |
t1 = std::chrono::steady_clock::now(); | |
model.predict(requests); | |
t2 = std::chrono::steady_clock::now(); | |
d = std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1); | |
std::cout | |
<< "\tTotal prediction time: " << d.count() << " ms\n" | |
<< "\tRPS: " << (input.size() * 1000 / d.count()) << '\n' | |
<< "\tTime per request: " << (d.count() + 0.0) / input.size() << " ms\n\n"; | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment