Created
April 24, 2024 01:48
-
-
Save BuckyI/f8e5458bffaec3a195c834bdf13949e2 to your computer and use it in GitHub Desktop.
pytorch demo to classify on iris dataset #prototype
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from sklearn.datasets import load_iris | |
from sklearn.model_selection import train_test_split | |
iris = load_iris() | |
X = iris["data"] | |
y = iris["target"] | |
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) | |
X_train = torch.FloatTensor(X_train) | |
X_test = torch.FloatTensor(X_test) | |
y_train = torch.LongTensor(y_train) | |
y_test = torch.LongTensor(y_test) | |
class ANN(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.fc1 = nn.Linear(in_features=4, out_features=16) | |
self.fc2 = nn.Linear(in_features=16, out_features=12) | |
self.output = nn.Linear(in_features=12, out_features=3) | |
def forward(self, x): | |
x = F.relu(self.fc1(x)) | |
x = F.relu(self.fc2(x)) | |
x = self.output(x) | |
return x | |
model = ANN() | |
criterion = nn.CrossEntropyLoss() | |
optimizer = torch.optim.Adam(model.parameters(), lr=0.01) | |
# training | |
epochs = 100 | |
loss_arr = [] | |
for i in range(epochs): | |
y_hat = model(X_train) | |
loss = criterion(y_hat, y_train) | |
loss_arr.append(loss) | |
if i % 10 == 0: | |
print(f"Epoch: {i} Loss: {loss}") | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
plt.plot(torch.tensor(loss_arr), marker=".") | |
plt.show() | |
# testing | |
preds = [] | |
with torch.no_grad(): | |
for val in X_test: | |
y_hat = model.forward(val) | |
preds.append(y_hat.argmax().item()) | |
accuracy = (torch.tensor(preds) == y_test).sum() / len(y_test) | |
print(f"Accuracy: {accuracy}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment