Last active
October 31, 2016 16:51
-
-
Save sherm1/fab75e7fdb41a6cedecbcfea36c87cee to your computer and use it in GitHub Desktop.
Prototype for multiple instantiations managed by a single system
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "MBTree.h" | |
#include "System.h" | |
#include <complex> | |
#include <cstdio> | |
#include <iostream> | |
#include <memory> | |
#include <utility> | |
using std::complex; | |
using std::cout; | |
using std::endl; | |
using std::unique_ptr; | |
using std::make_unique; | |
// This is a concrete System class, modeling MultibodySystem. | |
template <typename T> | |
class MBSystem : public System<MBSystem, T> { // CRTP! | |
using Super = System<::MBSystem, T>; | |
public: | |
// Construct and take over ownership of the given MBTree. | |
MBSystem(unique_ptr<MBTree<T>> tree) : tree_(std::move(tree)) {} | |
// Construct an alternate MySystem<T> with same structure as the given one. | |
// This is also the copy constructor for the T=double fundamental. | |
explicit MBSystem(const MBSystem<double>& fundamental) | |
: Super(fundamental), tree_(new MBTree<T>(*fundamental.tree_)) {} | |
// For demo purposes, report the name of the scalar type here. | |
const char* type() const { return typeid(T).name(); } | |
// Return interesting concrete System-specific stuff. | |
const MBTree<T>& get_mb_tree() const { return *tree_; } | |
private: | |
// Let all other instantiations of this class be friends with this one. | |
template <typename TT> | |
friend class MBSystem; | |
unique_ptr<MBTree<T>> tree_; | |
}; | |
int main() { | |
// Fill in the "MultibodyTree" first. | |
auto mb_tree = make_unique<MBTree<double>>(); | |
mb_tree->AddJoint(make_unique<PinJoint<double>>(1.1)); | |
mb_tree->AddJoint(make_unique<SliderJoint<double>>(2.2)); | |
mb_tree->AddJoint(make_unique<PinJoint<double>>(3.3)); | |
mb_tree->AddJoint(make_unique<SliderJoint<double>>(4.4)); | |
// Create the fundamental MBSystem (that is, type double). | |
MBSystem<double> sys(std::move(mb_tree)); | |
// Create some alternate instantiations of MBSystem (kept within the | |
// fundamental system). | |
MBSystem<complex<double>>::AddAlternate(sys); | |
MBSystem<float>::AddAlternate(sys); | |
cout << "num alternates=" << sys.get_num_alternates() << endl; | |
// Stuff happens using the fundamental system, then at some point we decide | |
// we want one of the alternate instantiations. | |
const auto& csys = sys.get_alternate<complex<double>>(); | |
const auto& fsys = sys.get_alternate<float>(); | |
cout << "my type=" << sys.type() << endl; | |
cout << "csys type=" << csys.type() << endl; | |
cout << "fsys type=" << fsys.type() << endl; | |
// Dig out the matching instantiations of the multibody tree. | |
const auto& dtree = sys.get_mb_tree(); // <double> (fundamental) | |
const auto& ftree = fsys.get_mb_tree(); // <float> (not useful) | |
// Using the fundamental system, calculate derivative df analytically. | |
Context<double> cd{0.5}; // set x=0.5 (set up context for fundamental) | |
const auto& dpin0 = dtree.GetJoint<PinJoint>(0); | |
double f = dpin0.PinFunc(cd); | |
double df = dpin0.DPinFuncDx(cd); // analytical derivative | |
// Instead, use the same joint of the complex alternate to calculate the | |
// derivative using a complex step derivative (equivalent to autodiff). | |
const auto& ctree = csys.get_mb_tree(); | |
const auto& cpin0 = ctree.GetJoint<PinJoint>(0); | |
Context<complex<double>> cc(cd); // clone context for this alternate. | |
cc.x += complex<double>(0, 1e-20); // complex step derivative | |
double cdf = cpin0.PinFunc(cc).imag() / 1e-20; | |
printf(" f(x)=%.16g;\n df(x)=%.16g (analytical)\ncdf(x)=%.16g (autodiff)\n", | |
f, df, cdf); | |
getchar(); | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#pragma once | |
#include <cmath> | |
#include <memory> | |
#include <utility> | |
#include <vector> | |
#include "System.h" | |
template <typename T> class MBTree; | |
/** These are the known joint types. This is required for compile-time | |
selection of the right copy-from-fundamental method. **/ | |
enum class JointType {Pin,Slider}; | |
/** This is a mock Joint base class. It demonstrates that virtual methods | |
work fine with this style of templatizing. **/ | |
template <typename T> | |
class Joint { | |
public: | |
/** Return the index assigned to this joint, which will be the same in all | |
of the alternate instantiations. **/ | |
int get_joint_index() const {return index_;} | |
virtual int get_num_positions() const = 0; | |
virtual void get_positions(const Context<T>& context, T* positions) const = 0; | |
virtual JointType get_joint_type() const = 0; | |
/** This serves to instantiate methods for creating a Joint<T> from a | |
Joint<double>. **/ | |
static std::unique_ptr<Joint<T>> CloneFrom( | |
const Joint<double>& fundamental); | |
protected: | |
Joint() {} | |
private: | |
friend class MBTree<T>; | |
// Used only by MBTree to assign an index to this joint. | |
void set_joint_index(int index) {index_=index;} | |
void set_mb_tree(MBTree<T>* tree) {tree_=tree;} | |
// These are set after this Joint has been added to an MBTree<T>. | |
int index_; | |
MBTree<T>* tree_{}; | |
}; | |
/** This is a concrete Joint implementation. It implements the base class | |
virtuals and adds some mock PinJoint-specific functionality. **/ | |
template <typename T> | |
class PinJoint : public Joint<T> { | |
public: | |
/** Create a PinJoint using some PinJoint-specific parameter so we can | |
verify that the alternate instantiations end up with the same value. **/ | |
explicit PinJoint(const T& parameter) : parameter_{parameter} {} | |
explicit PinJoint(const PinJoint<double>& fundamental) | |
: PinJoint(static_cast<T>(fundamental.parameter_)) {} | |
/** Compute something specific to a PinJoint. **/ | |
T PinFunc(const Context<T>& context) const { | |
using std::cos; | |
return cos(context.x); | |
} | |
/** For testing, this method returns the analytic derivative of Func(); we'll | |
compare that with the automatic differentiation result we get from an | |
alternate instantiation. **/ | |
T DPinFuncDx(const Context<T>& context) const { | |
using std::sin; | |
return -sin(context.x); | |
} | |
/** Return the PinJoint-specific parameter. **/ | |
const T& get_parameter() const { return parameter_; } | |
private: | |
template <typename TT> friend class PinJoint; | |
int get_num_positions() const override {return 1;} | |
void get_positions(const Context<T>& context,T* positions) const override { | |
*positions = T(3) * context.x; // whatever; point is that it is T-dependent | |
} | |
JointType get_joint_type() const override { return JointType::Pin; } | |
T parameter_{}; | |
}; | |
/** Another concrete joint, just so we'll have two. **/ | |
template <typename T> | |
class SliderJoint : public Joint<T> { | |
public: | |
explicit SliderJoint(const T& offset) : offset_{offset} {} | |
explicit SliderJoint(const SliderJoint<double>& fundamental) | |
: SliderJoint(static_cast<T>(fundamental.offset_)) {} | |
/** Return the SliderJoint-specific parameter. **/ | |
const T& get_offset() const { return offset_; } | |
private: | |
template <typename TT> friend class SliderJoint; | |
int get_num_positions() const override {return 1;} | |
void get_positions(const Context<T>& context,T* positions) const override { | |
*positions = T(-3) * context.x; // whatever; point is that it is T-dependent | |
} | |
JointType get_joint_type() const override { return JointType::Slider; } | |
T offset_{}; | |
}; | |
// This method instantiates the concrete joint-from-fundamental constructors. | |
// TODO: I don't like this but haven't been able to figure out how to do it | |
// without a switch yet. One way to justify this is to think of it as analogous | |
// to the way we build systems from urdf or sdf files. That is kind of what's | |
// happening here: the fundamental system is the specification we're using to | |
// build another model, in the same way that a urdf is a specification for | |
// building a model. Is there a better way to do this? | |
template <typename T> | |
std::unique_ptr<Joint<T>> Joint<T>::CloneFrom( | |
const Joint<double>& fundamental) { | |
switch (fundamental.get_joint_type()) { | |
case JointType::Pin: | |
return std::make_unique<PinJoint<T>>( | |
dynamic_cast<const PinJoint<double>&>(fundamental)); | |
case JointType::Slider: | |
return std::make_unique<SliderJoint<T>>( | |
dynamic_cast<const SliderJoint<double>&>(fundamental)); | |
default: | |
assert(!"bad joint type"); | |
} | |
return std::unique_ptr<Joint<T>>(); | |
} | |
/** This is a mock MultibodyTree. It is templatized by scalar type, has | |
substructure in the form of "Joints". **/ | |
template <typename T> | |
class MBTree { | |
public: | |
MBTree() {} | |
/** Add a concrete joint to the tree, take over ownership of the joint | |
object, and return a reference to it that still has the concrete type. An | |
index is assigned and saved in the returned object. **/ | |
template <template <typename> class JointType> | |
JointType<T>* AddJoint(std::unique_ptr<JointType<T>> joint) { | |
JointType<T>* concrete = joint.get(); | |
concrete->set_joint_index((int)joints_.size()); | |
concrete->set_mb_tree(this); | |
joints_.push_back(std::unique_ptr<Joint<T>>(joint.release())); | |
return concrete; | |
} | |
/** Retrieve the concrete Joint<T> given the joint index. **/ | |
template <template <typename> class JointType> | |
const JointType<T>& GetJoint(int index) const { | |
const Joint<T>& joint = *joints_[index]; | |
return dynamic_cast<const JointType<T>&>(joint); | |
} | |
/** Create an alternate-scalar MBTree from the fundamental one. This is the | |
copy constructor for the fundamental tree; the others don't have one. **/ | |
explicit MBTree(const MBTree<double>& fundamental) { | |
for (const auto& joint : fundamental.joints_) { | |
auto new_joint = Joint<T>::CloneFrom(*joint); | |
joints_.push_back(std::move(new_joint)); | |
} | |
} | |
private: | |
// Let all other instantiations of this class be friends with this one. | |
template <typename TT> | |
friend class MBTree; | |
std::vector<std::unique_ptr<Joint<T>>> joints_; | |
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#pragma once | |
#include <cassert> | |
#include <memory> | |
#include <typeindex> | |
#include <typeinfo> | |
#include <unordered_map> | |
#include <utility> | |
#include <vector> | |
/** This is the type-erased base class for instantiated `System` classes. **/ | |
class InstantiatedSystem { | |
public: | |
virtual ~InstantiatedSystem() {} | |
}; | |
/** For a System instantiated with the fundamental scalar type `double` this | |
class contains a registry of the other available instantiations of the | |
same System. For any other type this struct is empty. **/ | |
template <typename T> | |
class AlternateInstantiations { | |
public: | |
// Default implementation is empty. | |
int get_num_alternates() const { return 0; } | |
}; | |
/** Specialization for double. **/ | |
template<> | |
class AlternateInstantiations<double> { | |
public: | |
void RegisterAlternate(const std::type_index& index, | |
std::unique_ptr<InstantiatedSystem> alt) { | |
auto entry = type_to_instantiation_.insert( | |
std::pair<std::type_index, size_t>(index, instantiations.size())); | |
if (!entry.second) return; // This alternate is already present. | |
instantiations.push_back(std::move(alt)); | |
} | |
const InstantiatedSystem& get_alternate(const std::type_index& index) const { | |
auto entry = type_to_instantiation_.find(index); | |
assert(entry != type_to_instantiation_.end()); | |
return *instantiations[entry->second]; | |
} | |
int get_num_alternates() const { return (int)instantiations.size(); } | |
private: | |
std::vector<std::unique_ptr<InstantiatedSystem>> instantiations; | |
// Here is how you find the right instantiation. | |
std::unordered_map<std::type_index, size_t> type_to_instantiation_; | |
}; | |
/** This is a System instantiated for a concrete System type MySys, which | |
must be derived from `System<MySystem, T>` for an Eigen Scalar type T. **/ | |
template <template <typename> class MySys, typename T> | |
class System : public InstantiatedSystem { | |
public: | |
using SystemT = System; | |
/** Default constructor creates an empty System with no alternative | |
instantiations. **/ | |
System() {} | |
/** Construction from fundamental instantiation. **/ | |
explicit System(const MySys<double>& fundamental) { | |
} | |
/** Create an alternate instantiation of the fundamental System and | |
register this one with it. **/ | |
static void AddAlternate(MySys<double>& fundamental) { | |
std::unique_ptr<MySys<T>> alternate(new MySys<T>(fundamental)); | |
// TODO(sherm) Make alternate mirror fundamental. | |
fundamental.get_mutable_alternates()->RegisterAlternate( | |
std::type_index(typeid(T)), | |
std::unique_ptr<InstantiatedSystem>(alternate.release())); | |
} | |
template <class TT> | |
const MySys<TT>& get_alternate() const { | |
const InstantiatedSystem& alt = | |
alternates_.get_alternate(std::type_index(typeid(TT))); | |
return dynamic_cast<const MySys<TT>&>(alt); | |
} | |
int get_num_alternates() const { return alternates_.get_num_alternates(); } | |
AlternateInstantiations<T>* get_mutable_alternates() { return &alternates_; } | |
protected: | |
private: | |
AlternateInstantiations<T> alternates_; | |
}; | |
template <typename T> | |
struct Context{ | |
explicit Context(const T& value) : x(value) {} | |
/** Create an alternate context from a fundamental one. **/ | |
explicit Context(const Context<double>& fundamental) { | |
x = T(fundamental.x); | |
} | |
T x; // state variable | |
}; |
Instead of a copy constructor on MBSystem, could we have a virtual method System<T>::TransmogrifyFrom(System<double>)
? Then Diagrams can Transmogrify generically.
Yeah, that sounds like a better idea (except maybe for the word "Transmogrify"!). CloneFrom()
seems a little less terrifying.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi, @amcastro-tri. Here is what I have so far for the multiple instantiation stuff. Start with main.cpp. If you grab these 3 files you should be able to compile and build on Ubuntu. - Sherm