Created
June 24, 2019 15:36
-
-
Save ashlaban/142827263a51488c3098da0e456a197a to your computer and use it in GitHub Desktop.
Based off of: https://github.com/stwunsch/tmva_mnist/blob/master/train.py and modified for use with a resnet model. (See https://gist.github.com/ashlaban/6be0ddc58d940d6ab1783ac0dbab19cc)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# | |
# NOTE: Requires that you have run `python create_dataset.py` | |
# from https://github.com/stwunsch/tmva_mnist first. | |
# | |
import ROOT | |
from resnet import ResnetBuilder | |
# Setup TMVA | |
ROOT.TMVA.Tools.Instance() | |
ROOT.TMVA.PyMethodBase.PyInitialize() | |
output = ROOT.TFile.Open("TMVA.root", "RECREATE") | |
factory = ROOT.TMVA.Factory( | |
"MNIST", output, | |
"!V:!Silent:Color:DrawProgressBar:AnalysisType=multiclass") | |
# Load data | |
dataloader = ROOT.TMVA.DataLoader("dataset") | |
data = ROOT.TFile.Open("mnist.root") | |
tree_digit = [] | |
for i in range(10): | |
tree_digit.append(data.Get("train_digits/train_digit{}".format(i))) | |
dataloader.AddTree(tree_digit[i], "digit{}".format(i)) | |
for i in range(28 * 28 * 1): | |
dataloader.AddVariable("x[{}]".format(i), "x_{}".format(i), "") | |
dataloader.PrepareTrainingAndTestTree( | |
ROOT.TCut(""), "SplitMode=Random:NormMode=None:!V") | |
# Define model | |
model = ResnetBuilder.build_resnet_18((1, 28, 28), 10) | |
model.compile( | |
loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"]) | |
# Store model to file | |
model.save("model.h5") | |
model.summary() | |
# Book methods | |
factory.BookMethod(dataloader, ROOT.TMVA.Types.kPyKeras, "PyKeras", | |
"H:!V:FilenameModel=model.h5:NumEpochs=10:BatchSize=100") | |
# Run training, test and evaluation | |
factory.TrainAllMethods() | |
factory.TestAllMethods() | |
factory.EvaluateAllMethods() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment