Created
February 12, 2019 11:11
-
-
Save utilForever/1e0a93263ff5d52ae35304906e9c980a to your computer and use it in GitHub Desktop.
CubbyDNN graph sample code
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Tensor = Tensor<float>; | |
Graph g; | |
auto& input = g.Input(TensorShape(64, 32)); | |
auto& oneHot = g.Reshape(input, TensorShape(1, 2048)); | |
auto& output = g.DropOut(oneHot, 10); | |
auto& answer = g.Max(result); | |
for (int i = 0; i < trainNum; ++i) | |
{ | |
data = LoadData("dataPath"); | |
g.ImportData(data[i]); | |
Tensor answer = g.Run(); | |
std::cout << "the answer is" << answer << std::endl; | |
} |
For example,
using Tensor = Tensor<float>;
const int numCategories = 10;
Graph g;
Tensor batchSize = g.PlaceHolder(TensorShape(1), "batch_size");
Tensor input = g.PlaceHolder(TensorShape(64, 32, batchSize), "input");
Tensor oneHot = g.Reshape(input, TensorShape(1, 2048, batchSize));
Tensor output = g.Dense(oneHot, numCategories);
Tensor result = g.DropOut(output, 0.5);
Tensor softmax = g.SoftMax(result);
Tensor answer = g.Max(softmax, "answer");
auto& label = g.PlaceHolder(TensorShape(1, numCategories), "label");
Tensor crossEntropy = g.CrossEntropy(answer, label);
Tensor optimizer = g.AdamOptimizer(crossEntropy);
for (int i = 0; i < trainNum; ++i)
{
Stream data = LoadData("dataPath", batch_size); // brings train data of batch size
Stream label = LoadData("labelPath", batch_size);
// Stream class should have operator(size_t batch_size) that returns next batch
g.fit(optimizer, pair{"input", data}, pair{"label", label}, pair{"batch_size", 100});
}
Stream testData = LoadData("testDataPath", 1);
Tensor trainedAnswer = g.run(answer, pair{"input", testData}, pair{"batch_size", 1});
int num = trainedAnswer.castToInt();
std::cout << "derived answer for data: " << num << std::endl;
@Ravenwater I have not yet decided what operations are required in Graph
class. I referenced test code in the following library.
https://github.com/AcrylicShrimp/TinNet/blob/master/Example_Graph_Basic/Run.cpp
I'm curious about it you are suggesting. This part needs a lot of discussion and I want to take a few comments and go a little better.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Idea of having Graph structure looks good since Graph will manage operations and tensors.
But it's quite unclear that what will happen if we do g.run(); or g.ImportData(data[i]);
In my opinion, instead of g.Run(), we can do something like Run(answer); and feed(input, data[i]);
since answer and input is member of Graph 'g', It's clear that 'g' will be on execution.
When we want to save the graph, or import the graph, we can g1.save("path"); and Graph g1 = import("path").
there's no reason to get restricted by 'Graph' when we design 'run' and 'evaluating' interface. We only use it when it's more convenient.