Skip to content

Instantly share code, notes, and snippets.

@clungzta
Last active July 9, 2016 07:54
Show Gist options
  • Save clungzta/fa1a669f6447746b94b75a0cea5dbb14 to your computer and use it in GitHub Desktop.
Save clungzta/fa1a669f6447746b94b75a0cea5dbb14 to your computer and use it in GitHub Desktop.
Scikit Learn SVM Classification for spectrometer
#!/usr/bin/python
#import rospy
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
import time
from scipy import signal
from sklearn import datasets
from sklearn import svm
from serial.serialutil import SerialException
from serial import Serial
np.set_printoptions(precision=3, suppress=True)
min_wavelength = 340
max_wavelength = 780
spec_channels = 256
increment = (max_wavelength-min_wavelength)/float(spec_channels)
x_values = np.arange(min_wavelength,max_wavelength,increment).reshape(1,-1)
#Enable interactive plotting
plt.ion()
plt.ylim([57,100])
plt.fill(x_values[0],np.zeros(spec_channels))
#Dummy data
sinx = np.sin(np.linspace(0,np.pi,spec_channels))
cosx = np.cos(np.linspace(0,np.pi,spec_channels))
tanx = np.tan(np.linspace(0,np.pi,spec_channels))
expx = np.exp(np.linspace(0,np.pi,spec_channels))
training_dict = {
'sinx' : [sinx],
'cosx' : [cosx],
'tanx' : [tanx],
'expx' : [expx]
}
class SpectrometerSerial:
def __init__(self, port, baudrate=115200):
self.baudrate = baudrate
self.port = port
self.ser = Serial(port=self.port, baudrate=self.baudrate)
if not self.ser.isOpen():
self.ser.open()
def read_data(self):
#If data is available; read it
try:
if (self.ser.inWaiting()):
sensor_readings = [x for x in self.ser.readline().split(',')][:-1]
if len(sensor_readings) == spec_channels:
return np.asarray(sensor_readings, dtype=np.float32).reshape(1,-1)
except Exception, e:
print('failed to read data from serial port')
class SVMLearning():
def __init__(self, **kwargs):
self.clf = svm.SVC(**kwargs)
def train(self, training_dict):
keys = []
values = []
for key in training_dict.keys():
for training_set in training_dict[key]:
keys.append(key)
values.append(training_set)
self.clf.fit(values,keys)
return True
classifier = SVMLearning(gamma=0.001, C=100, probability=True)
classifier.train(training_dict)
spectrometer_serial = SpectrometerSerial('/dev/ttyACM1')
print("trained!","classifying")
#Main program loop
while (True):
sensor_reading = spectrometer_serial.read_data()
#If sensor data is valid
if sensor_reading is not None:
print(x_values, sensor_reading)
plt.ylim([57,float(np.amax(sensor_reading))+10])
plt.cla()
plt.plot(x_values[0], sensor_reading[0], 'r')
plt.plot(x_values[0], signal.medfilt(sensor_reading[0], 5), 'b')
plt.pause(0.001)
#SVM Learning prediction
predicted_item = classifier.clf.predict(sensor_reading)
print(predicted_item)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment