Skip to content

Instantly share code, notes, and snippets.

@siddheshgunjal
Created October 14, 2023 03:42
Show Gist options
  • Save siddheshgunjal/ff7c2b2ee0d98b66245e1efee258a6fa to your computer and use it in GitHub Desktop.
Save siddheshgunjal/ff7c2b2ee0d98b66245e1efee258a6fa to your computer and use it in GitHub Desktop.
Download Training images for your ML task with selenium
import selenium
from selenium import webdriver
import os
import time
import requests
from PIL import Image
import io
import hashlib
# Put the path for your ChromeDriver here
DRIVER_PATH = 'chromedriver-linux64/chromedriver' ## For Linux
# DRIVER_PATH = 'chromedriver-windows64/chromedriver.exe' ## For Windows
def fetch_image_urls(query:str, max_links_to_fetch:int, wd:webdriver, sleep_between_interactions:int=1):
def scroll_to_end(wd):
wd.execute_script("window.scrollTo(0, document.body.scrollHeight);")
time.sleep(sleep_between_interactions)
# build the google query
search_url = "https://www.google.com/search?safe=off&site=&tbm=isch&source=hp&q={q}&oq={q}&gs_l=img"
# load the page
wd.get(search_url.format(q=query))
image_urls = set()
image_count = 0
results_start = 0
scroll_to_end(wd)
while image_count < max_links_to_fetch:
# scroll_to_end(wd)
# get all image thumbnail results
thumbnail_results = wd.find_elements_by_css_selector("img.Q4LuWd")
number_results = len(thumbnail_results)
print(f"Found: {number_results} search results. Extracting links from {results_start}:{number_results}")
for img in thumbnail_results[results_start:number_results]:
# try to click every thumbnail such that we can get the real image behind it
try:
img.click()
time.sleep(sleep_between_interactions)
except Exception:
continue
# extract image urls
actual_images = wd.find_elements_by_css_selector('img.r48jcc')
for actual_image in actual_images:
if actual_image.get_attribute('src') and 'http' in actual_image.get_attribute('src'):
image_urls.add(actual_image.get_attribute('src'))
image_count = len(image_urls)
# image_count += 1
if image_count >= max_links_to_fetch:
print(f"Found: {len(image_urls)} image links, done!")
break
if image_count < max_links_to_fetch:
print("Found:", image_count, "image links, looking for more ...")
load_more_button = wd.find_element_by_css_selector(".LZ4I")
if load_more_button:
wd.execute_script("document.querySelector('.LZ4I').click();")
time.sleep(5)
scroll_to_end(wd)
# move the result startpoint further down
results_start = len(thumbnail_results)
# print(image_urls)
return image_urls
def persist_image(folder_path:str,url:str):
try:
image_content = requests.get(url).content
except Exception as e:
print(f"ERROR - Could not download {url} - {e}")
try:
image_file = io.BytesIO(image_content)
image = Image.open(image_file).convert('RGB')
file_path = os.path.join(folder_path,hashlib.sha1(image_content).hexdigest()[:10] + '.jpg')
with open(file_path, 'wb') as f:
image.save(f, "JPEG", quality=85)
print(f"SUCCESS - saved {url} - as {file_path}")
except Exception as e:
print(f"ERROR - Could not save {url} - {e}")
def search_and_download(search_term:str,driver_path:str,target_path='./Data/Train_Data',number_images=5):
target_folder = os.path.join(target_path,'_'.join(search_term.lower().split(' ')))
if not os.path.exists(target_folder):
os.makedirs(target_folder)
with webdriver.Chrome(executable_path=driver_path) as wd:
res = fetch_image_urls(search_term, number_images, wd=wd, sleep_between_interactions=0.5)
for elem in res:
persist_image(target_folder, elem)
if __name__ == '__main__':
search_term = input("Enter your search term: ")
count = int(input("How many images to download?: "))
search_and_download(search_term=search_term, driver_path=DRIVER_PATH, number_images=count)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment