Skip to content

Instantly share code, notes, and snippets.

@utilForever
Created February 12, 2019 11:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save utilForever/1e0a93263ff5d52ae35304906e9c980a to your computer and use it in GitHub Desktop.
Save utilForever/1e0a93263ff5d52ae35304906e9c980a to your computer and use it in GitHub Desktop.
CubbyDNN graph sample code
using Tensor = Tensor<float>;
Graph g;
auto& input = g.Input(TensorShape(64, 32));
auto& oneHot = g.Reshape(input, TensorShape(1, 2048));
auto& output = g.DropOut(oneHot, 10);
auto& answer = g.Max(result);
for (int i = 0; i < trainNum; ++i)
{
data = LoadData("dataPath");
g.ImportData(data[i]);
Tensor answer = g.Run();
std::cout << "the answer is" << answer << std::endl;
}
@jwkim98
Copy link

jwkim98 commented Feb 13, 2019

For example,

using Tensor = Tensor<float>;
const int numCategories = 10;

Graph g;

Tensor batchSize = g.PlaceHolder(TensorShape(1), "batch_size");
Tensor input = g.PlaceHolder(TensorShape(64, 32, batchSize), "input");
Tensor oneHot = g.Reshape(input, TensorShape(1, 2048, batchSize));
Tensor output = g.Dense(oneHot, numCategories);
Tensor result = g.DropOut(output, 0.5);
Tensor softmax = g.SoftMax(result);
Tensor answer = g.Max(softmax, "answer");

auto& label = g.PlaceHolder(TensorShape(1, numCategories), "label");
Tensor crossEntropy = g.CrossEntropy(answer, label);
Tensor optimizer = g.AdamOptimizer(crossEntropy);


for (int i = 0; i < trainNum; ++i)
{
    Stream data = LoadData("dataPath", batch_size); // brings train data of batch size
    Stream label = LoadData("labelPath", batch_size);
    // Stream class should have operator(size_t batch_size) that returns next batch
    g.fit(optimizer, pair{"input", data}, pair{"label", label}, pair{"batch_size", 100}); 
}

Stream testData = LoadData("testDataPath", 1);

Tensor trainedAnswer = g.run(answer, pair{"input", testData}, pair{"batch_size", 1});
int num = trainedAnswer.castToInt();
std::cout << "derived answer for data: " << num << std::endl;

@utilForever
Copy link
Author

utilForever commented Feb 13, 2019

@Ravenwater I have not yet decided what operations are required in Graph class. I referenced test code in the following library.
https://github.com/AcrylicShrimp/TinNet/blob/master/Example_Graph_Basic/Run.cpp
I'm curious about it you are suggesting. This part needs a lot of discussion and I want to take a few comments and go a little better.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment