Created
May 9, 2020 16:24
-
-
Save siddharththakur26/235a394e2d321816c3946689b7c3e01a to your computer and use it in GitHub Desktop.
TestDome-DataScience
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from sklearn import datasets,svm | |
from sklearn.model_selection import train_test_split | |
from sklearn import metrics | |
def train_and_predict(train_input_features, train_outputs, prediction_features): | |
""" | |
:param train_input_features: (numpy.array) A two-dimensional NumPy array where each element | |
is an array that contains: sepal length, sepal width, petal length, and petal width | |
:param train_outputs: (numpy.array) A one-dimensional NumPy array where each element | |
is a number representing the species of iris which is described in | |
the same row of train_input_features. 0 represents Iris setosa, | |
1 represents Iris versicolor, and 2 represents Iris virginica. | |
:param prediction_features: (numpy.array) A two-dimensional NumPy array where each element | |
is an array that contains: sepal length, sepal width, petal length, and petal width | |
:returns: (list) The function should return an iterable (like list or numpy.ndarray) of the predicted | |
iris species, one for each item in prediction_features | |
""" | |
model = svm.SVC(random_state=0) | |
model.fit(train_input_features,train_outputs) | |
return model.predict(prediction_features) | |
iris = datasets.load_iris() | |
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, | |
test_size=0.3, random_state=0) | |
y_pred = train_and_predict(X_train, y_train, X_test) | |
if y_pred is not None: | |
print(metrics.accuracy_score(y_test, y_pred)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment