Skip to content

Instantly share code, notes, and snippets.

@sherm1
Last active October 31, 2016 16:51
Show Gist options
  • Save sherm1/fab75e7fdb41a6cedecbcfea36c87cee to your computer and use it in GitHub Desktop.
Save sherm1/fab75e7fdb41a6cedecbcfea36c87cee to your computer and use it in GitHub Desktop.
Prototype for multiple instantiations managed by a single system
#include "MBTree.h"
#include "System.h"
#include <complex>
#include <cstdio>
#include <iostream>
#include <memory>
#include <utility>
using std::complex;
using std::cout;
using std::endl;
using std::unique_ptr;
using std::make_unique;
// This is a concrete System class, modeling MultibodySystem.
template <typename T>
class MBSystem : public System<MBSystem, T> { // CRTP!
using Super = System<::MBSystem, T>;
public:
// Construct and take over ownership of the given MBTree.
MBSystem(unique_ptr<MBTree<T>> tree) : tree_(std::move(tree)) {}
// Construct an alternate MySystem<T> with same structure as the given one.
// This is also the copy constructor for the T=double fundamental.
explicit MBSystem(const MBSystem<double>& fundamental)
: Super(fundamental), tree_(new MBTree<T>(*fundamental.tree_)) {}
// For demo purposes, report the name of the scalar type here.
const char* type() const { return typeid(T).name(); }
// Return interesting concrete System-specific stuff.
const MBTree<T>& get_mb_tree() const { return *tree_; }
private:
// Let all other instantiations of this class be friends with this one.
template <typename TT>
friend class MBSystem;
unique_ptr<MBTree<T>> tree_;
};
int main() {
// Fill in the "MultibodyTree" first.
auto mb_tree = make_unique<MBTree<double>>();
mb_tree->AddJoint(make_unique<PinJoint<double>>(1.1));
mb_tree->AddJoint(make_unique<SliderJoint<double>>(2.2));
mb_tree->AddJoint(make_unique<PinJoint<double>>(3.3));
mb_tree->AddJoint(make_unique<SliderJoint<double>>(4.4));
// Create the fundamental MBSystem (that is, type double).
MBSystem<double> sys(std::move(mb_tree));
// Create some alternate instantiations of MBSystem (kept within the
// fundamental system).
MBSystem<complex<double>>::AddAlternate(sys);
MBSystem<float>::AddAlternate(sys);
cout << "num alternates=" << sys.get_num_alternates() << endl;
// Stuff happens using the fundamental system, then at some point we decide
// we want one of the alternate instantiations.
const auto& csys = sys.get_alternate<complex<double>>();
const auto& fsys = sys.get_alternate<float>();
cout << "my type=" << sys.type() << endl;
cout << "csys type=" << csys.type() << endl;
cout << "fsys type=" << fsys.type() << endl;
// Dig out the matching instantiations of the multibody tree.
const auto& dtree = sys.get_mb_tree(); // <double> (fundamental)
const auto& ftree = fsys.get_mb_tree(); // <float> (not useful)
// Using the fundamental system, calculate derivative df analytically.
Context<double> cd{0.5}; // set x=0.5 (set up context for fundamental)
const auto& dpin0 = dtree.GetJoint<PinJoint>(0);
double f = dpin0.PinFunc(cd);
double df = dpin0.DPinFuncDx(cd); // analytical derivative
// Instead, use the same joint of the complex alternate to calculate the
// derivative using a complex step derivative (equivalent to autodiff).
const auto& ctree = csys.get_mb_tree();
const auto& cpin0 = ctree.GetJoint<PinJoint>(0);
Context<complex<double>> cc(cd); // clone context for this alternate.
cc.x += complex<double>(0, 1e-20); // complex step derivative
double cdf = cpin0.PinFunc(cc).imag() / 1e-20;
printf(" f(x)=%.16g;\n df(x)=%.16g (analytical)\ncdf(x)=%.16g (autodiff)\n",
f, df, cdf);
getchar();
}
#pragma once
#include <cmath>
#include <memory>
#include <utility>
#include <vector>
#include "System.h"
template <typename T> class MBTree;
/** These are the known joint types. This is required for compile-time
selection of the right copy-from-fundamental method. **/
enum class JointType {Pin,Slider};
/** This is a mock Joint base class. It demonstrates that virtual methods
work fine with this style of templatizing. **/
template <typename T>
class Joint {
public:
/** Return the index assigned to this joint, which will be the same in all
of the alternate instantiations. **/
int get_joint_index() const {return index_;}
virtual int get_num_positions() const = 0;
virtual void get_positions(const Context<T>& context, T* positions) const = 0;
virtual JointType get_joint_type() const = 0;
/** This serves to instantiate methods for creating a Joint<T> from a
Joint<double>. **/
static std::unique_ptr<Joint<T>> CloneFrom(
const Joint<double>& fundamental);
protected:
Joint() {}
private:
friend class MBTree<T>;
// Used only by MBTree to assign an index to this joint.
void set_joint_index(int index) {index_=index;}
void set_mb_tree(MBTree<T>* tree) {tree_=tree;}
// These are set after this Joint has been added to an MBTree<T>.
int index_;
MBTree<T>* tree_{};
};
/** This is a concrete Joint implementation. It implements the base class
virtuals and adds some mock PinJoint-specific functionality. **/
template <typename T>
class PinJoint : public Joint<T> {
public:
/** Create a PinJoint using some PinJoint-specific parameter so we can
verify that the alternate instantiations end up with the same value. **/
explicit PinJoint(const T& parameter) : parameter_{parameter} {}
explicit PinJoint(const PinJoint<double>& fundamental)
: PinJoint(static_cast<T>(fundamental.parameter_)) {}
/** Compute something specific to a PinJoint. **/
T PinFunc(const Context<T>& context) const {
using std::cos;
return cos(context.x);
}
/** For testing, this method returns the analytic derivative of Func(); we'll
compare that with the automatic differentiation result we get from an
alternate instantiation. **/
T DPinFuncDx(const Context<T>& context) const {
using std::sin;
return -sin(context.x);
}
/** Return the PinJoint-specific parameter. **/
const T& get_parameter() const { return parameter_; }
private:
template <typename TT> friend class PinJoint;
int get_num_positions() const override {return 1;}
void get_positions(const Context<T>& context,T* positions) const override {
*positions = T(3) * context.x; // whatever; point is that it is T-dependent
}
JointType get_joint_type() const override { return JointType::Pin; }
T parameter_{};
};
/** Another concrete joint, just so we'll have two. **/
template <typename T>
class SliderJoint : public Joint<T> {
public:
explicit SliderJoint(const T& offset) : offset_{offset} {}
explicit SliderJoint(const SliderJoint<double>& fundamental)
: SliderJoint(static_cast<T>(fundamental.offset_)) {}
/** Return the SliderJoint-specific parameter. **/
const T& get_offset() const { return offset_; }
private:
template <typename TT> friend class SliderJoint;
int get_num_positions() const override {return 1;}
void get_positions(const Context<T>& context,T* positions) const override {
*positions = T(-3) * context.x; // whatever; point is that it is T-dependent
}
JointType get_joint_type() const override { return JointType::Slider; }
T offset_{};
};
// This method instantiates the concrete joint-from-fundamental constructors.
// TODO: I don't like this but haven't been able to figure out how to do it
// without a switch yet. One way to justify this is to think of it as analogous
// to the way we build systems from urdf or sdf files. That is kind of what's
// happening here: the fundamental system is the specification we're using to
// build another model, in the same way that a urdf is a specification for
// building a model. Is there a better way to do this?
template <typename T>
std::unique_ptr<Joint<T>> Joint<T>::CloneFrom(
const Joint<double>& fundamental) {
switch (fundamental.get_joint_type()) {
case JointType::Pin:
return std::make_unique<PinJoint<T>>(
dynamic_cast<const PinJoint<double>&>(fundamental));
case JointType::Slider:
return std::make_unique<SliderJoint<T>>(
dynamic_cast<const SliderJoint<double>&>(fundamental));
default:
assert(!"bad joint type");
}
return std::unique_ptr<Joint<T>>();
}
/** This is a mock MultibodyTree. It is templatized by scalar type, has
substructure in the form of "Joints". **/
template <typename T>
class MBTree {
public:
MBTree() {}
/** Add a concrete joint to the tree, take over ownership of the joint
object, and return a reference to it that still has the concrete type. An
index is assigned and saved in the returned object. **/
template <template <typename> class JointType>
JointType<T>* AddJoint(std::unique_ptr<JointType<T>> joint) {
JointType<T>* concrete = joint.get();
concrete->set_joint_index((int)joints_.size());
concrete->set_mb_tree(this);
joints_.push_back(std::unique_ptr<Joint<T>>(joint.release()));
return concrete;
}
/** Retrieve the concrete Joint<T> given the joint index. **/
template <template <typename> class JointType>
const JointType<T>& GetJoint(int index) const {
const Joint<T>& joint = *joints_[index];
return dynamic_cast<const JointType<T>&>(joint);
}
/** Create an alternate-scalar MBTree from the fundamental one. This is the
copy constructor for the fundamental tree; the others don't have one. **/
explicit MBTree(const MBTree<double>& fundamental) {
for (const auto& joint : fundamental.joints_) {
auto new_joint = Joint<T>::CloneFrom(*joint);
joints_.push_back(std::move(new_joint));
}
}
private:
// Let all other instantiations of this class be friends with this one.
template <typename TT>
friend class MBTree;
std::vector<std::unique_ptr<Joint<T>>> joints_;
};
#pragma once
#include <cassert>
#include <memory>
#include <typeindex>
#include <typeinfo>
#include <unordered_map>
#include <utility>
#include <vector>
/** This is the type-erased base class for instantiated `System` classes. **/
class InstantiatedSystem {
public:
virtual ~InstantiatedSystem() {}
};
/** For a System instantiated with the fundamental scalar type `double` this
class contains a registry of the other available instantiations of the
same System. For any other type this struct is empty. **/
template <typename T>
class AlternateInstantiations {
public:
// Default implementation is empty.
int get_num_alternates() const { return 0; }
};
/** Specialization for double. **/
template<>
class AlternateInstantiations<double> {
public:
void RegisterAlternate(const std::type_index& index,
std::unique_ptr<InstantiatedSystem> alt) {
auto entry = type_to_instantiation_.insert(
std::pair<std::type_index, size_t>(index, instantiations.size()));
if (!entry.second) return; // This alternate is already present.
instantiations.push_back(std::move(alt));
}
const InstantiatedSystem& get_alternate(const std::type_index& index) const {
auto entry = type_to_instantiation_.find(index);
assert(entry != type_to_instantiation_.end());
return *instantiations[entry->second];
}
int get_num_alternates() const { return (int)instantiations.size(); }
private:
std::vector<std::unique_ptr<InstantiatedSystem>> instantiations;
// Here is how you find the right instantiation.
std::unordered_map<std::type_index, size_t> type_to_instantiation_;
};
/** This is a System instantiated for a concrete System type MySys, which
must be derived from `System<MySystem, T>` for an Eigen Scalar type T. **/
template <template <typename> class MySys, typename T>
class System : public InstantiatedSystem {
public:
using SystemT = System;
/** Default constructor creates an empty System with no alternative
instantiations. **/
System() {}
/** Construction from fundamental instantiation. **/
explicit System(const MySys<double>& fundamental) {
}
/** Create an alternate instantiation of the fundamental System and
register this one with it. **/
static void AddAlternate(MySys<double>& fundamental) {
std::unique_ptr<MySys<T>> alternate(new MySys<T>(fundamental));
// TODO(sherm) Make alternate mirror fundamental.
fundamental.get_mutable_alternates()->RegisterAlternate(
std::type_index(typeid(T)),
std::unique_ptr<InstantiatedSystem>(alternate.release()));
}
template <class TT>
const MySys<TT>& get_alternate() const {
const InstantiatedSystem& alt =
alternates_.get_alternate(std::type_index(typeid(TT)));
return dynamic_cast<const MySys<TT>&>(alt);
}
int get_num_alternates() const { return alternates_.get_num_alternates(); }
AlternateInstantiations<T>* get_mutable_alternates() { return &alternates_; }
protected:
private:
AlternateInstantiations<T> alternates_;
};
template <typename T>
struct Context{
explicit Context(const T& value) : x(value) {}
/** Create an alternate context from a fundamental one. **/
explicit Context(const Context<double>& fundamental) {
x = T(fundamental.x);
}
T x; // state variable
};
@sherm1
Copy link
Author

sherm1 commented Jul 19, 2016

Hi, @amcastro-tri. Here is what I have so far for the multiple instantiation stuff. Start with main.cpp. If you grab these 3 files you should be able to compile and build on Ubuntu. - Sherm

@david-german-tri
Copy link

Instead of a copy constructor on MBSystem, could we have a virtual method System<T>::TransmogrifyFrom(System<double>)? Then Diagrams can Transmogrify generically.

@sherm1
Copy link
Author

sherm1 commented Oct 31, 2016

Yeah, that sounds like a better idea (except maybe for the word "Transmogrify"!). CloneFrom() seems a little less terrifying.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment