Skip to content

Instantly share code, notes, and snippets.

@soumith
Created May 28, 2019 04:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save soumith/af07e5cf66174566d56c07bd983bcd89 to your computer and use it in GitHub Desktop.
Save soumith/af07e5cf66174566d56c07bd983bcd89 to your computer and use it in GitHub Desktop.
#include <torch/torch.h>
#include <iostream>
#include <ATen/Parallel.h>
#include <ATen/ATen.h>
// using namespace at;
using namespace torch;
void submodular_select(Tensor candidate_points, Tensor features_done, Tensor features)
{
int max_idx = -1;
float max_value = -1e-9;
for (int i=0; i < candidate_points.size(0); i++)
{
std::vector<Tensor> temp;
if (candidate_points.item<int>() == 1)
{
temp.push_back(features_done);
temp.push_back(features[candidate_points[i]]);
auto stacked_temp = stack(temp);
std::cout << std::get<0>(stacked_temp.max(1,false)) << std::endl;
float value = std::get<0>(stacked_temp.max(1,false)).sum().item<float>();
if (value > max_value)
{
max_value = value;
max_idx = i;
}
}
}
std::cout<<"Max Value" << max_value << std::endl;
std::cout << "Max Index" << max_idx << std::endl;
// return max_idx;
}
int main()
{
int num_data_points = 6000;
int num_features = 256;
int batch_size = 64;
Tensor features = torch::randn({num_data_points, num_features}, dtype(kFloat));
Tensor done = torch::randint(0, num_data_points, batch_size*4, dtype(kLong)); // Already Sampled Points
Tensor done_index = torch::arange(0, batch_size*3, kLong).squeeze();
Tensor features_done = features.index(done_index);
Tensor candidate_points = torch::ones(num_data_points, dtype(kLong));
auto scatter_val = torch::zeros(num_data_points, dtype(kLong));
candidate_points = candidate_points.scatter_(0, done, scatter_val);
for (int batch=0; batch < batch_size; batch++)
{
submodular_select(candidate_points, features_done, features);
// std::cout<<max_idx<<std::endl;
// done[batch_size*3+1] = max_idx;
// features_done = features[done[batch_size*3+batch+1]];
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment