Skip to content

Instantly share code, notes, and snippets.

@Nekodigi
Created February 20, 2021 13:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Nekodigi/36d1b0c664ee087c8571222130765248 to your computer and use it in GitHub Desktop.
Save Nekodigi/36d1b0c664ee087c8571222130765248 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Real Time AI\n",
"## Image Capture"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import cv2\n",
"import os\n",
"\n",
"def takePictures(className):\n",
" #get video capture object\n",
" capture = cv2.VideoCapture(0)\n",
" os.mkdir(f'images/{className}/')#make a folder with the class name\n",
" i = 0\n",
" while(True):\n",
" ret, frame = capture.read()\n",
" cv2.imshow('imageCapture',frame)\n",
" k = cv2.waitKey(1)\n",
" if k == ord('t'):\n",
" frame = cv2.resize(frame, (32, 32))\n",
" cv2.imwrite(f'images/{className}/{i:04}.jpg', frame)#https://note.nkmk.me/en/python-opencv-imread-imwrite/\n",
" i+=1\n",
" print(i)\n",
" elif k == ord('q'):\n",
" break\n",
"\n",
" capture.release()\n",
" cv2.destroyAllWindows()"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1\n",
"2\n",
"3\n",
"4\n",
"5\n",
"6\n",
"7\n",
"8\n",
"9\n",
"10\n",
"11\n",
"12\n",
"13\n",
"14\n",
"15\n",
"16\n",
"17\n",
"18\n",
"19\n",
"20\n",
"21\n",
"22\n",
"23\n",
"24\n",
"25\n",
"26\n",
"27\n",
"28\n",
"29\n",
"30\n",
"31\n",
"32\n",
"33\n",
"34\n",
"35\n",
"36\n",
"37\n",
"38\n",
"39\n",
"40\n",
"41\n",
"42\n",
"43\n",
"44\n",
"45\n",
"46\n",
"47\n",
"48\n",
"49\n",
"50\n",
"51\n",
"52\n",
"53\n",
"54\n",
"55\n",
"56\n",
"57\n",
"58\n",
"59\n",
"60\n",
"61\n",
"62\n",
"63\n",
"64\n",
"65\n",
"66\n",
"67\n",
"68\n",
"69\n",
"70\n",
"71\n",
"72\n",
"73\n",
"74\n",
"75\n",
"76\n",
"77\n",
"78\n",
"79\n",
"80\n"
]
}
],
"source": [
"takePictures(\"Center\")#please set class(folder) name. !input them to className later"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create Dataset\n",
"### Get file paths"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of training datas : 200\n",
"['./images/Center/0000.jpg', './images/Center/0001.jpg', './images/Center/0002.jpg']\n",
"Number of validation datas : 55\n",
"['./images/Center/0040.jpg', './images/Center/0041.jpg', './images/Center/0042.jpg']\n"
]
}
],
"source": [
"import cv2\n",
"import os\n",
"#based on this site https://venoda.hatenablog.com/entry/2020/10/11/221117\n",
"def make_filepath_list():\n",
" train_file_list = []\n",
" test_file_list = []\n",
"\n",
" for top_dir in os.listdir('./images/'):#get each folder in image folder\n",
" file_dir = os.path.join('./images/', top_dir)\n",
" file_list = os.listdir(file_dir)\n",
" \n",
" # 80% of images are training data, 20% of images are test data\n",
" num_data = len(file_list)\n",
" num_split = int(num_data * 0.8)\n",
"\n",
" train_file_list += [os.path.join('./images', top_dir, file).replace('\\\\', '/') for file in file_list[:num_split]]\n",
" test_file_list += [os.path.join('./images', top_dir, file).replace('\\\\', '/') for file in file_list[num_split:]]\n",
" \n",
" return train_file_list, test_file_list#training file path set, test file path set\n",
"\n",
"# Get list that contains file path\n",
"train_file_list, test_file_list = make_filepath_list()\n",
"\n",
"print('Number of training datas : ', len(train_file_list))\n",
"# show head 3\n",
"print(train_file_list[:3])\n",
"\n",
"print('Number of validation datas : ', len(test_file_list))\n",
"# show head3\n",
"print(test_file_list[:3])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define image transform"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"# transform = transforms.Compose([\n",
"# transforms.RandomResizedCrop(32),\n",
"# transforms.RandomHorizontalFlip(),\n",
"# transforms.ToTensor()#,\n",
"# #transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
"# ])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create dataset"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([3, 32, 32])\n",
"1\n"
]
}
],
"source": [
"import torch.utils.data as data\n",
"from PIL import Image\n",
"import matplotlib.pyplot as plt\n",
"import torchvision.transforms as transforms\n",
"\n",
"class MyDataset(data.Dataset):#I don't use transform in this case\n",
" def __init__(self, file_list, classes, transform=None):\n",
" self.file_list = file_list\n",
" self.transform = transform\n",
" self.classes = classes\n",
" \n",
" def __len__(self):\n",
" return len(self.file_list)\n",
" \n",
" def __getitem__(self, index):\n",
" img_path = self.file_list[index]\n",
" img = Image.open(img_path)\n",
" if(self.transform != None):\n",
" img = self.transform(img)\n",
" else:\n",
" img = transforms.ToTensor()(img)#convert to tensor instead of transform\n",
" \n",
" label = img_path.split('/')[2]#get label name from path\n",
" label = self.classes.index(label)#convert label name to index using className[]\n",
" return img, label\n",
" \n",
"classNames = os.listdir('./images/')\n",
"train_dataset = MyDataset(\n",
" file_list=train_file_list,\n",
" classes = classNames\n",
")\n",
"\n",
"test_dataset = MyDataset(\n",
" file_list=test_file_list,\n",
" classes = classNames\n",
")\n",
"#check data\n",
"print(train_dataset.__getitem__(0)[0].size())\n",
"print(train_dataset.__getitem__(0)[1])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create data loader"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([64, 3, 32, 32])\n",
"tensor([4, 3, 4, 4, 1, 5, 3, 1, 1, 1, 3, 3, 3, 2, 1, 5, 2, 2, 3, 1, 5, 1, 3, 3,\n",
" 4, 1, 2, 2, 2, 2, 4, 5, 5, 5, 1, 3, 3, 3, 5, 3, 1, 3, 4, 4, 5, 4, 5, 1,\n",
" 5, 5, 5, 5, 1, 4, 1, 2, 2, 3, 2, 2, 3, 2, 2, 2])\n"
]
}
],
"source": [
"batch_size = 64\n",
"\n",
"train_loader = data.DataLoader(\n",
" train_dataset, batch_size=batch_size, shuffle=True)\n",
"\n",
"test_loader = data.DataLoader(\n",
" test_dataset, batch_size=batch_size, shuffle=False)#!batch_size=1\n",
"\n",
"#check data loader\n",
"batch_iterator = iter(train_loader)\n",
"inputs, labels = next(batch_iterator)\n",
"print(inputs.size())\n",
"print(labels)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define Network"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"class Net(nn.Module):\n",
" def __init__(self):\n",
" super(Net, self).__init__()\n",
" self.conv1 = nn.Conv2d(3, 6, 5)\n",
" self.pool = nn.MaxPool2d(2, 2)\n",
" self.conv2 = nn.Conv2d(6, 16, 5)\n",
" self.fc1 = nn.Linear(16*5*5, 120)\n",
" self.fc2 = nn.Linear(120, 84)\n",
" self.fc3 = nn.Linear(84, len(classNames))\n",
" \n",
" def forward(self, x):\n",
" x = self.pool(F.relu(self.conv1(x)))\n",
" x = self.pool(F.relu(self.conv2(x)))\n",
" x = x.view(-1, 16*5*5)\n",
" x = F.relu(self.fc1(x))\n",
" x = F.relu(self.fc2(x))\n",
" x = self.fc3(x)\n",
" x = F.softmax(x, dim=1)\n",
" return x\n",
" \n",
"net = Net()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define optimizer"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"import torch.optim as optim\n",
"\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)#usually lr=0.001 but I set lr=0.01 to speed up"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Train network"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 1.7905797672271728 20.0\n",
"1 1.7902201557159423 20.0\n",
"2 1.7895686388015748 20.0\n",
"3 1.7887357711791991 20.0\n",
"4 1.7878109264373778 20.0\n",
"5 1.7868560791015624 20.0\n"
]
}
],
"source": [
"import torch\n",
"\n",
"testset_size = len(test_dataset)\n",
"trainset_size = len(train_dataset)\n",
"\n",
"for epoch in range(200):\n",
" running_loss = 0.0\n",
" for i, data in enumerate(train_loader, 0):\n",
" inputs, labels = data\n",
" optimizer.zero_grad()#reset parameters gradient\n",
" outputs = net(inputs)\n",
" loss = criterion(outputs, labels)\n",
" loss.backward()\n",
" optimizer.step()\n",
" #visualize statistic information\n",
" running_loss += loss.item()*inputs.size(0)#inputs.size(0) is batch size in this time\n",
" \n",
" \n",
" correct = 0\n",
" for i, data in enumerate(test_loader, 0):\n",
" inputs, labels = data\n",
" outputs = net(inputs)\n",
" #get correct outputs\n",
" preds = torch.argmax(outputs, 1)#get prediction label index\n",
" correct += torch.sum(preds == labels).numpy()#add length of correct outputs\n",
" accuracy = correct/float(testset_size)*100#calcualte accuracy by dividing number of corrects with dataset length\n",
" print(epoch, running_loss/trainset_size, accuracy)\n",
"print(\"Finished Training\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Test with test data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"dataiter = iter(test_loader)\n",
"images, labels = dataiter.next()#get first batch of test dataset\n",
"index = 0#get index'th element of batch\n",
"image = images[index]\n",
"label = labels[index]\n",
"#show image\n",
"plt.imshow(np.swapaxes(image.numpy(),0,2))\n",
"#show grandtruth and prediction\n",
"print(\"Grand Truth:\"+classNames[label])\n",
"print(\"Prediction:\"+classNames[net(images)[index].argmax()])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Real Time Prediction"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import cv2\n",
"\n",
"def cameraPredict():\n",
" #get video capture object\n",
" capture = cv2.VideoCapture(0)\n",
" while(True):\n",
" ret, frame = capture.read()\n",
" #prediction\n",
" frameInput = cv2.resize(frame, (32, 32))\n",
" l = net(transforms.ToTensor()(frameInput).unsqueeze(0))#predict and get label index\n",
" predLabel = classNames[l.argmax()]#predicted label\n",
" frame = cv2.putText(frame, str(predLabel), (50, 100) , cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 0) , 4, cv2.LINE_AA)#show label\n",
" k = cv2.waitKey(1)#get keyboard input\n",
" if k == ord('q'):#quit\n",
" break\n",
" \n",
" cv2.imshow('frame',frame)\n",
" capture.release()\n",
" cv2.destroyAllWindows()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cameraPredict()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment