Skip to content

Instantly share code, notes, and snippets.

@micmn
Created June 22, 2017 11:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save micmn/2b2d61f08b8e303d8d7767f3cddcab3f to your computer and use it in GitHub Desktop.
Save micmn/2b2d61f08b8e303d8d7767f3cddcab3f to your computer and use it in GitHub Desktop.
'''
@file lda.py
@author Michele Mazzoni
LDA Classifier with shogun.
'''
import os
import sys
import inspect
# Import the util path, this method even works if the path contains symlinks to
# modules.
cmd_subfolder = os.path.realpath(os.path.abspath(os.path.join(
os.path.split(inspect.getfile(inspect.currentframe()))[0], "../../util")))
if cmd_subfolder not in sys.path:
sys.path.insert(0, cmd_subfolder)
#Import the metrics definitions path.
metrics_folder = os.path.realpath(os.path.abspath(os.path.join(
os.path.split(inspect.getfile(inspect.currentframe()))[0], "../metrics")))
if metrics_folder not in sys.path:
sys.path.insert(0, metrics_folder)
from log import *
from timer import *
from definitions import *
from misc import *
import numpy as np
import modshogun
'''
This class implements the LDA Classifier benchmark.
'''
class LDA(object):
'''
Create the LDA Classifier benchmark instance.
@param dataset - Input dataset to perform LDA on.
@param timeout - The time until the timeout. Default no timeout.
@param verbose - Display informational messages.
'''
def __init__(self, dataset, timeout=0, verbose=True):
self.verbose = verbose
self.dataset = dataset
self.timeout = timeout
'''
Use the shogun libary to implement LDA Classifier.
@param options - Extra options for the method.
@return - Elapsed time in seconds or a negative value if the method was not
successful.
'''
def LDAShogun(self, options):
def RunLDAShogun(q):
totalTimer = Timer()
Log.Info("Loading dataset", self.verbose)
try:
# Load train and test dataset.
trainData = np.genfromtxt(self.dataset[0], delimiter=',')
trainFeat = modshogun.RealFeatures(trainData[:,:-1].T)
if len(self.dataset) == 2:
testSet = np.genfromtxt(self.dataset[1], delimiter=',')
testFeat = modshogun.RealFeatures(testData.T)
# Labels are the last row of the training set.
labelsData = trainData[:, (trainData.shape[1] - 1)]
if min(labelsData) > 0:
labelsData -= min(labelsData)
labels = modshogun.MulticlassLabels(labelsData)
with totalTimer:
model = modshogun.MCLDA(trainFeat, labels)
model.train()
if len(self.dataset) == 2:
model.apply(testFeat).get_labels()
except Exception as e:
q.put(-1)
return -1
time = totalTimer.ElapsedTime()
q.put(time)
return time
return timeout(RunLDAShogun, self.timeout)
'''
Perform LDA Classifier. If the method has been successfully completed
return the elapsed time in seconds.
@param options - Extra options for the method.
@return - Elapsed time in seconds or a negative value if the method was not
successful.
'''
def RunMetrics(self, options):
Log.Info("Perform LDA.", self.verbose)
results = self.LDAShogun(options)
if results < 0:
return results
def test(q):
trainData, labels = SplitTrainData(self.dataset)
testData = LoadDataset(self.dataset[1])
truelabels = LoadDataset(self.dataset[2])
if min(labels) > 0:
labels -= min(labels)
if min(truelabels) > 0:
truelabels -= min(truelabels)
model = modshogun.MCLDA(modshogun.RealFeatures(trainData.T),modshogun.MulticlassLabels(labels))
model.train()
predictions = model.apply(modshogun.RealFeatures(testData.T)).get_labels()
confusionMatrix = Metrics.ConfusionMatrix(truelabels, predictions)
metrics = {}
metrics['ACC'] = Metrics.AverageAccuracy(confusionMatrix)
metrics['MCC'] = Metrics.MCCMultiClass(confusionMatrix)
metrics['Precision'] = Metrics.AvgPrecision(confusionMatrix)
metrics['Recall'] = Metrics.AvgRecall(confusionMatrix)
metrics['MSE'] = Metrics.SimpleMeanSquaredError(truelabels, predictions)
q.put(metrics)
metrics = {'Runtime': results}
if len(self.dataset) >= 3:
q = Queue()
p = Process(target=test, args=(q,))
p.start()
p.join()
metrics.update(q.get())
return metrics
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment