Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Script to Generate Image Classification Dataset from Google Images
from selenium import webdriver
from selenium.webdriver.firefox.options import Options
import time
import urllib.request
from PIL import Image
import os
def cr_folder(folder_name):
if not os.path.exists(folder_name):
os.makedirs(folder_name)
# --Some important variables--
# Number of seconds to wait after scroll ('2' - Works for my system)
sleep_time = 2
# Number of minimum images needed for extraction (Mostly it's a multiple of 100)
# This is a very buggy feature, hit and trial for now.
n_min = 500
# Button Class (Some sites, need clicking a button to further load the images)
btn_class = "mye4qd"
# Image tile class, in order to get images from the site
img_class = 'rg_i Q4LuWd tx8vtf'
# Classes
classes = ['Torterra', 'Chimchar', 'Monferno', 'Greninja', 'Bidoof']
# Name of the parent folder to create
parent_folder = "Pokemon_5_Dataset"
# Size of the images to save
im_size = (80,80)
cr_folder(parent_folder)
options = Options()
options.headless = True
url_dict = dict()
for c in classes:
print(f'Processing for class: {c}')
class_folder = parent_folder + '/' + c
cr_folder(class_folder)
browser = webdriver.Firefox(options=options)
browser.get(f'https://www.google.com/search?tbm=isch&q={c}')
ads = browser.find_elements_by_xpath(f"//img[@class='{img_class}']")
n = len(ads)
while n < n_min:
print('Scrolling Down!!')
browser.execute_script("window.scrollTo(0, document.body.scrollHeight);")
time.sleep(sleep_time)
ads = browser.find_elements_by_xpath(f"//img[@class='{img_class}']")
n_last = n
n = len(ads)
print(n)
if n == n_last:
btn = browser.find_elements_by_xpath(f"//input[@class='{btn_class}']")[0]
browser.execute_script("arguments[0].click();",btn)
time.sleep(sleep_time)
ads = browser.find_elements_by_xpath(f"//img[@class='{img_class}']")
n = len(ads)
if n == n_last:
break
failed_count = 0
for i, element in enumerate(ads):
print(element.get_attribute('alt'))
url = element.get_attribute('src')
if url == None:
url = element.get_attribute('data-src')
print(url)
if url != None:
try:
image = Image.open(urllib.request.urlopen(url))
image = image.resize(im_size)
image.save(class_folder + '/' + f"{i-failed_count}.jpg")
except:
failed_count += 1
print(image.size)
print('_________________')
print(len(ads))
print("Failed attempts: ", failed_count)
browser.quit()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.