Last active
April 25, 2022 09:19
-
-
Save charbonnierg/cd003e653a8cd900ee7ea8aa7a5b6121 to your computer and use it in GitHub Desktop.
Az Custom Vision file downloader
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
import json | |
import urllib.request as request | |
from concurrent.futures import ThreadPoolExecutor | |
from pathlib import Path | |
from typing import Any, Dict, Iterator, List, Optional, Set, Union | |
from azure.cognitiveservices.vision.customvision.training import ( | |
CustomVisionTrainingClient, | |
) | |
from msrest.authentication import ApiKeyCredentials | |
from pydantic import BaseModel, BaseSettings, Field | |
from structlog import get_logger | |
logger = get_logger() | |
class FilesSettings(BaseSettings, case_sensitive=False, env_prefix="vision_"): | |
"""A class used to store file settings, I.E, files where settings are defined. | |
It's different from parsing settings itself ! | |
""" | |
config_file: Optional[str] = None | |
training_key_file: Optional[str] = None | |
project_file: Optional[str] = None | |
class DownloaderSettings(BaseSettings, case_sensitive=False, env_prefix="vision_"): | |
"""A class used to store settings. | |
NOTE: Training key must never be written in a script or a module. It must be parsed from environment variable or file. | |
""" | |
endpoint: str = "https://quara-poc-vision.cognitiveservices.azure.com/" | |
training_key: str = "SECRET" | |
project: str = "1cc5c053-39ed-4693-8e85-ff5096b9ea1c" | |
output: str = "./customvision/annotated_datasets/poc-vision" | |
def get_settings( | |
settings: Optional[DownloaderSettings] = None, | |
config_file: Union[Path, str, None] = None, | |
) -> DownloaderSettings: | |
"""A helper functions to parse settings either from environment or from files""" | |
files_settings = FilesSettings(config_file=config_file) | |
# Parse settings from config file | |
if files_settings.config_file: | |
parsed_settings = DownloaderSettings.parse_file(files_settings.config_file) | |
# Parse settings from env | |
else: | |
parsed_settings = DownloaderSettings() | |
# Override key from file | |
if files_settings.training_key_file: | |
parsed_settings.training_key = ( | |
Path(files_settings.training_key_file).read_text().splitlines(False)[0] | |
) | |
# Override project from file | |
if files_settings.project_file: | |
parsed_settings.project = ( | |
Path(files_settings.project_file).read_text().splitlines(False)[0] | |
) | |
# Override parsed settings | |
if settings: | |
return parsed_settings.copy(update=settings.dict(exclude_unset=True)) | |
# Return settings | |
return parsed_settings | |
class AnnotatedRegion(BaseModel): | |
"""Metadata found for annotated region within an image""" | |
tag_name: str | |
left: float | |
top: float | |
width: float | |
height: float | |
class Config: | |
schema_extra = { | |
"examples": [ | |
{ | |
"tag_name": "Closed", | |
"left": 0.00104166672, | |
"top": 0.00208333344, | |
"width": 0.288541675, | |
"height": 0.5104167, | |
} | |
] | |
} | |
class AnnotatedImage(BaseModel): | |
"""Metadata found for an annotated image. | |
Image can be accessed using .image property. | |
""" | |
filename: str = Field(..., alias="FileName") | |
regions: Optional[List[AnnotatedRegion]] = None | |
@property | |
def image(self) -> bytes: | |
return Path(self.filename).read_bytes() | |
class Config: | |
allow_population_by_field_name = True | |
schema_extra = { | |
"examples": [ | |
{ | |
"fileName": "0001.jpg", | |
"regions": [ | |
{ | |
"tag_name": "Closed", | |
"left": 0.00104166672, | |
"top": 0.00208333344, | |
"width": 0.288541675, | |
"height": 0.5104167, | |
} | |
], | |
} | |
] | |
} | |
def download_file( | |
name: str, | |
image: Any, | |
output_dir: Path, | |
all_tags: Optional[Set[str]] = None, | |
step: Optional[str] = None, | |
) -> None: | |
"""Download a single file""" | |
image_filepath = output_dir / f"{name}.jpg" | |
metadata_filepath = output_dir / f"{name}.json" | |
# Initialize set of tags | |
tags: Set[str] = set() | |
# Initialize an empty list of regions | |
regions: List[Dict[str, Any]] = [] | |
# Check if image contains regions | |
if image.regions is not None: | |
for region_data in image.regions: | |
# Add tag | |
tags.add(region_data.tag_name) | |
# Format the region | |
region = { | |
"tag_name": region_data.tag_name, | |
"left": region_data.left, | |
"top": region_data.top, | |
"width": region_data.width, | |
"height": region_data.height, | |
} | |
# Append the region | |
regions.append(region) | |
# Gather annotations | |
annotations = AnnotatedImage(filename=image_filepath.as_posix(), regions=regions) | |
# Write annotations to file | |
Path(metadata_filepath).write_text(annotations.json(by_alias=True)) | |
# Write image to file | |
request.urlretrieve(url=image.original_image_uri, filename=image_filepath) | |
# Append to all tags | |
if all_tags is not None: | |
all_tags.update(tags) | |
# Leave a log message only if step is known | |
if step is not None: | |
logger.debug( | |
"Downloaded image", | |
path=image_filepath.as_posix(), | |
step=step, | |
) | |
def download_files( | |
settings: Optional[DownloaderSettings] = None, | |
config_file: Union[str, Path, None] = None, | |
executor: Optional[ThreadPoolExecutor] = None, | |
) -> Path: | |
"""Download files and return path to directory where files are stored. | |
This function uses a ThreadPoolExecutor to download several images at the same time. | |
""" | |
# Create executor if it does not exist | |
executor = executor or ThreadPoolExecutor() | |
# Create settings if they do not exist | |
settings = get_settings(settings, config_file=config_file) | |
# Make sure output directory exists | |
output_dir = Path(settings.output).absolute() | |
if not output_dir.exists(): | |
logger.info("Creating output directory", path=output_dir.as_posix()) | |
output_dir.mkdir(parents=True, exist_ok=True) | |
else: | |
logger.info("Using existing output directory", path=output_dir.as_posix()) | |
# Create azure SDK resources | |
credentials = ApiKeyCredentials(in_headers={"Training-key": settings.training_key}) | |
trainer = CustomVisionTrainingClient(settings.endpoint, credentials) | |
# Use snake_case variables | |
images_count: Optional[int] = trainer.get_tagged_image_count(settings.project) | |
# Make sure we've got an image count | |
if images_count is None: | |
raise ValueError( | |
"Could not determine total number of images in custom vision workspace." | |
) | |
precision = len(str(images_count)) | |
# Create a new set of tags | |
all_tags: Set[str] = set() | |
with executor: | |
# Iterate over all images indexes | |
for idx in range(images_count): | |
# Download image 1 by 1 | |
for image in trainer.get_tagged_images( | |
project_id=settings.project, take=1, skip=idx | |
): | |
# Generate formatted index | |
current_file_index = format(idx + 1, f"0{precision}") | |
# Download and parse image data | |
executor.submit( | |
download_file, | |
current_file_index, | |
image, | |
output_dir, | |
all_tags, | |
f"{current_file_index}/{images_count}", | |
) | |
# Write all tags | |
Path(output_dir, "allTags.json").write_text(json.dumps(list(all_tags))) | |
# Return path to output directory | |
return output_dir | |
def get_all_tags(directory: Union[str, Path]) -> List[str]: | |
"""Get all tags from allTags.json""" | |
tags: List[str] = json.loads(Path(directory, "allTags.json").read_bytes()) | |
return tags | |
def load_files(directory: Union[str, Path]) -> Iterator[AnnotatedImage]: | |
"""Iterate over all files""" | |
directory = Path(directory) | |
for filepath in directory.glob("*.json"): | |
if filepath.name == "allTags.json": | |
continue | |
yield AnnotatedImage.parse_file(directory / (filepath.stem + ".json")) | |
# Usage: | |
# - Use a key stored in a file: | |
# VISION_TRAINING_KEY_FILE=.vision_key python scripts/download_images.py | |
# - Set key as environment variable | |
# VISION_TRAINING_KEY=<KEY_VALUE> python3 scripts/download_images.py | |
if __name__ == "__main__": | |
download_files(executor=ThreadPoolExecutor(16)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment