Created
May 24, 2015 04:46
-
-
Save Diggsey/8b374bc5a3773d2d7103 to your computer and use it in GitHub Desktop.
Not your father's C++
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// AmpTest.cpp : Defines the entry point for the console application. | |
// | |
#include "stdafx.h" | |
#include <boost/preprocessor/repeat.hpp> | |
#include <boost/preprocessor/enum.hpp> | |
#include <boost/preprocessor/enum_params.hpp> | |
#include <boost/range/counting_range.hpp> | |
#include <boost/range/adaptor/reversed.hpp> | |
using namespace concurrency; | |
template<typename T, int N> struct fixed_array; | |
#define ENUM_ARRAY_INDEX_M(z, n, data) data[n] | |
#define ENUM_ARRAY_INDEX(count, data) BOOST_PP_ENUM(count, ENUM_ARRAY_INDEX_M, data) | |
#define ENUM_FIELD_INIT_M(z, n, data) m_ ## data ## n (data ## n) | |
#define ENUM_FIELD_INIT(count, data) BOOST_PP_ENUM(count, ENUM_FIELD_INIT_M, data) | |
#define ENUM_FIELD_DECL_M(z, n, data) data ## n; | |
#define ENUM_FIELD_DECL(count, data) BOOST_PP_REPEAT(count, ENUM_FIELD_DECL_M, data) | |
#define ENUM_CASE_M(z, n, data) case n: return data ## n; | |
#define ENUM_CASE(count, data) BOOST_PP_REPEAT(count, ENUM_CASE_M, data) | |
#define DEF_SPEC_ARRAY(z, n, data) \ | |
template<typename T> struct fixed_array<T, n> { \ | |
public: \ | |
ENUM_FIELD_DECL(n, T m_elem); \ | |
\ | |
inline fixed_array(BOOST_PP_ENUM_PARAMS(n, T const& elem)) restrict(amp, cpu) : ENUM_FIELD_INIT(n, elem) {}; \ | |
inline T operator[](int i) const restrict(amp, cpu) { \ | |
switch (i) { \ | |
ENUM_CASE(n, m_elem) \ | |
default: return m_elem0; \ | |
} \ | |
} \ | |
}; | |
BOOST_PP_REPEAT(16, DEF_SPEC_ARRAY, ()); | |
template<unsigned N> class range_to; | |
const int uniform_param = 0x80000000; | |
template<unsigned N, int I> class param_indexer { | |
public: | |
template<typename T> static auto& index(T& param) restrict(amp, cpu) { | |
return param[N + I]; | |
} | |
}; | |
template<unsigned N> class param_indexer<N, uniform_param> { | |
public: | |
template<typename T> static T& index(T& param) restrict(amp, cpu) { | |
return param; | |
} | |
}; | |
template<typename F> | |
class range_invoker { | |
public: | |
F f; | |
template<int... I, unsigned N, typename... P> inline void each(range_to<N> unused, P&... p) restrict(amp, cpu) { | |
each<I...>(range_to<N - 1>(), p...); | |
f(N - 1, param_indexer<N - 1, I>::index(p)...); | |
} | |
template<int... I, unsigned N, typename... P> inline void eachRev(range_to<N> unused, P&... p) restrict(amp, cpu) { | |
f(N - 1, param_indexer<N - 1, I>::index(p)...); | |
eachRev<I...>(range_to<N - 1>(), p...); | |
} | |
template<int... I, typename... P> inline void each(range_to<0> unused, P&... p) restrict(amp, cpu) { } | |
template<int... I, typename... P> inline void eachRev(range_to<0> unused, P&... p) restrict(amp, cpu) { } | |
}; | |
template<unsigned N> class range_to { | |
public: | |
template<int... I, typename F, typename... P> static void each(F f, P&... p) restrict(amp, cpu) { | |
range_invoker<F> invoker{ f }; | |
invoker.each<I...>(range_to<N>(), p...); | |
} | |
template<int... I, typename F, typename... P> static void eachRev(F f, P&... p) restrict(amp, cpu) { | |
range_invoker<F> invoker{ f }; | |
invoker.eachRev<I...>(range_to<N>(), p...); | |
} | |
}; | |
template<typename T, int Rank> | |
void fill(array<T, Rank>& arr, T initValue) { | |
parallel_for_each(arr.extent, [&arr, initValue](index<Rank> idx) restrict(amp) { | |
arr[idx] = initValue; | |
}); | |
} | |
void printArray(array_view<float, 1> view) { | |
view.refresh(); | |
array_view<float, 1> temp(view.extent); | |
view.copy_to(temp); | |
for (int i = 0; i < temp.extent[0]; ++i) | |
std::cout << std::setw(8) << temp[i]; | |
std::cout << std::endl; | |
} | |
struct table { | |
array<float, 1> m_value; | |
array<float, 1> m_gradient; | |
inline table(extent<1> size) : m_value(size), m_gradient(size) { } | |
inline extent<1> extent() const { return m_value.extent; } | |
}; | |
struct table_view { | |
array_view<float, 1> m_value; | |
array_view<float, 1> m_gradient; | |
inline table_view(table& src) : m_value(src.m_value), m_gradient(src.m_gradient) { } | |
inline extent<1> extent() const { return m_value.extent; } | |
}; | |
class network; | |
class module { | |
protected: | |
table m_output; | |
network* m_network; | |
inline module(network* nn, extent<1> outputSize) : m_network(nn), m_output(outputSize) { } | |
static inline extent<1> scalarExtent(std::initializer_list<table_view> inputs) { | |
auto extent = inputs.begin()->extent(); | |
for (auto& input : inputs) { | |
if (input.extent() != extent) | |
throw "Extent mismatch"; | |
} | |
return extent; | |
} | |
public: | |
inline table_view getOutput() { | |
return table_view(m_output); | |
} | |
virtual void updateOutput() = 0; | |
virtual void updateGradInput() = 0; | |
inline operator table_view() { | |
return getOutput(); | |
} | |
}; | |
// Base class for modules of the form `foldl <op> [inputs...]`, | |
// which includes scalar arithmetic | |
template<int N, typename S> | |
class module_scalar : public module { | |
protected: | |
fixed_array<table_view, N> m_inputs; | |
public: | |
template<typename... T> | |
inline module_scalar(network* nn, T... inputs) : module(nn, scalarExtent({ inputs... })), m_inputs(inputs...) { } | |
virtual void updateOutput() { | |
auto inputs = m_inputs; | |
table_view output = m_output; | |
output.m_value.discard_data(); | |
try { | |
parallel_for_each( | |
output.extent(), | |
[=](index<1> idx) restrict(amp) { | |
float acc = S::identity(); | |
range_to<N>::each<uniform_param, 0>([=](int i, float& acc, auto& inputi) restrict(amp) { | |
acc = S::op(acc, inputi.m_value[idx]); | |
}, acc, inputs); | |
output.m_value[idx] = acc; | |
} | |
); | |
} catch (concurrency::runtime_exception& ex) { | |
OutputDebugStringA(ex.what()); | |
DebugBreak(); | |
} | |
} | |
virtual void updateGradInput() { | |
auto inputs = m_inputs; | |
table_view output = m_output; | |
for (int i = 0; i < N; ++i) | |
inputs[i].m_gradient.discard_data(); | |
try { | |
parallel_for_each( | |
output.extent(), | |
[=](index<1> idx) restrict(amp) { | |
// This is a fun little snippet of code which calculates all N partial | |
// derivatives of `foldl <op> [inputs...]` in O(N) time | |
// We can't use normal loops because AMP is funny about them... | |
float x[N]; | |
range_to<N>::each<0, 0>([=](int i, float& xi, auto& inputi) restrict(amp) { | |
xi = inputi.m_value[idx]; | |
}, x, inputs); | |
float acc[N]; | |
acc[0] = S::identity(); | |
range_to<N-1>::each<1, 0, 0>([=](int i, float& acci1, float& acci, auto& xi) restrict(amp) { | |
acci1 = S::op(acci, xi); | |
}, acc, acc, x); | |
float dacc0[N]; | |
dacc0[N - 1] = 1.0f; | |
range_to<N-1>::eachRev<0, 1, 1>([=](int i, float& dacc0i, float& acci, auto& xi) restrict(amp) { | |
dacc0i = S::dop0(acci, xi); | |
}, dacc0, acc, x); | |
float y[N]; | |
range_to<N - 1>::eachRev<0, 0, 0, 0>([=](int i, float& yi, float& acci, auto& xi, float& dacc0i) restrict(amp) { | |
float dacc1 = S::dop1(acci, xi); | |
yi = dacc1*dacc0i; | |
}, y, acc, x, dacc0); | |
float gradient = output.m_gradient[idx]; | |
range_to<N>::each<0, 0>([=](int i, auto& inputi, float& yi) restrict(amp) { | |
inputi.m_gradient[idx] = gradient*yi; | |
}, inputs, y); | |
} | |
); | |
} catch (concurrency::runtime_exception& ex) { | |
OutputDebugStringA(ex.what()); | |
DebugBreak(); | |
} | |
} | |
}; | |
template<int N = 2> | |
class module_add : public module_scalar<N, module_add<N>> { | |
public: | |
using module_scalar::module_scalar; | |
static inline float identity() restrict(amp) { | |
return 0.0f; | |
} | |
static inline float op(float a, float b) restrict(amp) { | |
return a+b; | |
} | |
static inline float dop0(float a, float b) restrict(amp) { | |
return 1.0f; | |
} | |
static inline float dop1(float a, float b) restrict(amp) { | |
return 1.0f; | |
} | |
}; | |
class module_sub : public module_scalar<2, module_sub> { | |
public: | |
using module_scalar::module_scalar; | |
static inline float identity() restrict(amp) { | |
return 0.0f; | |
} | |
static inline float op(float a, float b) restrict(amp) { | |
return -(a+b); | |
} | |
static inline float dop0(float a, float b) restrict(amp) { | |
return -1.0f; | |
} | |
static inline float dop1(float a, float b) restrict(amp) { | |
return -1.0f; | |
} | |
}; | |
class module_neg : public module_scalar<1, module_neg> { | |
public: | |
using module_scalar::module_scalar; | |
static inline float identity() restrict(amp) { | |
return 0.0f; | |
} | |
static inline float op(float a, float b) restrict(amp) { | |
return -b; | |
} | |
static inline float dop0(float a, float b) restrict(amp) { | |
return 0.0f; | |
} | |
static inline float dop1(float a, float b) restrict(amp) { | |
return -1.0f; | |
} | |
}; | |
template<int N = 2> | |
class module_mul : public module_scalar<N, module_mul<N>> { | |
public: | |
using module_scalar::module_scalar; | |
static inline float identity() restrict(amp) { | |
return 1.0f; | |
} | |
static inline float op(float a, float b) restrict(amp) { | |
return a*b; | |
} | |
static inline float dop0(float a, float b) restrict(amp) { | |
return b; | |
} | |
static inline float dop1(float a, float b) restrict(amp) { | |
return a; | |
} | |
}; | |
class module_div : public module_scalar<2, module_div> { | |
public: | |
using module_scalar::module_scalar; | |
static inline float identity() restrict(amp) { | |
return 1.0f; | |
} | |
static inline float op(float a, float b) restrict(amp) { | |
return 1.0f / (a*b); | |
} | |
static inline float dop0(float a, float b) restrict(amp) { | |
return -1.0f / (a*a*b); | |
} | |
static inline float dop1(float a, float b) restrict(amp) { | |
return -1.0f / (a*b*b); | |
} | |
}; | |
class module_rcp : public module_scalar<1, module_rcp> { | |
public: | |
using module_scalar::module_scalar; | |
static inline float identity() restrict(amp) { | |
return 1.0f; | |
} | |
static inline float op(float a, float b) restrict(amp) { | |
return 1.0f / b; | |
} | |
static inline float dop0(float a, float b) restrict(amp) { | |
return 0.0f; | |
} | |
static inline float dop1(float a, float b) restrict(amp) { | |
return -1.0f / (b*b); | |
} | |
}; | |
class module_param : public module { | |
public: | |
inline module_param(network* nn, extent<1> extent) : module(nn, extent) { } | |
virtual void updateOutput() { | |
} | |
virtual void updateGradInput() { | |
} | |
void setValue(array_view<float,1> value) { | |
if (value.extent != m_output.extent()) | |
throw "Extent mismatch"; | |
value.copy_to(m_output.m_value); | |
} | |
}; | |
class network { | |
private: | |
std::vector<std::unique_ptr<module>> m_moduleSeq; | |
template<typename T> | |
inline T* addModule(std::unique_ptr<T>&& m) { | |
T* result = m.get(); | |
m_moduleSeq.push_back(std::move(m)); | |
return result; | |
} | |
public: | |
template<typename T, typename... P> inline T* make(P&&... args) { | |
return addModule(std::make_unique<T>(this, std::forward<P>(args)...)); | |
} | |
void updateOutput() { | |
for (auto& module : m_moduleSeq) | |
module->updateOutput(); | |
} | |
void updateGradInput() { | |
for (auto& module : boost::adaptors::reverse(m_moduleSeq)) | |
module->updateGradInput(); | |
} | |
}; | |
struct rms_config { | |
float learningRate = 1e-2f; | |
float alpha = 0.99f; | |
float epsilon = 1e-8f; | |
}; | |
float lerp(float alpha, float a, float b) restrict(amp) { | |
return alpha*b + (1.0f - alpha)*a; | |
} | |
template<typename F> | |
class rms_prop { | |
private: | |
rms_config m_config; | |
F m_f; | |
array_view<float, 1>& m_x; | |
array<float, 1> m_state; | |
array<float, 1> m_loss; | |
public: | |
rms_prop(F f, array_view<float, 1>& x, rms_config const& config) : m_f(f), m_x(x), m_state(x.get_extent()), m_loss(x.get_extent()), m_config(config) { | |
fill(m_state, 0.0f); | |
fill(m_loss, 0.0f); | |
} | |
void step() { | |
parallel_for_each(x.extent, [&](index<Rank> idx) restrict(amp) { | |
m_state[idx] = lerp(m_config.alpha, m_state[idx], m_state[idx]); | |
}); | |
} | |
}; | |
int _tmain(int argc, _TCHAR* argv[]) | |
{ | |
network nn; | |
extent<1> size(8); | |
auto a = nn.make<module_param>(size); | |
auto b = nn.make<module_add<2>>(a->getOutput(), a->getOutput()); | |
auto c = nn.make<module_div>(b->getOutput(), a->getOutput()); | |
array<float, 1> data(size, boost::make_counting_iterator(1.0f)); | |
a->setValue(data); | |
nn.updateOutput(); | |
nn.updateGradInput(); | |
printArray(a->getOutput().m_value); | |
printArray(b->getOutput().m_value); | |
printArray(c->getOutput().m_value); | |
getchar(); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment