Skip to content

Instantly share code, notes, and snippets.

@Gerold103
Last active May 2, 2024 11:11
Show Gist options
  • Save Gerold103/3e400ebdcd08824d6df44d408e9ce7a5 to your computer and use it in GitHub Desktop.
Save Gerold103/3e400ebdcd08824d6df44d408e9ce7a5 to your computer and use it in GitHub Desktop.
CTR Predictor Test
cmake_minimum_required(VERSION 3.10)
project(predictor)
set(CMAKE_CXX_STANDARD 17)
# Append Torch path
list(APPEND CMAKE_PREFIX_PATH "${CMAKE_SOURCE_DIR}/libtorch")
# Find Torch
find_package(Torch REQUIRED)
message(STATUS "Torch found: ${Torch_FOUND}")
# Define the executable
add_executable(predictor main.cpp)
# Link libraries
target_link_libraries(predictor ${TORCH_LIBRARIES})
//#include <chrono>
#include <torch/script.h>
// Campaign object like in the adserver code.
struct Campaign
{
// Real CTR of the campaign for the past week.
float weeklyCtr_;
};
// Content unit object like in the adserver code.
struct ContentUnit
{
// Encoded bitmask of which learning tags the content unit has.
std::vector<bool> tagBits_;
};
// Campaign and content unit containers, also like in the adserver code.
struct Mappings
{
std::map<uint32_t, Campaign> campaigns_;
std::map<uint32_t, ContentUnit> contentUnits_;
};
static torch::Tensor
torchPrepareRow(const Campaign& camp, const ContentUnit& cu)
{
std::vector<float> inputRow;
inputRow.resize(1 + cu.tagBits_.size());
inputRow[0] = camp.weeklyCtr_;
for (size_t i = 1; i < inputRow.size(); ++i)
{
inputRow[i] = cu.tagBits_[i - 1] ? 1.0 : 0.0;
}
return torch::tensor(std::move(inputRow));
}
static std::vector<c10::IValue>
torchConsolidateRows(torch::Tensor&& row)
{
return {torch::stack(std::move(row))};
}
static std::vector<c10::IValue>
torchConsolidateRows(std::vector<torch::Tensor>&& rows)
{
return {torch::stack(std::move(rows))};
}
struct Request
{
Request(const Campaign& camp, const ContentUnit& cu) : camp_(camp), cu_(cu), probability_(0) {}
const Campaign& camp_;
const ContentUnit& cu_;
float probability_;
};
// Model object. It encapsulates the implementation. For any given campaign and content
// unit it simply returns a float. All the tensor stuff is done inside, hidden from the
// class user.
struct Model
{
Model()
: bin_(torch::jit::load("model.pt"))
{
}
void
predict(Request* req)
{
torch::Tensor output = bin_.forward(torchConsolidateRows(torchPrepareRow(req->camp_, req->cu_))).toTensor();
if (output.size(0) != 1 or output.size(1) != 1)
{
std::cout << "Couldn't predict\n";
abort();
}
req->probability_ = output[0][0].item<float>();
}
void
predict(std::vector<Request>& reqs)
{
std::vector<torch::Tensor> inputRows;
inputRows.reserve(reqs.size());
for (Request& r : reqs)
{
inputRows.emplace_back(torchPrepareRow(r.camp_, r.cu_));
}
torch::Tensor output = bin_.forward(torchConsolidateRows(std::move(inputRows))).toTensor();
if (output.size(0) != reqs.size() or output.size(1) != 1)
{
std::cout << "Couldn't predict\n";
abort();
}
for (size_t i = 0; i < reqs.size(); ++i)
{
float res = output[i][0].item<float>();
if (res < 0 or res > 1)
{
std::cout << "Invalid probability\n";
abort();
}
reqs[i].probability_ = res;
}
}
torch::jit::Module bin_;
};
// One view sample from the real word recording of views. For testing.
struct Sample
{
uint32_t cuID_;
uint32_t campID_;
};
static std::vector<std::string>
split(const std::string &s, char delimiter) {
std::vector<std::string> tokens;
std::string token;
std::istringstream tokenStream(s);
while (std::getline(tokenStream, token, delimiter)) {
tokens.push_back(token);
}
return tokens;
}
static void
readCampaignRates(Mappings& maps)
{
std::ifstream in("mapping_campaign_rates.csv");
if (not in.is_open())
{
std::cout << "Couldn't open campaign rates\n";
abort();
}
std::string line;
if (not std::getline(in, line))
{
std::cout << "Couldn't read campaign rates header\n";
abort();
}
if (line != ",campaign_id,rate")
{
std::cout << "Bad campaign rates header\n";
abort();
}
size_t lineIdx = 1;
while (std::getline(in, line))
{
++lineIdx;
std::vector<std::string> columns = split(line, ',');
if (columns.size() != 2 and columns.size() != 3)
{
std::cout << "Bad campaign rates line " << lineIdx << '\n';
std::cout << line << '\n';
abort();
}
unsigned long id = std::stoul(columns[1]);
if (id > UINT32_MAX)
{
std::cout << "Too big campaign ID\n";
abort();
}
Campaign camp;
if (columns.size() == 3)
{
camp.weeklyCtr_ = std::stof(columns[2]);
}
else
{
camp.weeklyCtr_ = 0.005;
}
auto [it, ok] = maps.campaigns_.emplace(id, std::move(camp));
if (not ok)
{
std::cout << "Campaign ID must be unique\n";
abort();
}
}
}
static void
readContentUnits(Mappings& maps)
{
std::ifstream in("mapping_contentunit_tags.csv");
if (not in.is_open())
{
std::cout << "Couldn't open content unit tags\n";
abort();
}
std::string line;
if (not std::getline(in, line))
{
std::cout << "Couldn't read content unit tags header\n";
abort();
}
if (line.rfind("contentunit_id,", 0) != 0)
{
std::cout << "Bad content unit tags header\n";
abort();
}
while (std::getline(in, line))
{
std::vector<std::string> columns = split(line, ',');
if (columns.size() < 2)
{
std::cout << "Bad content unit tags line\n";
abort();
}
unsigned long id = std::stoul(columns[0]);
if (id > UINT32_MAX)
{
std::cout << "Too big campaign ID\n";
abort();
}
ContentUnit unit;
unit.tagBits_.resize(columns.size() - 1);
for (size_t i = 1; i < columns.size(); ++i)
{
float bit = std::stof(columns[i]);
if (bit != 0 and bit != 1)
{
std::cout << "Bad tag bit\n";
abort();
}
unit.tagBits_[i - 1] = bit == 0 ? false : true;
}
auto [it, ok] = maps.contentUnits_.emplace(id, std::move(unit));
if (not ok)
{
std::cout << "Content unit ID must be unique\n";
abort();
}
}
}
static void
readFiles(Mappings& maps)
{
readCampaignRates(maps);
readContentUnits(maps);
}
static std::vector<Sample>
readSampleData()
{
std::ifstream inputFile("sample_data.csv");
if (not inputFile.is_open())
{
std::cout << "Couldn't open input file\n";
abort();
}
std::string line;
if (not std::getline(inputFile, line))
{
std::cout << "Couldn't read input file header\n";
abort();
}
std::vector<std::string> columnNames = split(line, ';');
ssize_t campIDIdx = -1;
ssize_t cuIDIdx = -1;
for (size_t i = 0; i < columnNames.size(); ++i)
{
if (columnNames[i] == "campaign_id")
{
if (campIDIdx != -1)
{
std::cout << "Found second campaign ID column in input file\n";
abort();
}
campIDIdx = i;
continue;
}
if (columnNames[i] == "contentunit_id")
{
if (cuIDIdx != -1)
{
std::cout << "Found second content unit ID column in input file\n";
abort();
}
cuIDIdx = i;
continue;
}
}
if (campIDIdx == -1 or cuIDIdx == -1)
{
std::cout << "Didn't find needed columns in input file\n";
abort();
}
std::vector<Sample> result;
size_t lineIdx = 1;
while (std::getline(inputFile, line))
{
++lineIdx;
std::vector<std::string> columns = split(line, ';');
if (columns.size() <= cuIDIdx or columns.size() <= campIDIdx)
{
std::cout << "Too short line " << lineIdx << " in the input file\n";
std::cout << line << '\n';
abort();
}
unsigned long campID = std::stoul(columns[campIDIdx]);
unsigned long cuID = std::stoul(columns[cuIDIdx]);
if (campID > UINT32_MAX or cuID > UINT32_MAX)
{
std::cout << "Too large IDs in the input file\n";
abort();
}
result.emplace_back();
Sample& s = result.back();
s.campID_ = campID;
s.cuID_ = cuID;
}
return result;
}
int
main()
{
Mappings maps;
readFiles(maps);
std::vector<Sample> input = readSampleData();
std::cout << "Samples: " << input.size() << '\n';
std::chrono::steady_clock::time_point t1 = std::chrono::steady_clock::now();
Model model;
std::chrono::steady_clock::time_point t2 = std::chrono::steady_clock::now();
std::chrono::milliseconds d = std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1);
std::cout << "Model loading time: " << d.count() << " ms\n";
std::vector<Request> requests;
requests.reserve(input.size());
for (const Sample& s : input)
{
const ContentUnit* cu;
{
auto it = maps.contentUnits_.find(s.cuID_);
if (it == maps.contentUnits_.end())
{
std::cout << "Couldn't find a content unit by ID\n";
abort();
}
cu = &it->second;
}
const Campaign* camp;
{
auto it = maps.campaigns_.find(s.campID_);
if (it == maps.campaigns_.end())
{
std::cout << "Couldn't find a campaign by ID\n";
abort();
}
camp = &it->second;
}
requests.emplace_back(*camp, *cu);
}
std::cout << '\n';
//////////////////////////////////////////////////////////////////////////////////////
std::cout << "Predict one request at at time\n";
t1 = std::chrono::steady_clock::now();
for (Request& r : requests)
{
model.predict(&r);
}
t2 = std::chrono::steady_clock::now();
d = std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1);
std::cout
<< "\tTotal prediction time: " << d.count() << " ms\n"
<< "\tRPS: " << (input.size() * 1000 / d.count()) << '\n'
<< "\tTime per request: " << (d.count() + 0.0) / input.size() << " ms\n\n";
//////////////////////////////////////////////////////////////////////////////////////
std::cout << "Predict all requests at once\n";
t1 = std::chrono::steady_clock::now();
model.predict(requests);
t2 = std::chrono::steady_clock::now();
d = std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1);
std::cout
<< "\tTotal prediction time: " << d.count() << " ms\n"
<< "\tRPS: " << (input.size() * 1000 / d.count()) << '\n'
<< "\tTime per request: " << (d.count() + 0.0) / input.size() << " ms\n\n";
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment