Skip to content

Instantly share code, notes, and snippets.

@haifenghuang
Forked from DarinM223/magic.md
Created January 29, 2024 06:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save haifenghuang/aa37c0fe4efe0fa0d2763d68f292071a to your computer and use it in GitHub Desktop.
Save haifenghuang/aa37c0fe4efe0fa0d2763d68f292071a to your computer and use it in GitHub Desktop.
C++ magical template examples:

Generic constraints

In Rust it is easy to constrain a generic type to only include types extending from some interface. For example:

pub trait DoSomething {
    fn do_something(&self);
}

pub fn do_many_things<T>(things: Vec<T>) where T: DoSomething {
    for thing in things {
        thing.do_something();
    }
}

pub struct Things<T: DoSomething> {
    things: Vec<T>,
}

impl<T> Things<T> where T: DoSomething {
    pub fn new(things: Vec<T>) -> Things<T> {
        Things { things: things }
    }
    
    pub fn do_things(&self) {
        for thing in self.things.iter() {
            thing.do_something();
        }
    }
}

Then it is easy to use the do_many_things function and the Things struct with any type that extends the DoSomething interface:

pub struct PrintSomething<'a>(&'a str);
impl<'a> DoSomething for PrintSomething<'a> {
    fn do_something(&self) {
        println!("Printing: {}", self.0);
    }
}

pub struct SayHello;
impl DoSomething for SayHello {
    fn do_something(&self) {
        println!("Hello!");
    }
}

do_many_things(vec![PrintSomething("a"), PrintSomething("b")]); // Compiles
do_many_things(vec![SayHello, SayHello, SayHello]); // Compiles
do_many_things(vec![1, 2, 3]); // Doesn't compile

let things = Things::new(vec![SayHello]); // Compiles
let things = Things::new(vec![PrintSomething("b")]); // Compiles
let things = Things::new(vec![1]); // Doesn't compile

C++ at the moment doesn't have a language feature to do this, but with template magic it is possible to emulate this behavior:

template <typename D, typename B>
struct IsDerivedFrom {
  static void Constraints(D *p) { B *bp = p; }
  IsDerivedFrom() { void (*p)(D*) = Constraints; }
};

It works by adding a static function that sets a base type pointer to the derived type. This can only happen if the derived class inherits the base trait. Because this static function is set to a function pointer in the constructor C++ has to enforce this relationship with compiler errors when the constructor is called.

Right now this definition of IsDerivedFrom has variable not used warnings which can be errors under the right compiler flags. An easy way to prevent variable not used warnings is to add (void)variable; after declaring a variable. This updated definition removes the unused variable warnings:

template <typename D, typename B>
struct IsDerivedFrom {
  static void Constraints(D* p) {
    B* bp = p;
    (void)bp;
  }
  IsDerivedFrom() {
    void (*p)(D*) = Constraints;
    (void)p;
  }
};

This struct can be used to constrain generic types in both functions and classes. To constrain a function you would instantiate the struct at the beginning:

struct DoSomething {
  virtual void doSomething() = 0;
};

template <typename T>
void doManyThings(std::vector<T> things) {
  IsDerivedFrom<T, DoSomething>{}; // Enforce constraint
  for (auto &thing : things) {
    thing.doSomething();
  }
}

And to constrain classes you would inherit the struct (since an inherited class will implicitly call the base classes constructor):

template <typename T>
class Things : public IsDerivedFrom<T, DoSomething> { // Enforce constraint
 public:
  Things(std::vector<T> things) : things_(std::move(things)) {}
  void doThings() {
    for (auto &thing : things_) {
      thing.doSomething();
    }
  }

 private:
  std::vector<T> things_;
};

Both doManyThings and Things will constrain the types similarly to the Rust code:

struct PrintSomething : public DoSomething {
  std::string s;
  PrintSomething(std::string s) : s(s) {}
  void doSomething() override {
    std::cout << "Printing: " << s << "\n";
  }
};

struct SayHello : public DoSomething {
  void doSomething() override {
    std::cout << "Hello!\n";
  }
};

doManyThings(std::vector<PrintSomething>{
    PrintSomething("a"), PrintSomething("b")}); // Compiles
doManyThings(std::vector<SayHello>{SayHello{}, SayHello{}, SayHello{}}); // Compiles
doManyThings(std::vector<int>{1, 2, 3}); // Doesn't compile

Things<PrintSomething> things1{
    std::vector<PrintSomething>{PrintSomething("a"), PrintSomething("b")}}; // Compiles
Things<SayHello> things2{std::vector<SayHello>{SayHello{}, SayHello{}, SayHello{}}}; // Compiles
Things<int> things3{std::vector<int>{1, 2, 3}}; // Doesn't compile

Variant matching

Sum types and pattern matching are useful features that Rust has by default. They allow you to easily write new actions for a fixed amount of subtypes. This is opposite from traditional dynamic dispatch where you can easily write new subtypes for a fixed amount of actions.

An example is a message type for a game. A message can be represented by using an interface and dynamic dispatch:

struct Message {
  virtual void apply(Game &game) = 0; // Apply message to game.
};

struct ActorRemove : public Message {
  size_t id;
  
  ActorRemove(size_t id) : id(id) {}

  void apply(Game &game) override {
    game.actors.erase(game.actors.begin() + id);
  }
};

struct ActorAdd : public Message {
  std::string name;
    
  ActorAdd(std::string name) : name(name) {}

  void apply(Game &game) override {
    game.actors.emplace_back(name);
  }
};

// ...etc

The game will then apply a message like this:

void Game::doSomething() {
  auto msg = getMessage();
  msg->apply(*this);
}

There are some drawbacks with implementing Message like this. First of all, this code mixes game implementation details inside the Message subtypes. That means that you cannot use this message type on non-Game classes without a lot of trouble. This also means if Game has private variables that need to be modified, the message classes have to be friends with the game in order to access them. Secondly, the code called inside the Game method is hard to follow because the implementation details can be spread over a large number of files, some of which may even be in different libraries. Sure it is possible to use an IDE to get all implementations of the method but it is still harder to skim through everything that is happening compared to something like a switch statement.

This is where sum types and pattern matching are useful. Sum types are like enums in C++ or Java but also allow you to embed data inside each subtype. Here's the sum type version of the C++ Message class:

pub enum Message {
    ActorRemove(usize),
    ActorAdd(String),
    // ...etc
}

Then inside Game we can just use pattern matching to define each message implementation. Pattern matching is similar to a switch statement but also lets you easily destructure nested fields, ranges, and other patterns.

impl Game {
    pub fn do_something(&mut self) {
        let msg = get_message();
        match msg {
            Message::ActorRemove(id) => {
                self.actors.remove(id);
            }
            Message::ActorAdd(name) => {
                self.actors.push(Actor { name: name });
            }
            // ...etc
        }
    }
}

When pattern matching is used the code is in the place we want them to be. The game has its implementation details right in its method instead of being spread out over many files and the message type can be used anywhere, not just in Game.

Now that we know about the benefits of sum types it would be nice if C++ could also have them. Luckily, C++17 now includes a sum type container in the standard library, std::variant. Heres the game message example using std::variant:

struct ActorRemove {
  size_t id;

  ActorRemove(size_t id) : id(id) {}
};

struct ActorAdd {
  std::string name;

  ActorAdd(std::string name) : name(name) {}
};

using Message = std::variant<ActorRemove, ActorAdd>;

Now the next thing to do is to use pattern matching inside Game to do different actions for each variant subtype. However, C++ doesn't have syntax for pattern matching and uses visitor classes instead. A visitor class is a class that overrides the operator() method for each subtype. A visitor is used through the std::visit function that takes in a visitor and variant. Here's a visitor example that just prints the type of the message:

void Game::doSomething() {
  auto msg = getMessage();

  struct visitor {
    void operator()(ActorRemove msg) { std::cout << "ActorRemove\n"; }
    void operator()(ActorAdd packet) { std::cout << "ActorAdd\n"; }
  };
  std::visit(visitor{}, msg);
}

This visitor syntax is usable but it would be better if you can pass in different variables into each overloaded function. That way each overloaded function can choose what variables they want to modify.

A way to do this is to use templating magic to create a visitor class from a bunch of lambdas, one for every variant. Lambdas allow for capturing external variables so each lambda can choose what variable they want to modify. Here's what we want to do with this lambda syntax:

void Game::doSomething() {
  auto msg = getMessage();

  std::visit(overloaded {
    [this](ActorRemove msg) { actors.erase(actors.begin() + msg.id); },
    [this](ActorAdd msg) { actors.emplace_back(msg.name); }
  }, msg);
}

In this example, std::visit takes in two lambdas that bind in this, which allows the game's actors to be mutated in both branches in the same way the Rust version works.

How can this syntax be created? By abusing template metaprogramming. Heres the full code needed to create the overloaded { } syntax used in the above example:

template <typename... fns>
struct overloaded : fns... {
  using fns::operator()...;
};

template <typename... Ts>
overloaded(Ts...) -> overloaded<Ts...>;

In order to understand this code its best to break it into parts. First lets start with the first struct definition:

template <typename... fns>
struct overloaded : fns... {
  using fns::operator()...;
};

This builds a visitor class from multiple lambda types. The dots after the typename mean that fns is a placeholder for multiple generic types. Using dots after fns then unpacks each type in a sequence. So the overloaded type with three lambda types fn1, fn2, and fn3 unpacks into the following struct:

template <typename fn1, typename fn2, typename fn3>
struct overloaded : fn1, fn2, fn3 {
    using fn1::operator();
    using fn2::operator();
    using fn3::operator();
};

This unpacked type overloads the operator() inside each lambda, making it a proper visitor class. The unpacking syntax ... allows overloaded to take in as many lambdas as it needs to fulfill the variant visitor.

Now for the second definition:

template <typename... Ts>
overloaded(Ts...) -> overloaded<Ts...>;

This is syntactic sugar that converts a syntax of overloaded { a, b, c } into the type overloaded<a, b, c>. Just like the first part, this syntax can take in as many types needed to satisfy the visitor.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment