Skip to content

Instantly share code, notes, and snippets.

@Ehsan1997
Created March 25, 2020 11:14
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Ehsan1997/dce2cbc529f9b3a9b82a70c8e6eb3bdd to your computer and use it in GitHub Desktop.
Save Ehsan1997/dce2cbc529f9b3a9b82a70c8e6eb3bdd to your computer and use it in GitHub Desktop.
Script to Generate Image Classification Dataset from Google Images
from selenium import webdriver
from selenium.webdriver.firefox.options import Options
import time
import urllib.request
from PIL import Image
import os
def cr_folder(folder_name):
if not os.path.exists(folder_name):
os.makedirs(folder_name)
# --Some important variables--
# Number of seconds to wait after scroll ('2' - Works for my system)
sleep_time = 2
# Number of minimum images needed for extraction (Mostly it's a multiple of 100)
# This is a very buggy feature, hit and trial for now.
n_min = 500
# Button Class (Some sites, need clicking a button to further load the images)
btn_class = "mye4qd"
# Image tile class, in order to get images from the site
img_class = 'rg_i Q4LuWd tx8vtf'
# Classes
classes = ['Torterra', 'Chimchar', 'Monferno', 'Greninja', 'Bidoof']
# Name of the parent folder to create
parent_folder = "Pokemon_5_Dataset"
# Size of the images to save
im_size = (80,80)
cr_folder(parent_folder)
options = Options()
options.headless = True
url_dict = dict()
for c in classes:
print(f'Processing for class: {c}')
class_folder = parent_folder + '/' + c
cr_folder(class_folder)
browser = webdriver.Firefox(options=options)
browser.get(f'https://www.google.com/search?tbm=isch&q={c}')
ads = browser.find_elements_by_xpath(f"//img[@class='{img_class}']")
n = len(ads)
while n < n_min:
print('Scrolling Down!!')
browser.execute_script("window.scrollTo(0, document.body.scrollHeight);")
time.sleep(sleep_time)
ads = browser.find_elements_by_xpath(f"//img[@class='{img_class}']")
n_last = n
n = len(ads)
print(n)
if n == n_last:
btn = browser.find_elements_by_xpath(f"//input[@class='{btn_class}']")[0]
browser.execute_script("arguments[0].click();",btn)
time.sleep(sleep_time)
ads = browser.find_elements_by_xpath(f"//img[@class='{img_class}']")
n = len(ads)
if n == n_last:
break
failed_count = 0
for i, element in enumerate(ads):
print(element.get_attribute('alt'))
url = element.get_attribute('src')
if url == None:
url = element.get_attribute('data-src')
print(url)
if url != None:
try:
image = Image.open(urllib.request.urlopen(url))
image = image.resize(im_size)
image.save(class_folder + '/' + f"{i-failed_count}.jpg")
except:
failed_count += 1
print(image.size)
print('_________________')
print(len(ads))
print("Failed attempts: ", failed_count)
browser.quit()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment