Skip to content

Instantly share code, notes, and snippets.

@dongguosheng
Last active October 11, 2016 07:43
Show Gist options
  • Save dongguosheng/5031ff6f41ec978cf01199a855ef1b8c to your computer and use it in GitHub Desktop.
Save dongguosheng/5031ff6f41ec978cf01199a855ef1b8c to your computer and use it in GitHub Desktop.
# -*- coding: gbk -*-
import math
def cal_ndcg(r_list, k):
idcg = cal_dcg(sorted(r_list, reverse=True), k)
if idcg <= 1e-10:
return 1.0
return cal_dcg(r_list, k) / idcg
def cal_dcg(r_list, k):
sum_dcg = 0.0
for i, label in enumerate(r_list):
if i >= k:
break
sum_dcg += ((1 << label) - 1) / math.log(i + 2.0, 2)
return sum_dcg
def cal_rmse(pred_list, label_list):
rmse = 0.0
for pred, label in zip(pred_list, label_list):
rmse += (pred - label)**2
return math.sqrt(rmse / len(pred_list))
def main():
r_list = [0, 1, 1, 0, 2, 0, 0]
k = 5
print 'dcg@%d: %.4f' % (k, cal_dcg(r_list, k))
print 'ndcg@%d: %.4f' % (k, cal_ndcg(r_list, k))
if __name__ == '__main__':
main()
# -*- coding: gbk -*-
from lxml import etree
from Queue import Queue
from operator import itemgetter
from itertools import groupby
import random
import math
from metric import cal_rmse, cal_ndcg
MAX_FID = 25
class DFeature(object):
def __init__(self):
self.f = []
self.y = 0
self.hess = 1
def from_str(self, line):
tmp_list = line.rstrip().split()
self.y = float(tmp_list[0])
i = 0
for e in tmp_list[1: ]:
fid, val = e.split(':')
fid = int(fid)
val = float(val)
while i < fid:
self.f.append(0.0)
i += 1
self.f.append(val)
i += 1
while i <= MAX_FID:
self.f.append(0.0)
i += 1
class QNode(object):
def __init__(self, node_id, cnt, loss):
self.node_id = node_id # id of tree nodes
self.cnt = cnt
self.loss = loss
def __str__(self):
return 'node_id: %d, cnt: %d, loss: %.2f' % (self.node_id, self.cnt, self.loss)
class SplitInfo(object):
def __init__(self):
self.feature_id = -1
self.split_val = 0.0
self.cnt_pair = [0, 0]
self.sum_hess_pair = [0.0, 0.0]
self.sum_grad_pair = [0.0, 0.0]
self.sum_sqr_grad_pair = [0.0, 0.0]
self.loss = 0.0
self.cnt0 = 0
self.sum0 = 0.0
self.sum0_hess = 0.0
self.ss0 = 0.0
self.last_val = 0.0
def update(self, split_info):
if isinstance(split_info, SplitInfo):
self.feature_id = split_info.feature_id
self.split_val = split_info.split_info
self.cnt_pair = split_info.cnt_pair
self.sum_hess_pair = split_info.sum_hess_pair
self.sum_grad_pair = split_info.sum_grad_pair
self.sum_sqr_grad_pair = split_info.sum_sqr_grad_pair
self.loss = split_info.loss
else:
print 'split_info type error.'
def __str__(self):
return 'fid: %d, split_val: %d, cnt: %s, sum_grad: %s, sum_hess: %s, sum_sqr_grad: %s, loss: %.2f, cnt0: %d, sum0: %.2f, ss0: %.2f' % \
(self.feature_id, self.split_val, str(self.cnt_pair), str(self.sum_grad_pair), str(self.sum_hess_pair),
str(self.sum_sqr_grad_pair), self.loss, self.cnt0, self.sum0, self.ss0)
class Node(object):
def __init__(self):
self.feature_id = -1
self.split_val = 0
self.ch = [-1, -1]
self.value = 0
self.sum_grad = 0
self.sum_hess = 0
self.sum_sqr_grad = 0
def load_from_element(self, element):
n_children = len(element.getchildren())
if n_children != 1 and n_children != 4:
print 'invalid format: n_children is %d' % n_children
elif n_children == 1: # leaf
self.value = float(element.xpath('output')[0].text.strip())
elif n_children == 4: # split node
self.feature_id = int(element.xpath('feature')[0].text.strip())
self.split_val = float(element.xpath('threshold')[0].text.strip())
return element.xpath('split')
def __str__(self):
return 'f: %d, split_val: %.2f, ch: %s, value: %.2f, sum_grad: %.2f, sum_hess: %.2f, sum_sqr_grad: %.2f' % \
(self.feature_id, self.split_val, str(self.ch), self.value, self.sum_grad, self.sum_hess, self.sum_sqr_grad)
class Tree(object):
def __init__(self):
self.nodes = []
self.to_be_splited = []
self.grad_list = []
self.hess_list = []
self.sqr_grad_list = []
self.positions = []
self.min_children = 1
self.n = 0
self.features = []
def load_from_element(self, element):
if 1 != len(element.getchildren()):
print 'format invalid.'
root = element.getchildren()[0]
# breadth-first traverse
q = Queue()
q.put(root)
while not q.empty():
node = Node()
children = node.load_from_element(q.get())
self.nodes.append(node)
for e in children:
q.put(e)
# ch update
self.nodes[-1].ch = (len(self.nodes) + q.qsize() - 2, len(self.nodes) + q.qsize() - 1)
def to_xml(self, pre='\t\t'):
q = Queue()
root = etree.Element('split')
q.put( (self.nodes[0], root) )
while not q.empty():
node, e = q.get()
if node.feature_id < 0: # leaf
output = etree.SubElement(e, 'output')
output.text = ' %f ' % node.value
else:
feature = etree.SubElement(e, 'feature')
feature.text = ' %s ' % node.feature_id
threshold = etree.SubElement(e, 'threshold')
threshold.text = ' %s ' % node.split_val
l_split = etree.SubElement(e, 'split')
l_split.attrib['pos'] = 'left'
q.put( (self.nodes[node.ch[0]], l_split) )
r_split = etree.SubElement(e, 'split')
r_split.attrib['pos'] = 'right'
q.put( (self.nodes[node.ch[1]], r_split) )
xml_str = etree.tostring(root, pretty_print=True)
s = ''
for line in xml_str.rstrip().split('\n'):
start_idx = line.find('<')
s += pre + '\t' * (start_idx / 2) + line[start_idx: ] + '\n'
return s
def predict(self, feature_dict, is_debug=False):
# take feature missing as zero
idx = 0
path = []
while self.nodes[idx].feature_id >= 0:
if self.nodes[idx].feature_id not in feature_dict or feature_dict[self.nodes[idx].feature_id] <= self.nodes[idx].split_val:
# path.append( '{fid} <= {split_val}'.format(fid=self.nodes[idx].feature_id, split_val=self.nodes[idx].split_val) )
if is_debug:
path.append( (self.nodes[idx].feature_id, '<=', self.nodes[idx].split_val) )
idx = self.nodes[idx].ch[0]
else:
# path.append( '{fid} > {split_val}'.format(fid=self.nodes[idx].feature_id, split_val=self.nodes[idx].split_val) )
if is_debug:
path.append( (self.nodes[idx].feature_id, '>', self.nodes[idx].split_val) )
idx = self.nodes[idx].ch[1]
if is_debug:
path = sorted(path, key=itemgetter(0))
# print path,
# print self.nodes[idx].value
return (self.nodes[idx].value, path)
else:
return self.nodes[idx].value
def build_tree(self, features, max_depth=3, min_children=1, row_sample=1.0, col_sample=1.0):
# TODO:
self.n = len(features) # num of instances
self.features = features
m = len(features[0].f) if len(features) > 0 else 0
self.min_children = min_children
id_list = []
sum_grad = 0.0
sum_hess = 0.0
sum_sqr_grad = 0.0
self.grad_list = [0 for _ in range(self.n)]
self.hess_list = [0 for _ in range(self.n)]
self.sqr_grad_list = [0 for _ in range(self.n)]
self.positions = [0 for _ in range(self.n)]
# bootstrap, instance sampling
for i, feature in enumerate(features):
if random.random() < row_sample:
id_list.append(i) # append sample id
self.grad_list[i] = feature.y
self.hess_list[i] = feature.hess
self.sqr_grad_list[i] = feature.y**2
sum_grad += self.grad_list[i]
sum_hess += self.hess_list[i]
sum_sqr_grad += self.sqr_grad_list[i]
self.positions[i] = 0
else:
self.positions[i] = -1
print 'bootstrap done'
# add root node
root = Node()
root.feature_id = -1
root.value = sum_grad / (sum_hess if sum_hess > 0 else 1.0)
root.sum_grad = sum_grad # for split find
root.sum_hess = sum_hess
root.sum_sqr_grad = sum_sqr_grad # for split find
self.nodes.append(root)
print 'add root done'
if len(id_list) == 0:
return
self.to_be_splited.append( QNode(0, len(id_list), sum_sqr_grad - (sum_grad**2)/len(id_list)) )
# sort samples by each feature val globally, row -> feature, col -> sample
fid_list = [fid for fid in range(len(features[0].f))] # can do col sample here, heihei
fid_list_sample = [fid for fid in fid_list if random.random() < col_sample]
col_fea_list = [ [(features[idx].f[fid], idx) for idx in id_list] for fid in fid_list ]
for samples in col_fea_list:
# print 'before sort: '
# print samples
samples.sort(key=itemgetter(0))
# print 'after sort: '
# print samples
print 'samples sorted globally'
# begin to grow
dep = 0
while dep < max_depth:
if len(self.to_be_splited) == 0:
break
print 'growing depth: %d' % dep
# for each node to be splited, find best splitinfo
split_infos = [SplitInfo() for _ in self.to_be_splited]
# init loss
for split_info, qnode in zip(split_infos, self.to_be_splited):
split_info.loss = qnode.loss
# print 'split info: %s' % split_info
# print 'qnode : %s' % qnode
for qnode in self.to_be_splited:
for fid in fid_list_sample:
self.find_split(fid, col_fea_list[fid], split_infos)
print 'find split done'
# for split_info in split_infos:
# print 'split info: %s' % split_info
self.update_queue(split_infos)
print 'update queue done'
dep += 1
def find_split(self, fid, fea_list, split_infos):
for split_info in split_infos:
split_info.cnt0 = 0
split_info.sum0 = 0.0
split_info.sum0_hess = 0.0
split_info.ss0 = 0.0;
for i, feature_pair in enumerate(fea_list):
val, sample_id = feature_pair
pos = self.positions[sample_id]
if pos < 0:
continue
qnode = self.to_be_splited[pos]
node_id = qnode.node_id
split_info = split_infos[pos]
sum1 = sum1_hess = ss1 = loss = 0.0
if split_info.cnt0 >= self.min_children and (qnode.cnt - split_info.cnt0) >= self.min_children and self.sign(val - split_info.last_val) != 0:
sum1 = self.nodes[node_id].sum_grad - split_info.sum0
sum1_hess = self.nodes[node_id].sum_hess - split_info.sum0_hess
ss1 = self.nodes[node_id].sum_sqr_grad - split_info.ss0
loss = split_info.ss0 - split_info.sum0**2 / split_info.cnt0 + ss1 - sum1**2 / (qnode.cnt - split_info.cnt0)
if self.sign(loss - split_info.loss) < 0:
# print 'better split, loss: from %.2f to %.2f, pos: %d' % (split_info.loss, loss, pos)
split_info.loss = loss
split_info.feature_id = fid;
split_info.split_val = (val + split_info.last_val) / 2.0
split_info.sum_grad_pair = [split_info.sum0, sum1]
split_info.sum_hess_pair = [split_info.sum0_hess, sum1_hess]
split_info.sum_sqr_grad_pair = [split_info.ss0, ss1]
split_info.cnt_pair = [split_info.cnt0, qnode.cnt - split_info.cnt0]
split_info.cnt0 += 1
split_info.sum0 += self.grad_list[sample_id]
split_info.sum0_hess += self.hess_list[sample_id]
split_info.ss0 += self.sqr_grad_list[sample_id]
split_info.last_val = val
def update_queue(self, split_infos):
new_to_be_splited = []
child_q_pos = []
for qnode, split_info in zip(self.to_be_splited, split_infos):
child_q_pos.append( (len(new_to_be_splited), len(new_to_be_splited) + 1) )
if split_info.feature_id >= 0:
node_id = qnode.node_id
self.nodes[node_id].feature_id = split_info.feature_id
self.nodes[node_id].split_val = split_info.split_val
self.nodes[node_id].ch[0] = len(self.nodes)
self.nodes[node_id].ch[1] = len(self.nodes) + 1
for i in range(2):
new_node = Node()
new_node.value = split_info.sum_grad_pair[i] / split_info.sum_hess_pair[i]
new_node.sum_grad = split_info.sum_grad_pair[i]
new_node.sum_hess = split_info.sum_hess_pair[i]
new_node.sum_sqr_grad = split_info.sum_sqr_grad_pair[i]
# print 'new node: %s' % new_node
loss = new_node.sum_sqr_grad - new_node.sum_grad**2 / split_info.cnt_pair[i]
# print 'qnode loss: %.2f' % loss
new_to_be_splited.append( QNode(len(self.nodes), split_info.cnt_pair[i], loss) )
# add leaf node
self.nodes.append(new_node)
# update the queue of to be splited qnode
self.to_be_splited = new_to_be_splited
# mark samples, in which split idx
for i in range(self.n):
pos = self.positions[i]
if self.positions[i] >= 0 and split_infos[pos].feature_id >= 0:
if self.features[i].f[split_infos[pos].feature_id] <= split_infos[pos].split_val:
self.positions[i] = child_q_pos[pos][0]
else:
self.positions[i] = child_q_pos[pos][1]
else:
self.positions[i] = -1
def sign(self, val):
EPSI = 1e-4;
if val > EPSI:
return 1
elif val < -EPSI:
return -1
else:
return 0
def __str__(self):
return '\n'.join(str(i) + ': ' + str(e) for i, e in enumerate(self.nodes))
class Forest(object):
def __init__(self):
self.lr = 0.1
self.trees = []
self.fmap = None
def load_from_xml(self, model_file):
with open(model_file) as fin:
for i in range(6):
if i == 4: # learning rate
self.lr = float(fin.readline().split('=')[1])
else:
fin.readline()
xml_str = fin.read()
root = etree.fromstring(r'<?xml version="1.0"?>' + xml_str)
for e in root.getchildren():
tree = Tree()
tree.load_from_element(e)
self.trees.append(tree)
def save_to_xml(self, model_file):
with open(model_file, 'w') as fout:
fout.write('## %s\n' % 'GBDT')
fout.write('## No. of trees = %d\n' % len(self.trees))
fout.write('## No. of leaves = %d\n' % -1)
fout.write('## No. of threshold candidates = %d\n' % -1)
fout.write('## Learning rate = %.2f\n' % self.lr)
fout.write('## Stop early = %d\n\n' % -1)
fout.write('<ensemble>\n')
for id, tree in enumerate(self.trees):
fout.write('\t<tree id="%d" weight="%.2f">\n' % (id, self.lr))
fout.write( tree.to_xml(pre='\t\t') )
fout.write('\t</tree>\n')
fout.write('</ensemble>\n')
def predict(self, feature, is_debug=False):
feature_dict = {}
if isinstance(feature, DFeature):
for i, val in enumerate(feature.f):
feature_dict[i] = val
elif isinstance(feature, dict):
feature_dict = feature
rs = 0
paths = []
for tree in self.trees:
rs_once = tree.predict(feature_dict, is_debug)
if isinstance(rs_once, float):
rs += rs_once * self.lr
else:
rs += rs_once[0] * self.lr
paths.append( (rs_once[1], rs_once[0] * self.lr))
if is_debug:
paths = self.__merge_path(paths)
paths = sorted(paths, key=itemgetter(1), reverse=True if rs > 0 else False)
# print 'num of merge paths: %d' % len(paths)
# for path, weight in paths:
# print path, weight
return (rs, paths)
else:
return rs
def build_lambdamart(self, features, group_list, n_tree, max_depth=3, min_children=1, row_sample=1.0, col_sample=1.0, vali_info=None):
# 0. make pairs; 1. compute gradients; 2. compute delta ndcg;
# group_list: [(qid, label, idx), ...], already sorted by qid.
assert len(features) == len(group_list)
print 'num of samples: %d' % len(features)
print 'num of feature dim: %d' % len(features[0].f)
# 0. make pairs
lambda_pairs = []
for qid, group in groupby(group_list, itemgetter(0)):
ranklist = list(group)
pair_list = self.make_lambda_pairs(ranklist)
lambda_pairs.extend( pair_list )
print 'make pair done.'
# 1. compute gradients
pred_list = [ [qid, label, 0.0] for qid, label, _ in group_list ]
vali_features = vali_group_list = vali_pred_list = None
if vali_info is not None:
vali_features, vali_group_list = vali_info
vali_pred_list = [ [qid, label, 0.0] for qid, label, _ in vali_group_list ]
for i in range(n_tree):
# init grad
for feature in features:
feature.y = 0
feature.hess = 0
# delta ndcg TODO:
for idx_left, idx_right in lambda_pairs:
grad = 1.0 / ( 1.0 + math.exp(pred_list[idx_right][2] - pred_list[idx_left][2]) )
hess = max(grad * (1.0 - grad), 1e-16)
# print 'grad: %.2f' % grad
features[idx_left].y -= grad
features[idx_right].y += grad
features[idx_left].hess += hess
features[idx_right].hess += hess
# for feature in features:
# print feature.y
print 'compute grad && hess done.'
new_tree = Tree()
new_tree.build_tree(features, max_depth=max_depth, min_children=min_children, row_sample=row_sample, col_sample=col_sample)
# print new_tree
self.trees.append(new_tree)
for j, feature in enumerate(features):
pred_list[j][2] = self.predict(feature)
ndcg = 0.0
cnt = 0
k = 10
for qid, group in groupby(pred_list, itemgetter(0)):
ranklist = list(group)
r_list = [ label for _, label, pred in sorted(ranklist, key=itemgetter(2), reverse=True) ]
ndcg += cal_ndcg(r_list, k)
cnt += 1
print 'train ndcg@%d: %.4f' % (k, ndcg / cnt)
if vali_features is not None:
for j, feature in enumerate(vali_features):
vali_pred_list[j][2] = self.predict(feature)
vali_ndcg = 0.0
vali_cnt = 0
for qid, group in groupby(vali_pred_list, itemgetter(0)):
ranklist = list(group)
r_list = [ label for _, label, pred in sorted(ranklist, key=itemgetter(2), reverse=True) ]
vali_ndcg += cal_ndcg(r_list, k)
vali_cnt += 1
print 'vali ndcg@%d: %.4f' % (k, vali_ndcg / vali_cnt)
print 'iter %d' % i
print '--------'
def make_lambda_pairs(self, ranklist):
rs = []
for e1 in ranklist:
for e2 in ranklist:
if e1[1] < e2[1]: # left label < right label
rs.append( (e1[2], e2[2]) )
return rs
def build_gbdt(self, features, n_tree, max_depth=3, min_children=1, row_sample=1.0, col_sample=1.0, vali_features=None):
print 'num of samples: %d' % len(features)
print 'num of feature dim: %d' % len(features[0].f)
label_list = [ feature.y for feature in features ]
if vali_features is not None:
vali_label_list = [ feature.y for feature in vali_features ]
pred_list = [ 0.0 for feature in features ]
for i in range(n_tree):
assert len(label_list) == len(pred_list)
# compute gradients
for j in range(len(label_list)):
features[j].y = label_list[j] - pred_list[j]
new_tree = Tree()
new_tree.build_tree(features, max_depth=max_depth, min_children=min_children, row_sample=row_sample, col_sample=col_sample)
# print new_tree
self.trees.append(new_tree)
pred_list = [ self.predict(feature) for feature in features ]
# print 'pred list: %s' % str(pred_list)
print 'train rmse: %.4f' % cal_rmse(pred_list, label_list)
if vali_features is not None:
vali_pred_list = [ self.predict(feature) for feature in vali_features ]
print 'val rmse: %.4f' % cal_rmse(vali_pred_list, vali_label_list)
print 'iter %d' % i
print '--------'
def __merge_path(self, paths):
# sort by path length or node in the path, group by path length first
merge_paths = []
for length, group in groupby(sorted(paths, key=lambda pair: (len(pair[0]), pair[0])), lambda pair: len(pair[0])):
# then group by fid and '<=/>'
for key, group_len in groupby(group, key=lambda pair: set( ( e[0], e[1]) for e in pair[0] ) ):
key = sorted(key, key=itemgetter(0))
tmp_list = list(group_len)
# if length == 1:
# print 'key: ' + str(key)
# print 'tmp_list: ' + str(tmp_list)
w = sum([val for path, val in tmp_list])
# print w
# TODO: if '<=', find min split vals; else find max split vals
split_vals_dict = {}
for node in tmp_list[0][0]:
if ( node[0], node[1] ) in split_vals_dict: # same fid in one path
if node[1] == '<=' and split_vals_dict[( node[0], node[1] )] > node[2]:
split_vals_dict[( node[0], node[1] )] = node[2]
elif node[1] == '>' and split_vals_dict[( node[0], node[1] )] <= node[2]:
split_vals_dict[( node[0], node[1])] = node[2]
else:
split_vals_dict[( node[0], node[1] )] = node[2]
for path, val in tmp_list[1: ]:
for node in path:
if node[1] == '<=' and split_vals_dict[( node[0], node[1] )] > node[2]:
split_vals_dict[( node[0], node[1])] = node[2]
elif node[1] == '>' and split_vals_dict[( node[0], node[1] )] <= node[2]:
split_vals_dict[( node[0], node[1])] = node[2]
# print min(path[0][2] for path, val in tmp_list)
merge_path = ''
if self.fmap is None:
merge_path = [(fid, symbol, split_vals_dict[( fid, symbol )]) for fid, symbol in key]
else:
merge_path = [(fid, self.fmap[fid] if fid in self.fmap else '', symbol, split_vals_dict[( fid, symbol )]) for fid, symbol in key]
# if length == 1:
# print merge_path
merge_paths.append( (merge_path, w) )
return merge_paths
def get_fmap(self, fmap_file):
self.fmap = {}
with open(fmap_file) as fin:
for line in fin:
if len(line.rstrip().split()) != 2:
continue
fid, fname = line.rstrip().split()
self.fmap[int(fid)] = fname
def fea_str_to_dict(self, s):
feature_dict = {}
for feature in s.strip().split():
fid, val = feature.split(':')
feature_dict[int(fid)] = float(val)
return feature_dict
def fea_dict_to_str(self, fea_dict):
fea_str = ''
for fid, val in sorted(fea_dict.items(), key=itemgetter(0)):
fea_str += ('%d' % fid)
if self.fmap is not None:
fea_str += ('(%s):' % self.fmap[fid])
fea_str += ('%f' % val).rstrip('0').rstrip('.')
fea_str += ' '
return fea_str
def load_feat(fea_file):
features = []
with open(fea_file) as fin:
for line in fin:
feature = DFeature()
feature.from_str(line)
features.append(feature)
return features
def load_rank_feat(fea_file):
features = []
group_list = []
idx = 0
with open(fea_file) as fin:
for line in fin:
label, qid_str, feature_str = line.rstrip().split('#')[0].split(' ', 2)
qid = int(qid_str.split(':')[1])
feature = DFeature()
feature.from_str('0 ' + feature_str)
features.append(feature)
group_list.append( (qid, int(label), idx) )
idx += 1
print 'num of feature: %d.' % (len(features))
return (features, group_list)
def main():
# model_file = 'ranklib_model.txt'
# fmap_file = 'fmap.txt'
# forest = Forest()
# forest.load_from_xml(model_file)
# forest.get_fmap(fmap_file)
# feature_str = '1000:50 1001:40 1002:40 1003:20 1004:255 1005:255 1006:30 1007:0.312578 1021:5.152 1022:15 1023:12 1024:4 1025:0 1026:0 1027:1 1041:0 1061:68 1062:142 1063:158 1081:82 1082:83 1083:176 1084:82.7 1101:1.45669 1102:144 1121:97'
# feature_str = '1000:23 1001:32 1002:32 1003:12 1004:172 1005:155 1006:27 1007:0.312578 1021:3.693 1022:14 1023:10 1024:4 1025:0 1026:0 1027:1 1041:0 1061:55 1062:150 1063:168 1081:64 1082:64 1083:138 1084:64 1101:1.18467 1102:19 1121:95'
# feature_dict = forest.fea_str_to_dict(feature_str)
# is_debug = True
# from datetime import datetime, timedelta
# now = datetime.now()
# rs = forest.predict(feature_dict, is_debug)
# score, paths = '', ''
# if isinstance(rs, tuple):
# score, paths = rs
# else:
# score = rs
# print score
# print paths
# print (datetime.now() - now).microseconds
# --------------- test gbdt train ----------------- #
# fea_file = 'heart_scale'
# fea_file = 'tmp_fea'
# vali_fea_file = 'tmp_fea_vali'
# max_depth = 3
# min_children = 10
# row_sample = 0.5
# col_sample = 0.5
# features = load_feat(fea_file)
# vali_features = load_feat(vali_fea_file)
# n_tree = 100
# from datetime import datetime
# now = datetime.now()
# forest = Forest()
# forest.build_gbdt(features, n_tree,
# max_depth=max_depth, min_children=min_children, row_sample=row_sample, col_sample=col_sample, vali_features=vali_features)
# print 'build forest cost: ' + str((datetime.now() - now))
# forest.save_to_xml('model.xml')
# --------------- test lambdamart train ----------------- #
# fea_file = 'heart_scale'
fea_file = 'tmp_fea_rank'
vali_fea_file = 'tmp_fea_rank_vali'
max_depth = 3
min_children = 10
row_sample = 0.3
col_sample = 0.3
features, group_list = load_rank_feat(fea_file)
vali_info = load_rank_feat(vali_fea_file)
n_tree = 50
from datetime import datetime
now = datetime.now()
forest = Forest()
forest.build_lambdamart(features, group_list, n_tree,
max_depth=max_depth, min_children=min_children, row_sample=row_sample, col_sample=col_sample, vali_info=vali_info)
print 'build forest cost: ' + str((datetime.now() - now))
forest.save_to_xml('model.xml')
if __name__ == '__main__':
main()
@dongguosheng
Copy link
Author

dongguosheng commented Sep 30, 2016

ok

@dongguosheng
Copy link
Author

dongguosheng commented Oct 10, 2016

add pairwise lambdamart, output weight = G/H

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment