Skip to content

Instantly share code, notes, and snippets.

@cwpearson
Created November 8, 2018 18:11
Show Gist options
  • Save cwpearson/c14525b6634f377a532cad9b94467689 to your computer and use it in GitHub Desktop.
Save cwpearson/c14525b6634f377a532cad9b94467689 to your computer and use it in GitHub Desktop.
#! /bin/env python
"""Convert higgs dataset from GBM-Benchmarks to libsvm format"""
import pandas as pd
with open('HIGGS.csv') as f:
num_lines = sum(1 for line in f)
print(num_lines)
train_lines = int(num_lines * 0.95)
test_lines = num_lines - train_lines
train_f = open("HIGGS.csv.train", "w")
test_f = open("HIGGS.csv.test", "w")
train_f.write("")
test_f.write("")
train_f = open("HIGGS.csv.train", "a")
test_f = open("HIGGS.csv.test", "a")
with open('HIGGS.csv') as f:
for li,line in enumerate(f):
fields = line.split(",")
output = str(int(float(fields[0])))
for i,f in enumerate(fields[1:]):
output += " " + str(i) + ":" + f
if li % 10000 == 0:
print(float(li)/num_lines)
if li < train_lines:
train_f.write(output)
else:
test_f.write(output)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment