Last active
May 24, 2016 16:25
-
-
Save lambday/ffc5ac685650775772e8eda0994257c5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class linalg | |
{ | |
template <class T> | |
T dot(Vector<T>* a, Vector<T>* b) | |
{ | |
if (a->onGPU() && b->onGPU()) { | |
if (this->hasGPUBackend()) { | |
// do the gpu backend dot product | |
// you shouldn't care whether it's viennacl or some other GPU backend. | |
return this->gpu_backend->dot(*static_cast<GPUVector<T>*>(a), *static_cast<GPUVector<T>*>(a)); | |
} else { | |
// either throw a RuntimeException or transfer back the data to cpu | |
throw new RuntimeException("user did not register GPU backend"); | |
} | |
} else { | |
// take care that the matricies are on the same backend | |
// do the non-gpu based default backend: | |
// this should be actually as well implemented in a separate class's function and just that being called here: | |
// like: | |
return this->cpu_backend->dot(*static_cast<CPUVector<T>*>(a), *static_cast<CPUVector<T>*>(a)); | |
} | |
} | |
bool hasGPUBackend() | |
{ | |
return gpu_backend != nullptr; | |
} | |
CPUBackend* cpu_backend; | |
GPUBackend* gpu_backend; | |
}; | |
class CPUBackend | |
{ | |
template <typename T> | |
T dot(CPUVector<T> a, CPUVector<T> b) | |
{ | |
// create eigen maps, call eigen | |
} | |
// similarly, other methods | |
}; | |
class GPUBackend | |
{ | |
#ifdef HAVE_VIENNACL | |
template <typename T> | |
T dot(GPUVector<T> a, GPUVector<T> b) | |
{ | |
// access gpu memory | |
// call viennacl inner prod | |
} | |
// similarly, other methods | |
#endif | |
}; | |
template <typename T> | |
class Vector | |
{ | |
virtual bool onGPU() = 0; | |
}; | |
template <typename T> | |
class CPUVector : public Vector<T> | |
{ | |
unique_ptr<T[]> vector; | |
// other stuffs | |
virtual bool onGPU() { return false; } | |
}; | |
template <typename T> | |
class GPUVector : public Vector<T> | |
{ | |
#ifdef HAVE_VIENNACL | |
vcl_memory_array gpu_ptr; | |
// other gpu related stuff | |
virtual bool onGPU() { return true; } | |
#endif | |
}; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment