Skip to content

Instantly share code, notes, and snippets.

@botev
Last active May 17, 2016 11:38
Show Gist options
  • Save botev/9addde28ba6810cb23e9157a920b6ce0 to your computer and use it in GitHub Desktop.
Save botev/9addde28ba6810cb23e9157a920b6ce0 to your computer and use it in GitHub Desktop.
Iterator for arrayfire
#include <iostream>
#include <vector>
#include <map>
#include <algorithm>
#include <iterator>
#include "arrayfire.h"
class AbstractDataSource{
protected:
unsigned instance_dim;
long n;
std::map<std::string, af::array> data;
std::vector<std::string> current_datums;
public:
AbstractDataSource(unsigned instance_dim = 0): instance_dim(instance_dim) {};
void set_datums(std::vector<std::string> datums){
current_datums = datums;
}
virtual void add_data(std::string key, af::array value){
if(data.size() == 0){
n = value.dims(instance_dim);
}
if(value.dims(instance_dim) != n){
throw 2;
}
if(data.find(key) != data.end()){
throw 1;
}
data[key] = value;
}
virtual void shuffle_data(af::array order){
if(instance_dim == 0){
auto f = [&order](std::pair<std::string const, af::array>& value){value.second = value.second(order, af::span, af::span, af::span);};
std::for_each(data.begin(), data.end(), f);
} else if(instance_dim == 1){
auto f = [&order](std::pair<std::string const, af::array>& value){value.second = value.second(af::span, order, af::span, af::span);};
std::for_each(data.begin(), data.end(), f);
} else if(instance_dim == 2){
auto f = [&order](std::pair<std::string const, af::array>& value){value.second = value.second(af::span, af::span, order, af::span);};
std::for_each(data.begin(), data.end(), f);
} else if(instance_dim == 3){
auto f = [&order](std::pair<std::string const, af::array>& value){value.second = value.second(af::span, af::span, af::span, order);};
std::for_each(data.begin(), data.end(), f);
}
}
virtual void random_shuffle(){
af::array rand = af::randu(n);
af::array order;
af::sort(rand, order, rand, 0);
shuffle_data(order);
}
class iterator{
private:
AbstractDataSource & source;
int batch_size;
bool full_batches;
int index;
std::vector<af::array::array_proxy> slice;
bool updated;
public:
typedef std::vector<af::array::array_proxy> value_type;
typedef std::vector<af::array::array_proxy>& ref_type;
typedef std::vector<af::array::array_proxy>* ptr_type;
iterator(AbstractDataSource & source, int index,
int batch_size,
bool full_batches):
source(source), index(index),
batch_size(batch_size),
full_batches(full_batches), updated(false) {};
iterator(iterator const & ref):
source(ref.source), index(index),
batch_size(ref.batch_size),
full_batches(ref.full_batches), updated(false) {
};
iterator& operator++() {
index+=batch_size;
if(index > source.n or (full_batches and index + batch_size >= source.n)){
index = source.n;
}
updated = false;
return *this;
}
iterator operator++(int) {
iterator copy(*this);
++(*this);
return copy;
}
bool operator==(iterator const & ref){
return &source == &ref.source and batch_size == ref.batch_size and
index == ref.index;
}
bool operator!=(iterator const & ref){
return &source != &ref.source or batch_size != ref.batch_size or
index != ref.index;
}
ref_type operator*() {
fetch_slice();
return slice;
}
ptr_type operator->(){
fetch_slice();
return &slice;
}
void fetch_slice(){
if(not updated) {
slice.clear();
int last = (index+batch_size-1) < source.n ? (index+batch_size-1) : source.n-1;
for (auto i = 0; i < source.current_datums.size(); ++i) {
if(source.instance_dim == 0){
slice.push_back(source.data[source.current_datums[i]](af::seq(index, last), af::span, af::span, af::span));
} else if(source.instance_dim == 1){
slice.push_back(source.data[source.current_datums[i]](af::span, af::seq(index, last), af::span, af::span));
} else if(source.instance_dim == 2){
slice.push_back(source.data[source.current_datums[i]](af::span, af::span, af::seq(index, last), af::span));
} else if(source.instance_dim == 3){
slice.push_back(source.data[source.current_datums[i]](af::span, af::span, af::span, af::seq(index, last)));
}
}
updated = true;
}
}
};
iterator begin( int batch_size, bool full_batches = false){
return iterator(*this, 0, batch_size, full_batches);
}
iterator end(int batch_size){
return iterator(*this, n, batch_size, false);
}
void print(){
for(auto i = data.begin(); i != data.end(); ++i){
std::cout << i->first << std::endl;
af_print(i->second);
}
}
};
int main()
{
auto source = AbstractDataSource(1);
source.add_data("d1", af::randu(5, 7));
source.add_data("d2", af::randu(3, 7));
source.add_data("d3", af::randu(1, 7));
source.print();
source.random_shuffle();
source.print();
source.set_datums({"d2", "d3"});
for(auto i = source.begin(2, true); i != source.end(2); ++i){
std::cout << "Iteration " << std::endl;
for(auto j = 0; j < i->size(); ++j){
af_print((*i)[j]);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment