Last active
August 29, 2015 14:13
-
-
Save shun91/0b1ce77b1f1dd339d250 to your computer and use it in GitHub Desktop.
LIBSVMの学習データのスケーリングを行う.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#! /usr/bin/python | |
# -*- coding: utf-8 -*- | |
############################################################################### | |
# LIBSVM(LIBLINEAR)の学習データのスケーリング(標準化)を行う. | |
# 各素性が平均0,分散1の正規分布に従うようにスケーリングする. | |
# | |
# 次のコマンドで実行できる. | |
# $ python libsvm_gaussian_scaler.py [options] | |
# [options] | |
# -i file: 入力ファイル.省略すると35行目付近のINPUT_PATHで指定した値となる. | |
# -o file: 出力ファイル.基本は省略して構わないが,名前を指定したい時に. | |
# | |
# 入力ファイルのフォーマットはLIBSVMの学習データのフォーマットと同様. | |
# | |
# 出力ファイルは2種類のテキストファイル. | |
# デフォルトでは入力ファイルと同一ディレクトリに出力される. | |
# 1. 入力ファイル名.out | |
# スケーリングを行った学習データのファイル.フォーマットは入力と同じ. | |
# 2. 入力ファイル名.out.stats | |
# 各素性の「素性id 平均 標準偏差」がこの順番で記されているtsvファイル. | |
# | |
# LIBSVM | |
# http://www.csie.ntu.edu.tw/~cjlin/libsvm/ | |
############################################################################## | |
import argparse | |
import os | |
import re | |
import sys | |
import numpy as np | |
from scipy import stats | |
# デフォルトの入力ファイル | |
INPUT_PATH = '' | |
# 省略された素性の値 | |
DEFAULT_FEATURE_VALUE = 0 | |
def exec_argparse(): | |
''' | |
引数をパースした結果を連想配列で返す. | |
input_file: 入力ファイルパス | |
output_file: 出力ファイルパス | |
''' | |
parser = argparse.ArgumentParser(description='') | |
parser.add_argument('-i', '--input_file', help='入力するファイル', default=INPUT_PATH) | |
parser.add_argument('-o', '--output_file', help='出力するファイル') | |
return parser.parse_args() | |
def get_number_of_features(input_file): | |
''' | |
入力ファイルの素性数を返す. | |
(素性idの最大値を見つける.) | |
''' | |
# 素性idの最大値(これが素性数になる) | |
max_id = 0 | |
# 入力ファイルを開く | |
f_input_file = open(input_file, 'r') | |
# 事例ごとの素性idの最大値を,現在の素性idの最大値と比較し,更新する. | |
for line in f_input_file: | |
line = line.rstrip() # 末尾の改行を削除 | |
field = line.split(' ') # 半角スペースでフィールド分割 | |
del field[0] # ラベルフィールドは不要なので削除 | |
for feature in field: # 素性を1つずつ見ていく | |
field2 = feature.split(':') # idと素性値に分割 | |
if int(field2[0])>max_id: # 最大値なら更新 | |
max_id=int(field2[0]) | |
f_input_file.close() | |
return max_id | |
def complement_feature(field, number_of_features): | |
''' | |
省略された素性idを補完した素性値配列を生成して返す | |
''' | |
feature_dict = {} # 素性値を一旦辞書に格納(keyが素性id). | |
for feature in field: | |
field2 = feature.split(':') | |
feature_dict[int(field2[0])]=field2[1] | |
feature_array = [] | |
for i in range(1, number_of_features+1): | |
if i in feature_dict: | |
feature_array.append(feature_dict[i]) | |
else: | |
feature_array.append(DEFAULT_FEATURE_VALUE) | |
return feature_array | |
def convert_to_numpy_array_from_liblinear_file(input_file): | |
''' | |
liblinear形式のファイルを2つの配列に格納.2つの配列を含むtupleを返す. | |
''' | |
# 素性数のチェック | |
number_of_features = get_number_of_features(input_file) | |
# 配列の初期化 | |
label_array = [] # ラベル配列 | |
feature_array = np.zeros(number_of_features) # 素性値のnumpy配列 | |
# 入力ファイルを開く | |
f_input_file = open(input_file, 'r') | |
# 1行ずつ処理. | |
for line in f_input_file: | |
line = line.rstrip() # 末尾の改行を削除 | |
field = line.split(' ') # 半角スペースでフィールド分割 | |
label_array.append(field[0]) # ラベル配列に追加 | |
del field[0] # ラベルフィールドは素性値配列に不要なので削除 | |
# 省略された素性値の補完 | |
field = complement_feature(field, number_of_features) | |
# feature_arrayに追加 | |
feature_array=np.vstack((feature_array, np.array(field, dtype=float))) | |
f_input_file.close() | |
feature_array = np.delete(feature_array, 0, 0) # 初期化に利用した最初の行は不要なので削除 | |
return (label_array, feature_array) | |
def write_zscore(input_file, output_file, zscore_array, liblinear_tuple): | |
''' | |
スケーリング結果をliblinear形式のファイルとして出力 | |
''' | |
# 出力ファイルの初期化(削除) | |
if os.path.exists(output_file): | |
os.remove(output_file) | |
# 出力 | |
f_output_file = open(output_file, "a") | |
for i in range(0, len(liblinear_tuple[0])): | |
line = [] | |
line.append(liblinear_tuple[0][i]) | |
line.append(' ') | |
for j in range(0, len(zscore_array[0])): | |
line.append(str(j+1)) | |
line.append(':') | |
if zscore_array[i][j]!=zscore_array[i][j]: # NaN対策 | |
zscore_array[i][j]=0.0 | |
line.append(str(zscore_array[i][j])) | |
line.append(' ') | |
line.append('\n') | |
f_output_file.write(''.join(line)) | |
f_output_file.close() | |
def write_stats(output_file, mean_array, std_array): | |
''' | |
素性毎の平均と標準偏差をtsv形式でファイル出力. | |
''' | |
# 出力ファイルのパスはoutput_file.stats | |
output_file = output_file+'.stats' | |
# 出力ファイルの初期化(削除) | |
if os.path.exists(output_file): | |
os.remove(output_file) | |
#出力 | |
f_output_file = open(output_file, "a") | |
for i in range(0, len(mean_array)): | |
line = [] | |
line.append(str(i+1)) | |
line.append('\t') | |
line.append(str(mean_array[i])) | |
line.append('\t') | |
line.append(str(std_array[i])) | |
line.append('\n') | |
f_output_file.write(''.join(line)) | |
f_output_file.close() | |
def write_results(input_file, output_file, zscore_array, mean_array, std_array, liblinear_tuple): | |
''' | |
結果をファイル出力 | |
''' | |
# 出力ファイルが指定されていなければinput_file.outに出力 | |
if output_file == None: | |
output_file=input_file+'.out' | |
# 出力 | |
write_zscore(input_file, output_file, zscore_array, liblinear_tuple) | |
write_stats(output_file, mean_array, std_array) | |
def main(): | |
# 引数をパースしてargsに格納 | |
args = exec_argparse() | |
# liblinear形式のファイルを2つの配列に格納.2つの配列を含むtupleが返ってくる. | |
sys.stdout.write('loading file ... ') | |
liblinear_tuple = convert_to_numpy_array_from_liblinear_file(args.input_file) | |
print 'done' | |
# スケーリング | |
sys.stdout.write('scaling ... ') | |
zscore_array = stats.zscore(liblinear_tuple[1], axis=0, ddof=1).astype(np.float16) # スケーリング後の素性値 | |
mean_array = np.mean(liblinear_tuple[1], axis=0, dtype=np.float16) # 平均 | |
std_array = np.std(liblinear_tuple[1], axis=0, ddof=1, dtype=np.float16) # 標準偏差 | |
print 'done' | |
# スケーリング結果をファイル出力. | |
sys.stdout.write('writing ... ') | |
write_results(args.input_file, args.output_file, zscore_array, mean_array, std_array, liblinear_tuple) | |
print 'done' | |
print 'Finish.' | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment