Created
December 4, 2019 20:23
-
-
Save tahuang1991/430361dc754cce7034530ad995801223 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ROOT | |
import argparse | |
import os | |
import numpy as np | |
from array import array | |
parser = argparse.ArgumentParser(description='Compute event category fraction on a given sample') | |
parser.add_argument('-i', '--inputfile',dest= "inputfile", type=str, metavar='STR', help='input file') | |
options = parser.parse_args() | |
print "==============================================================" | |
print "inputfile ", options.inputfile | |
print "==============================================================" | |
_rootBranchType2PythonArray = { 'b':'B', 'B':'b', 'i':'I', 'I':'i', 'F':'f', 'D':'d', 'l':'L', 'L':'l', 'O':'B' } | |
treename = "evtree" | |
tf = ROOT.TFile(options.inputfile, "read") | |
chain = tf.Get( treename ) | |
outfile = options.inputfile.split('.')[0] | |
f = ROOT.TFile(outfile+"_addbdt_B8_20k_200k_train.root", 'recreate'); f.cd() | |
TCha2 = chain.CloneTree(0) | |
rootBranchType = "F" | |
bdt_value = array(_rootBranchType2PythonArray[rootBranchType], [0.0]) | |
#dy_Mbtag_weight = array(_rootBranchType2PythonArray[rootBranchType], [0.0]) | |
br_bdt_value = TCha2.Branch("bdt_value", bdt_value ,"bdt_value/%s" % (rootBranchType)) | |
#br_dy_Mbtag_weight = TCha2.Branch("dy_Mbtag_weight", dy_Mbtag_weight , "dy_Mbtag_weight/%s" % (rootBranchType)) | |
def initBranchVals(): | |
bdt_value[0] = -99 | |
#dy_Mbtag_weight[0] = 0.0 | |
#def fillBranches(): | |
# print "fillBranches(): bdt_value ",bdt_value[0]," weight ",dy_Mbtag_weight[0] | |
# br_bdt_value.Fill() | |
# br_dy_Mbtag_weight.Fill() | |
entries = None | |
if not entries: | |
entries = chain.GetEntries() | |
bdt_tmva_variables = [ | |
#"recomass" | |
"dR_l1l2", | |
"dR_b1b2", | |
#"dR_l1l2b1b2", | |
"dphi_l1l2b1b2", | |
"MT2", | |
"dR_minbl" | |
#"mass_l1l2", | |
#"mass_b1b2", | |
#"mass_trans", | |
#"dR_bl" | |
#"pt_b1b2" | |
] | |
date = "2019_10_15" | |
suffix = "_".join(bdt_tmva_variables) | |
label_template = "DATE_MethodLab3_SUFFIX" | |
#bdt_label = label_template.replace("DATE", date).replace("SUFFIX", suffix).replace("Method", "".join(MVAmethods)) | |
bdt_label = "2019_10_15_kBDTkDNNkMLPLab3_dR_l1l2sig_bkg_varsdR_b1b2sig_bkg_varsdR_blsig_bkg_varsdR_minblsig_bkg_varsdphi_l1l2b1b2sig_bkg_varspt_b1b2" | |
bdt_xml_file = "DYBDTTraining/weights/{}_kBDT.weights.xml".format(bdt_label) | |
#bdt_xml_file = "DYBDTTraining/weights/2019_10_15_kBDTkDNNLab3_dR_l1l2_dR_b1b2_dR_bl_dR_minbl_dphi_l1l2b1b2_pt_b1b2_kBDT.weights.xml" | |
#bdt_label = "2019_10_15_kBDTkDNNLab3_dR_l1l2_dR_b1b2_dR_bl_dR_minbl_dphi_l1l2b1b2_pt_b1b2" | |
#bdt_label = "TMVAClassification_B6_TTbar" | |
#bdt_label = "TMVA_SignalB8_Bkg_smeared" | |
bdt_label = "TMVA_SignalB8_20k_Bkg_200k" | |
#bdt_label = "dataset" | |
#bdt_label = "TMVAClassification" | |
#bdt_xml_file = "dataset/weights/TMVAClassification_B6_TTbar_BDT.weights.xml" | |
#bdt_xml_file = "dataset/weights/TMVAClassification_MLP.weights.xml" | |
bdt_xml_file = "dataset/weights/{}_kBDT.weights.xml".format(bdt_label) | |
print "bdt_xml_file ",bdt_xml_file | |
dict_tmva_variables = { var: array('f', [0]) for var in bdt_tmva_variables } | |
#m_reader = ROOT.TMVA.Reader("Silent=1") | |
m_reader = ROOT.TMVA.Reader() | |
for var in bdt_tmva_variables: | |
m_reader.AddVariable(var, dict_tmva_variables[var]) | |
sepctatorvar = ["recomass"] | |
sepctatorvar =[] | |
dict_tmva_spectators = { var: array('f', [0]) for var in sepctatorvar } | |
for var in sepctatorvar: | |
m_reader.AddSpectator(var, dict_tmva_spectators[var]) | |
print "bdd label ",bdt_label, " xml ", bdt_xml_file | |
#MVAmethod = "kBDT" | |
#bdt_type = getattr(ROOT.TMVA.Types, MVAmethod) | |
#print "mvamethod ",MVAmethod, " type ", bdt_type | |
m_reader.BookMVA(bdt_label, bdt_xml_file) | |
entries = None | |
print("Loading chain...") | |
if not entries: | |
entries = chain.GetEntries() | |
print("Done.") | |
print("Adding medium btagging weight for %d events." % entries) | |
for i in range(0, entries): | |
chain.GetEntry(i) | |
initBranchVals() | |
if (i % 1000 == 0): | |
print("Event %d over %d" % (i + 1, entries)) | |
def get_value(object, val): | |
return getattr(object, val) | |
for var in bdt_tmva_variables: | |
# Special treatment for variables not retrieved from the base object | |
dict_tmva_variables[var][0] = get_value(chain, var) | |
for var in sepctatorvar: | |
# Special treatment for variables not retrieved from the base object | |
dict_tmva_spectators[var][0] = get_value(chain, var) | |
bdt_value[0] = m_reader.EvaluateMVA(bdt_label) | |
#print " bdt_value ",bdt_value[0]," weight ",dy_Mbtag_weight[0], " chain value bdt_value ",chain.bdt_value," weight ",chain.dy_Mbtag_weight | |
TCha2.Fill() | |
#fillBranches() | |
# br_bdt_value.Fill() | |
# br_dy_Mbtag_weight.Fill() | |
# | |
# | |
# | |
# | |
#chain.Write() | |
#tf.Close() | |
TCha2.Write() | |
f.Close() | |
tf.Close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment