Skip to content

Instantly share code, notes, and snippets.

@tahuang1991
Created December 4, 2019 20:23
Show Gist options
  • Save tahuang1991/430361dc754cce7034530ad995801223 to your computer and use it in GitHub Desktop.
Save tahuang1991/430361dc754cce7034530ad995801223 to your computer and use it in GitHub Desktop.
import ROOT
import argparse
import os
import numpy as np
from array import array
parser = argparse.ArgumentParser(description='Compute event category fraction on a given sample')
parser.add_argument('-i', '--inputfile',dest= "inputfile", type=str, metavar='STR', help='input file')
options = parser.parse_args()
print "=============================================================="
print "inputfile ", options.inputfile
print "=============================================================="
_rootBranchType2PythonArray = { 'b':'B', 'B':'b', 'i':'I', 'I':'i', 'F':'f', 'D':'d', 'l':'L', 'L':'l', 'O':'B' }
treename = "evtree"
tf = ROOT.TFile(options.inputfile, "read")
chain = tf.Get( treename )
outfile = options.inputfile.split('.')[0]
f = ROOT.TFile(outfile+"_addbdt_B8_20k_200k_train.root", 'recreate'); f.cd()
TCha2 = chain.CloneTree(0)
rootBranchType = "F"
bdt_value = array(_rootBranchType2PythonArray[rootBranchType], [0.0])
#dy_Mbtag_weight = array(_rootBranchType2PythonArray[rootBranchType], [0.0])
br_bdt_value = TCha2.Branch("bdt_value", bdt_value ,"bdt_value/%s" % (rootBranchType))
#br_dy_Mbtag_weight = TCha2.Branch("dy_Mbtag_weight", dy_Mbtag_weight , "dy_Mbtag_weight/%s" % (rootBranchType))
def initBranchVals():
bdt_value[0] = -99
#dy_Mbtag_weight[0] = 0.0
#def fillBranches():
# print "fillBranches(): bdt_value ",bdt_value[0]," weight ",dy_Mbtag_weight[0]
# br_bdt_value.Fill()
# br_dy_Mbtag_weight.Fill()
entries = None
if not entries:
entries = chain.GetEntries()
bdt_tmva_variables = [
#"recomass"
"dR_l1l2",
"dR_b1b2",
#"dR_l1l2b1b2",
"dphi_l1l2b1b2",
"MT2",
"dR_minbl"
#"mass_l1l2",
#"mass_b1b2",
#"mass_trans",
#"dR_bl"
#"pt_b1b2"
]
date = "2019_10_15"
suffix = "_".join(bdt_tmva_variables)
label_template = "DATE_MethodLab3_SUFFIX"
#bdt_label = label_template.replace("DATE", date).replace("SUFFIX", suffix).replace("Method", "".join(MVAmethods))
bdt_label = "2019_10_15_kBDTkDNNkMLPLab3_dR_l1l2sig_bkg_varsdR_b1b2sig_bkg_varsdR_blsig_bkg_varsdR_minblsig_bkg_varsdphi_l1l2b1b2sig_bkg_varspt_b1b2"
bdt_xml_file = "DYBDTTraining/weights/{}_kBDT.weights.xml".format(bdt_label)
#bdt_xml_file = "DYBDTTraining/weights/2019_10_15_kBDTkDNNLab3_dR_l1l2_dR_b1b2_dR_bl_dR_minbl_dphi_l1l2b1b2_pt_b1b2_kBDT.weights.xml"
#bdt_label = "2019_10_15_kBDTkDNNLab3_dR_l1l2_dR_b1b2_dR_bl_dR_minbl_dphi_l1l2b1b2_pt_b1b2"
#bdt_label = "TMVAClassification_B6_TTbar"
#bdt_label = "TMVA_SignalB8_Bkg_smeared"
bdt_label = "TMVA_SignalB8_20k_Bkg_200k"
#bdt_label = "dataset"
#bdt_label = "TMVAClassification"
#bdt_xml_file = "dataset/weights/TMVAClassification_B6_TTbar_BDT.weights.xml"
#bdt_xml_file = "dataset/weights/TMVAClassification_MLP.weights.xml"
bdt_xml_file = "dataset/weights/{}_kBDT.weights.xml".format(bdt_label)
print "bdt_xml_file ",bdt_xml_file
dict_tmva_variables = { var: array('f', [0]) for var in bdt_tmva_variables }
#m_reader = ROOT.TMVA.Reader("Silent=1")
m_reader = ROOT.TMVA.Reader()
for var in bdt_tmva_variables:
m_reader.AddVariable(var, dict_tmva_variables[var])
sepctatorvar = ["recomass"]
sepctatorvar =[]
dict_tmva_spectators = { var: array('f', [0]) for var in sepctatorvar }
for var in sepctatorvar:
m_reader.AddSpectator(var, dict_tmva_spectators[var])
print "bdd label ",bdt_label, " xml ", bdt_xml_file
#MVAmethod = "kBDT"
#bdt_type = getattr(ROOT.TMVA.Types, MVAmethod)
#print "mvamethod ",MVAmethod, " type ", bdt_type
m_reader.BookMVA(bdt_label, bdt_xml_file)
entries = None
print("Loading chain...")
if not entries:
entries = chain.GetEntries()
print("Done.")
print("Adding medium btagging weight for %d events." % entries)
for i in range(0, entries):
chain.GetEntry(i)
initBranchVals()
if (i % 1000 == 0):
print("Event %d over %d" % (i + 1, entries))
def get_value(object, val):
return getattr(object, val)
for var in bdt_tmva_variables:
# Special treatment for variables not retrieved from the base object
dict_tmva_variables[var][0] = get_value(chain, var)
for var in sepctatorvar:
# Special treatment for variables not retrieved from the base object
dict_tmva_spectators[var][0] = get_value(chain, var)
bdt_value[0] = m_reader.EvaluateMVA(bdt_label)
#print " bdt_value ",bdt_value[0]," weight ",dy_Mbtag_weight[0], " chain value bdt_value ",chain.bdt_value," weight ",chain.dy_Mbtag_weight
TCha2.Fill()
#fillBranches()
# br_bdt_value.Fill()
# br_dy_Mbtag_weight.Fill()
#
#
#
#
#chain.Write()
#tf.Close()
TCha2.Write()
f.Close()
tf.Close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment