Skip to content

Instantly share code, notes, and snippets.

@chengscott
Last active March 26, 2024 17:04
Show Gist options
  • Save chengscott/6769e9ff1a2fc2c9b058ad80e2837bd2 to your computer and use it in GitHub Desktop.
Save chengscott/6769e9ff1a2fc2c9b058ad80e2837bd2 to your computer and use it in GitHub Desktop.
#include <iostream>
#include <memory>
#include <mutex>
#include <thread>
class MeanTracker {
int total_ = 0;
float mean_ = 0.f;
public:
float mean() const { return mean_; }
void update(float v, size_t n = 1) {
mean_ = (total_ * mean_ + v) / (total_ + n);
total_ += n;
}
};
class MeanTrackerWithMtx {
mutable std::mutex mtx_;
int total_ = 0;
float mean_ = 0.f;
public:
float mean() const { return mean_; }
void update(float v, size_t n = 1) {
std::lock_guard lk(mtx_);
mean_ = (total_ * mean_ + v) / (total_ + n);
total_ += n;
}
};
using MeanTrackerPtr = std::shared_ptr<MeanTrackerWithMtx>;
int main() {
auto mt = std::make_shared<MeanTrackerWithMtx>();
std::thread t1([](MeanTrackerPtr mt) {
for (size_t i = 0; i < 100; ++i) {
mt->update(i);
}
}, mt);
std::thread t2([](MeanTrackerPtr mt) {
for (size_t i = 100; i < 200; ++i) {
mt->update(i);
}
}, mt);
t1.join();
t2.join();
std::cout << mt->mean() << '\n';
}
#include <iostream>
#include <mutex>
template <class T>
class MutexMixin {
std::mutex mtx;
public:
void update_with_mtx(float v, size_t n = 1) {
std::lock_guard<std::mutex> lk(mtx);
static_cast<T*>(this)->update(v, n);
}
};
class MeanTrackerWithMtx : public MutexMixin<MeanTrackerWithMtx> {
int total_ = 0;
float mean_ = 0.f;
public:
int mean() const { return mean_; }
void update(float v, size_t n = 1) {
mean_ = (total_ * mean_ + v) / (total_ + n);
total_ += n;
}
};
int main() {
MeanTrackerWithMtx mt;
mt.update(1);
mt.update(3);
std::cout << mt.mean() << '\n';
}
#include <iostream>
#include <mutex>
template<class policy_t>
class MeanTracker_t {
policy_t policy_;
int total_ = 0;
float mean_ = 0.f;
public:
int mean() const { return mean_; }
void update(float v, size_t n = 1) {
policy_.prologue();
mean_ = (total_ * mean_ + v) / (total_ + n);
total_ += n;
policy_.epilogue();
}
};
struct base_policy_t {
void prologue() {}
void epilogue() {}
};
class mutex_policy_t {
mutable std::mutex mtx;
public:
void prologue() { mtx.lock(); }
void epilogue() { mtx.unlock(); }
};
using MeanTracker = MeanTracker_t<base_policy_t>;
using MeanTrackerWithMtx = MeanTracker_t<mutex_policy_t>;
int main() {
MeanTrackerWithMtx mt;
mt.update(1);
mt.update(3);
std::cout << mt.mean() << '\n';
}
#include <iostream>
#include <memory>
#include <mutex>
#include <thread>
class MeanTracker {
int total_ = 0;
float mean_ = 0.f;
public:
float mean() const { return mean_; }
void update(float v, size_t n = 1) {
mean_ = (total_ * mean_ + v) / (total_ + n);
total_ += n;
}
};
class MutexMixin {
std::mutex mtx;
protected:
std::mutex &get_mtx() { return mtx; }
};
struct MeanTrackerWithMtx final : private MutexMixin, public MeanTracker {
void update(float v, size_t n = 1) {
std::lock_guard lk(get_mtx());
MeanTracker::update(v, n);
}
};
using MeanTrackerPtr = std::shared_ptr<MeanTrackerWithMtx>;
int main() {
auto mt = std::make_shared<MeanTrackerWithMtx>();
std::thread t1([](MeanTrackerPtr mt) {
for (size_t i = 0; i < 100; ++i) {
mt->update(i);
}
}, mt);
std::thread t2([](MeanTrackerPtr mt) {
for (size_t i = 100; i < 200; ++i) {
mt->update(i);
}
}, mt);
t1.join();
t2.join();
std::cout << mt->mean() << '\n';
MeanTracker m = *mt;
std::cout << m.mean() << '\n';
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment