Created
February 20, 2021 13:06
-
-
Save Nekodigi/36d1b0c664ee087c8571222130765248 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Real Time AI\n", | |
"## Image Capture" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import cv2\n", | |
"import os\n", | |
"\n", | |
"def takePictures(className):\n", | |
" #get video capture object\n", | |
" capture = cv2.VideoCapture(0)\n", | |
" os.mkdir(f'images/{className}/')#make a folder with the class name\n", | |
" i = 0\n", | |
" while(True):\n", | |
" ret, frame = capture.read()\n", | |
" cv2.imshow('imageCapture',frame)\n", | |
" k = cv2.waitKey(1)\n", | |
" if k == ord('t'):\n", | |
" frame = cv2.resize(frame, (32, 32))\n", | |
" cv2.imwrite(f'images/{className}/{i:04}.jpg', frame)#https://note.nkmk.me/en/python-opencv-imread-imwrite/\n", | |
" i+=1\n", | |
" print(i)\n", | |
" elif k == ord('q'):\n", | |
" break\n", | |
"\n", | |
" capture.release()\n", | |
" cv2.destroyAllWindows()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"1\n", | |
"2\n", | |
"3\n", | |
"4\n", | |
"5\n", | |
"6\n", | |
"7\n", | |
"8\n", | |
"9\n", | |
"10\n", | |
"11\n", | |
"12\n", | |
"13\n", | |
"14\n", | |
"15\n", | |
"16\n", | |
"17\n", | |
"18\n", | |
"19\n", | |
"20\n", | |
"21\n", | |
"22\n", | |
"23\n", | |
"24\n", | |
"25\n", | |
"26\n", | |
"27\n", | |
"28\n", | |
"29\n", | |
"30\n", | |
"31\n", | |
"32\n", | |
"33\n", | |
"34\n", | |
"35\n", | |
"36\n", | |
"37\n", | |
"38\n", | |
"39\n", | |
"40\n", | |
"41\n", | |
"42\n", | |
"43\n", | |
"44\n", | |
"45\n", | |
"46\n", | |
"47\n", | |
"48\n", | |
"49\n", | |
"50\n", | |
"51\n", | |
"52\n", | |
"53\n", | |
"54\n", | |
"55\n", | |
"56\n", | |
"57\n", | |
"58\n", | |
"59\n", | |
"60\n", | |
"61\n", | |
"62\n", | |
"63\n", | |
"64\n", | |
"65\n", | |
"66\n", | |
"67\n", | |
"68\n", | |
"69\n", | |
"70\n", | |
"71\n", | |
"72\n", | |
"73\n", | |
"74\n", | |
"75\n", | |
"76\n", | |
"77\n", | |
"78\n", | |
"79\n", | |
"80\n" | |
] | |
} | |
], | |
"source": [ | |
"takePictures(\"Center\")#please set class(folder) name. !input them to className later" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Create Dataset\n", | |
"### Get file paths" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Number of training datas : 200\n", | |
"['./images/Center/0000.jpg', './images/Center/0001.jpg', './images/Center/0002.jpg']\n", | |
"Number of validation datas : 55\n", | |
"['./images/Center/0040.jpg', './images/Center/0041.jpg', './images/Center/0042.jpg']\n" | |
] | |
} | |
], | |
"source": [ | |
"import cv2\n", | |
"import os\n", | |
"#based on this site https://venoda.hatenablog.com/entry/2020/10/11/221117\n", | |
"def make_filepath_list():\n", | |
" train_file_list = []\n", | |
" test_file_list = []\n", | |
"\n", | |
" for top_dir in os.listdir('./images/'):#get each folder in image folder\n", | |
" file_dir = os.path.join('./images/', top_dir)\n", | |
" file_list = os.listdir(file_dir)\n", | |
" \n", | |
" # 80% of images are training data, 20% of images are test data\n", | |
" num_data = len(file_list)\n", | |
" num_split = int(num_data * 0.8)\n", | |
"\n", | |
" train_file_list += [os.path.join('./images', top_dir, file).replace('\\\\', '/') for file in file_list[:num_split]]\n", | |
" test_file_list += [os.path.join('./images', top_dir, file).replace('\\\\', '/') for file in file_list[num_split:]]\n", | |
" \n", | |
" return train_file_list, test_file_list#training file path set, test file path set\n", | |
"\n", | |
"# Get list that contains file path\n", | |
"train_file_list, test_file_list = make_filepath_list()\n", | |
"\n", | |
"print('Number of training datas : ', len(train_file_list))\n", | |
"# show head 3\n", | |
"print(train_file_list[:3])\n", | |
"\n", | |
"print('Number of validation datas : ', len(test_file_list))\n", | |
"# show head3\n", | |
"print(test_file_list[:3])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Define image transform" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# transform = transforms.Compose([\n", | |
"# transforms.RandomResizedCrop(32),\n", | |
"# transforms.RandomHorizontalFlip(),\n", | |
"# transforms.ToTensor()#,\n", | |
"# #transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n", | |
"# ])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Create dataset" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"torch.Size([3, 32, 32])\n", | |
"1\n" | |
] | |
} | |
], | |
"source": [ | |
"import torch.utils.data as data\n", | |
"from PIL import Image\n", | |
"import matplotlib.pyplot as plt\n", | |
"import torchvision.transforms as transforms\n", | |
"\n", | |
"class MyDataset(data.Dataset):#I don't use transform in this case\n", | |
" def __init__(self, file_list, classes, transform=None):\n", | |
" self.file_list = file_list\n", | |
" self.transform = transform\n", | |
" self.classes = classes\n", | |
" \n", | |
" def __len__(self):\n", | |
" return len(self.file_list)\n", | |
" \n", | |
" def __getitem__(self, index):\n", | |
" img_path = self.file_list[index]\n", | |
" img = Image.open(img_path)\n", | |
" if(self.transform != None):\n", | |
" img = self.transform(img)\n", | |
" else:\n", | |
" img = transforms.ToTensor()(img)#convert to tensor instead of transform\n", | |
" \n", | |
" label = img_path.split('/')[2]#get label name from path\n", | |
" label = self.classes.index(label)#convert label name to index using className[]\n", | |
" return img, label\n", | |
" \n", | |
"classNames = os.listdir('./images/')\n", | |
"train_dataset = MyDataset(\n", | |
" file_list=train_file_list,\n", | |
" classes = classNames\n", | |
")\n", | |
"\n", | |
"test_dataset = MyDataset(\n", | |
" file_list=test_file_list,\n", | |
" classes = classNames\n", | |
")\n", | |
"#check data\n", | |
"print(train_dataset.__getitem__(0)[0].size())\n", | |
"print(train_dataset.__getitem__(0)[1])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Create data loader" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"torch.Size([64, 3, 32, 32])\n", | |
"tensor([4, 3, 4, 4, 1, 5, 3, 1, 1, 1, 3, 3, 3, 2, 1, 5, 2, 2, 3, 1, 5, 1, 3, 3,\n", | |
" 4, 1, 2, 2, 2, 2, 4, 5, 5, 5, 1, 3, 3, 3, 5, 3, 1, 3, 4, 4, 5, 4, 5, 1,\n", | |
" 5, 5, 5, 5, 1, 4, 1, 2, 2, 3, 2, 2, 3, 2, 2, 2])\n" | |
] | |
} | |
], | |
"source": [ | |
"batch_size = 64\n", | |
"\n", | |
"train_loader = data.DataLoader(\n", | |
" train_dataset, batch_size=batch_size, shuffle=True)\n", | |
"\n", | |
"test_loader = data.DataLoader(\n", | |
" test_dataset, batch_size=batch_size, shuffle=False)#!batch_size=1\n", | |
"\n", | |
"#check data loader\n", | |
"batch_iterator = iter(train_loader)\n", | |
"inputs, labels = next(batch_iterator)\n", | |
"print(inputs.size())\n", | |
"print(labels)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Define Network" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch.nn as nn\n", | |
"import torch.nn.functional as F\n", | |
"\n", | |
"class Net(nn.Module):\n", | |
" def __init__(self):\n", | |
" super(Net, self).__init__()\n", | |
" self.conv1 = nn.Conv2d(3, 6, 5)\n", | |
" self.pool = nn.MaxPool2d(2, 2)\n", | |
" self.conv2 = nn.Conv2d(6, 16, 5)\n", | |
" self.fc1 = nn.Linear(16*5*5, 120)\n", | |
" self.fc2 = nn.Linear(120, 84)\n", | |
" self.fc3 = nn.Linear(84, len(classNames))\n", | |
" \n", | |
" def forward(self, x):\n", | |
" x = self.pool(F.relu(self.conv1(x)))\n", | |
" x = self.pool(F.relu(self.conv2(x)))\n", | |
" x = x.view(-1, 16*5*5)\n", | |
" x = F.relu(self.fc1(x))\n", | |
" x = F.relu(self.fc2(x))\n", | |
" x = self.fc3(x)\n", | |
" x = F.softmax(x, dim=1)\n", | |
" return x\n", | |
" \n", | |
"net = Net()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Define optimizer" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch.optim as optim\n", | |
"\n", | |
"criterion = nn.CrossEntropyLoss()\n", | |
"optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)#usually lr=0.001 but I set lr=0.01 to speed up" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Train network" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"0 1.7905797672271728 20.0\n", | |
"1 1.7902201557159423 20.0\n", | |
"2 1.7895686388015748 20.0\n", | |
"3 1.7887357711791991 20.0\n", | |
"4 1.7878109264373778 20.0\n", | |
"5 1.7868560791015624 20.0\n" | |
] | |
} | |
], | |
"source": [ | |
"import torch\n", | |
"\n", | |
"testset_size = len(test_dataset)\n", | |
"trainset_size = len(train_dataset)\n", | |
"\n", | |
"for epoch in range(200):\n", | |
" running_loss = 0.0\n", | |
" for i, data in enumerate(train_loader, 0):\n", | |
" inputs, labels = data\n", | |
" optimizer.zero_grad()#reset parameters gradient\n", | |
" outputs = net(inputs)\n", | |
" loss = criterion(outputs, labels)\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
" #visualize statistic information\n", | |
" running_loss += loss.item()*inputs.size(0)#inputs.size(0) is batch size in this time\n", | |
" \n", | |
" \n", | |
" correct = 0\n", | |
" for i, data in enumerate(test_loader, 0):\n", | |
" inputs, labels = data\n", | |
" outputs = net(inputs)\n", | |
" #get correct outputs\n", | |
" preds = torch.argmax(outputs, 1)#get prediction label index\n", | |
" correct += torch.sum(preds == labels).numpy()#add length of correct outputs\n", | |
" accuracy = correct/float(testset_size)*100#calcualte accuracy by dividing number of corrects with dataset length\n", | |
" print(epoch, running_loss/trainset_size, accuracy)\n", | |
"print(\"Finished Training\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Test with test data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"\n", | |
"dataiter = iter(test_loader)\n", | |
"images, labels = dataiter.next()#get first batch of test dataset\n", | |
"index = 0#get index'th element of batch\n", | |
"image = images[index]\n", | |
"label = labels[index]\n", | |
"#show image\n", | |
"plt.imshow(np.swapaxes(image.numpy(),0,2))\n", | |
"#show grandtruth and prediction\n", | |
"print(\"Grand Truth:\"+classNames[label])\n", | |
"print(\"Prediction:\"+classNames[net(images)[index].argmax()])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Real Time Prediction" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import cv2\n", | |
"\n", | |
"def cameraPredict():\n", | |
" #get video capture object\n", | |
" capture = cv2.VideoCapture(0)\n", | |
" while(True):\n", | |
" ret, frame = capture.read()\n", | |
" #prediction\n", | |
" frameInput = cv2.resize(frame, (32, 32))\n", | |
" l = net(transforms.ToTensor()(frameInput).unsqueeze(0))#predict and get label index\n", | |
" predLabel = classNames[l.argmax()]#predicted label\n", | |
" frame = cv2.putText(frame, str(predLabel), (50, 100) , cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 0) , 4, cv2.LINE_AA)#show label\n", | |
" k = cv2.waitKey(1)#get keyboard input\n", | |
" if k == ord('q'):#quit\n", | |
" break\n", | |
" \n", | |
" cv2.imshow('frame',frame)\n", | |
" capture.release()\n", | |
" cv2.destroyAllWindows()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"cameraPredict()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment