Skip to content

Instantly share code, notes, and snippets.

@ashlaban
Created June 24, 2019 15:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ashlaban/142827263a51488c3098da0e456a197a to your computer and use it in GitHub Desktop.
Save ashlaban/142827263a51488c3098da0e456a197a to your computer and use it in GitHub Desktop.
#
# NOTE: Requires that you have run `python create_dataset.py`
# from https://github.com/stwunsch/tmva_mnist first.
#
import ROOT
from resnet import ResnetBuilder
# Setup TMVA
ROOT.TMVA.Tools.Instance()
ROOT.TMVA.PyMethodBase.PyInitialize()
output = ROOT.TFile.Open("TMVA.root", "RECREATE")
factory = ROOT.TMVA.Factory(
"MNIST", output,
"!V:!Silent:Color:DrawProgressBar:AnalysisType=multiclass")
# Load data
dataloader = ROOT.TMVA.DataLoader("dataset")
data = ROOT.TFile.Open("mnist.root")
tree_digit = []
for i in range(10):
tree_digit.append(data.Get("train_digits/train_digit{}".format(i)))
dataloader.AddTree(tree_digit[i], "digit{}".format(i))
for i in range(28 * 28 * 1):
dataloader.AddVariable("x[{}]".format(i), "x_{}".format(i), "")
dataloader.PrepareTrainingAndTestTree(
ROOT.TCut(""), "SplitMode=Random:NormMode=None:!V")
# Define model
model = ResnetBuilder.build_resnet_18((1, 28, 28), 10)
model.compile(
loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
# Store model to file
model.save("model.h5")
model.summary()
# Book methods
factory.BookMethod(dataloader, ROOT.TMVA.Types.kPyKeras, "PyKeras",
"H:!V:FilenameModel=model.h5:NumEpochs=10:BatchSize=100")
# Run training, test and evaluation
factory.TrainAllMethods()
factory.TestAllMethods()
factory.EvaluateAllMethods()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment