Skip to content

Instantly share code, notes, and snippets.

@tatyam-prime
Last active July 9, 2024 23:33
Show Gist options
  • Save tatyam-prime/1161da013a31632690d616016d51d743 to your computer and use it in GitHub Desktop.
Save tatyam-prime/1161da013a31632690d616016d51d743 to your computer and use it in GitHub Desktop.
【ICPC国内予選】データセットごとにプロセスを生成して並列化するためのライブラリ
#include <fcntl.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <sys/wait.h>
#include <unistd.h>
#include <array>
#include <cassert>
#include <csignal>
#include <cstdint>
#include <cstdio>
#include <cstring>
#include <format>
#include <iostream>
#include <string>
#include <vector>
void solve();
namespace MP {
enum class PState : std::uint8_t {
BEGIN_INPUT,
END_INPUT,
BEGIN_OUTPUT,
END_OUTPUT,
END_OF_FILE,
};
constexpr int MAX_NUM_DATASET = 1024; // 生成するプロセス数の最大値
constexpr int MAX_NUM_PROCESS = 50; // 同時に実行するプロセス数の最大値
int child_id = -1; // 子 : 自身の番号; 0 から始まる
int input_child_id = 0; // 親 : 入力中の子の番号
int output_child_id = 0; // 親 : 出力中の子の番号
std::vector<pid_t> children; // 親 : 子のプロセス番号 (動作確認に)
PState* state_children; // 親 : 各子の状態; mmap で共有メモリとして確保
PState* state; // 子 : 自身の状態
void wait() { usleep(5000); }
// 子 : child_id を付けて,std::format しながら標準エラー出力に出力
template<class... T> void debugln(std::format_string<T...> fmt, T&&... args) {
std::string s = std::format("#{:03}: ", child_id) + std::format(fmt, std::forward<T>(args)...) + "\n";
std::cerr << s << std::flush;
}
// 子 : fork された子プロセスの初期化
void init_child() {
child_id = input_child_id;
state = state_children + child_id;
*state = PState::BEGIN_INPUT;
}
// 子 : 無
void begin_input() {
debugln("begin_input()");
}
// 子 : 親に入力の終了を伝える
void end_input() {
fflush(stdin); // なぜか都合よく動く
assert(*state == PState::BEGIN_INPUT);
*state = PState::END_INPUT;
debugln("end_input()");
}
// 子 : 出力が許可されるまで sleep
void begin_output() {
assert(*state == PState::END_INPUT || *state == PState::BEGIN_OUTPUT);
while (*state != PState::BEGIN_OUTPUT) wait();
debugln("begin_output()");
}
// 子 : 親に出力の終了を伝える
void end_output() {
fflush(stdout);
assert(*state == PState::BEGIN_OUTPUT);
*state = PState::END_OUTPUT;
debugln("end_output()");
std::exit(0);
}
// 子 : 入力が終了していることを親に伝える
void end_of_file() {
assert(*state == PState::BEGIN_INPUT);
*state = PState::END_OF_FILE;
debugln("end_of_file()");
std::exit(0);
}
// 親に SIGINT したときに子を停止させるシグナルハンドラ
void sigint_handler(int sig) {
for (auto pid : children) {
kill(pid, SIGTERM);
}
std::exit(128 + sig);
}
// 親 : 子プロセスを生成
void spawn_child() {
pid_t pid = fork();
while (pid == -1) {
perror("fork");
sleep(5);
pid = fork();
}
if (pid == 0) {
init_child();
solve();
assert(false);
} else {
children.push_back(pid);
}
}
// 親 : input_child_id を可能なら進める (新しい子を作る) or 子のエラーを検知; これ以上子を作らないなら false を返す
bool next_input() {
if (input_child_id >= MAX_NUM_DATASET) return false;
if (input_child_id - output_child_id >= MAX_NUM_PROCESS) return true;
const auto st = state_children + input_child_id;
const pid_t pid = children[input_child_id];
// 子の終了をノンブロッキングで検知
int status;
if (waitpid(pid, &status, WNOHANG) == pid) {
// END_OF_FILE になっていれば子の生成を停止; そうでなければすべて terminate
if (*st == PState::END_OF_FILE) return false;
debugln("[error] dataset #{:03} exited with status {}, without end_input()\n", input_child_id, status);
sigint_handler(SIGTERM);
} else if (*st == PState::END_INPUT) {
// 次の子を生成
input_child_id++;
spawn_child();
}
return true;
}
// 親 : output_child_id を可能なら進める or 子のエラーを検知
void next_output() {
if (output_child_id >= input_child_id) return;
const auto st = state_children + output_child_id;
const pid_t pid = children[output_child_id];
// 子の終了をノンブロッキングで検知
int status;
if (waitpid(pid, &status, WNOHANG) == pid) {
// END_OUTPUT になっていれば次の子が出力を始める; そうでなければすべて terminate
if (*st == PState::END_OUTPUT) {
output_child_id++;
next_output();
} else {
debugln("[error] dataset #{:03} exited with status {}, without end_output()\n", output_child_id, status);
sigint_handler(SIGTERM);
}
} else if (*st == PState::END_INPUT) {
// 次の子の出力を許可
*st = PState::BEGIN_OUTPUT;
}
}
int main() {
// シグナルハンドラの設定
std::signal(SIGINT, sigint_handler);
// stdin を unbuffered に設定
std::setvbuf(stdin, nullptr, _IONBF, 0);
// state_children の初期化 (共有メモリを開く)
const size_t siz = sizeof(PState) * MAX_NUM_DATASET;
shm_unlink("/state_children");
int fd = shm_open("/state_children", O_RDWR | O_CREAT, 0666);
if (fd == -1) {
perror("shm_open");
return 1;
}
if (ftruncate(fd, siz) == -1) {
perror("ftruncate");
return 1;
}
state_children = (PState*)mmap(nullptr, siz, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
if (state_children == MAP_FAILED) {
perror("mmap");
return 1;
}
// 並列化開始
spawn_child();
while (next_input()) {
next_output();
wait();
}
while (output_child_id < input_child_id) {
next_output();
wait();
}
return 0;
}
} // namespace MP
using MP::debugln;
int main() {
return MP::main();
}
/************************************
使用例
*************************************/
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
void solve() {
// ios_base::sync_with_stdio(false); して動作するかは謎
MP::begin_input();
ll A, B;
cin >> A >> B;
// scanf("%lld%lld", &A, &B);
assert(cin);
debugln("{} {}", A, B);
if (A == 0 && B == 0) MP::end_of_file();
MP::end_input();
MP::begin_output();
cout << A + B << '\n';
MP::end_output();
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment