Skip to content

Instantly share code, notes, and snippets.

@charbonnierg
Last active April 25, 2022 09:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save charbonnierg/cd003e653a8cd900ee7ea8aa7a5b6121 to your computer and use it in GitHub Desktop.
Save charbonnierg/cd003e653a8cd900ee7ea8aa7a5b6121 to your computer and use it in GitHub Desktop.
Az Custom Vision file downloader
from __future__ import annotations
import json
import urllib.request as request
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Any, Dict, Iterator, List, Optional, Set, Union
from azure.cognitiveservices.vision.customvision.training import (
CustomVisionTrainingClient,
)
from msrest.authentication import ApiKeyCredentials
from pydantic import BaseModel, BaseSettings, Field
from structlog import get_logger
logger = get_logger()
class FilesSettings(BaseSettings, case_sensitive=False, env_prefix="vision_"):
"""A class used to store file settings, I.E, files where settings are defined.
It's different from parsing settings itself !
"""
config_file: Optional[str] = None
training_key_file: Optional[str] = None
project_file: Optional[str] = None
class DownloaderSettings(BaseSettings, case_sensitive=False, env_prefix="vision_"):
"""A class used to store settings.
NOTE: Training key must never be written in a script or a module. It must be parsed from environment variable or file.
"""
endpoint: str = "https://quara-poc-vision.cognitiveservices.azure.com/"
training_key: str = "SECRET"
project: str = "1cc5c053-39ed-4693-8e85-ff5096b9ea1c"
output: str = "./customvision/annotated_datasets/poc-vision"
def get_settings(
settings: Optional[DownloaderSettings] = None,
config_file: Union[Path, str, None] = None,
) -> DownloaderSettings:
"""A helper functions to parse settings either from environment or from files"""
files_settings = FilesSettings(config_file=config_file)
# Parse settings from config file
if files_settings.config_file:
parsed_settings = DownloaderSettings.parse_file(files_settings.config_file)
# Parse settings from env
else:
parsed_settings = DownloaderSettings()
# Override key from file
if files_settings.training_key_file:
parsed_settings.training_key = (
Path(files_settings.training_key_file).read_text().splitlines(False)[0]
)
# Override project from file
if files_settings.project_file:
parsed_settings.project = (
Path(files_settings.project_file).read_text().splitlines(False)[0]
)
# Override parsed settings
if settings:
return parsed_settings.copy(update=settings.dict(exclude_unset=True))
# Return settings
return parsed_settings
class AnnotatedRegion(BaseModel):
"""Metadata found for annotated region within an image"""
tag_name: str
left: float
top: float
width: float
height: float
class Config:
schema_extra = {
"examples": [
{
"tag_name": "Closed",
"left": 0.00104166672,
"top": 0.00208333344,
"width": 0.288541675,
"height": 0.5104167,
}
]
}
class AnnotatedImage(BaseModel):
"""Metadata found for an annotated image.
Image can be accessed using .image property.
"""
filename: str = Field(..., alias="FileName")
regions: Optional[List[AnnotatedRegion]] = None
@property
def image(self) -> bytes:
return Path(self.filename).read_bytes()
class Config:
allow_population_by_field_name = True
schema_extra = {
"examples": [
{
"fileName": "0001.jpg",
"regions": [
{
"tag_name": "Closed",
"left": 0.00104166672,
"top": 0.00208333344,
"width": 0.288541675,
"height": 0.5104167,
}
],
}
]
}
def download_file(
name: str,
image: Any,
output_dir: Path,
all_tags: Optional[Set[str]] = None,
step: Optional[str] = None,
) -> None:
"""Download a single file"""
image_filepath = output_dir / f"{name}.jpg"
metadata_filepath = output_dir / f"{name}.json"
# Initialize set of tags
tags: Set[str] = set()
# Initialize an empty list of regions
regions: List[Dict[str, Any]] = []
# Check if image contains regions
if image.regions is not None:
for region_data in image.regions:
# Add tag
tags.add(region_data.tag_name)
# Format the region
region = {
"tag_name": region_data.tag_name,
"left": region_data.left,
"top": region_data.top,
"width": region_data.width,
"height": region_data.height,
}
# Append the region
regions.append(region)
# Gather annotations
annotations = AnnotatedImage(filename=image_filepath.as_posix(), regions=regions)
# Write annotations to file
Path(metadata_filepath).write_text(annotations.json(by_alias=True))
# Write image to file
request.urlretrieve(url=image.original_image_uri, filename=image_filepath)
# Append to all tags
if all_tags is not None:
all_tags.update(tags)
# Leave a log message only if step is known
if step is not None:
logger.debug(
"Downloaded image",
path=image_filepath.as_posix(),
step=step,
)
def download_files(
settings: Optional[DownloaderSettings] = None,
config_file: Union[str, Path, None] = None,
executor: Optional[ThreadPoolExecutor] = None,
) -> Path:
"""Download files and return path to directory where files are stored.
This function uses a ThreadPoolExecutor to download several images at the same time.
"""
# Create executor if it does not exist
executor = executor or ThreadPoolExecutor()
# Create settings if they do not exist
settings = get_settings(settings, config_file=config_file)
# Make sure output directory exists
output_dir = Path(settings.output).absolute()
if not output_dir.exists():
logger.info("Creating output directory", path=output_dir.as_posix())
output_dir.mkdir(parents=True, exist_ok=True)
else:
logger.info("Using existing output directory", path=output_dir.as_posix())
# Create azure SDK resources
credentials = ApiKeyCredentials(in_headers={"Training-key": settings.training_key})
trainer = CustomVisionTrainingClient(settings.endpoint, credentials)
# Use snake_case variables
images_count: Optional[int] = trainer.get_tagged_image_count(settings.project)
# Make sure we've got an image count
if images_count is None:
raise ValueError(
"Could not determine total number of images in custom vision workspace."
)
precision = len(str(images_count))
# Create a new set of tags
all_tags: Set[str] = set()
with executor:
# Iterate over all images indexes
for idx in range(images_count):
# Download image 1 by 1
for image in trainer.get_tagged_images(
project_id=settings.project, take=1, skip=idx
):
# Generate formatted index
current_file_index = format(idx + 1, f"0{precision}")
# Download and parse image data
executor.submit(
download_file,
current_file_index,
image,
output_dir,
all_tags,
f"{current_file_index}/{images_count}",
)
# Write all tags
Path(output_dir, "allTags.json").write_text(json.dumps(list(all_tags)))
# Return path to output directory
return output_dir
def get_all_tags(directory: Union[str, Path]) -> List[str]:
"""Get all tags from allTags.json"""
tags: List[str] = json.loads(Path(directory, "allTags.json").read_bytes())
return tags
def load_files(directory: Union[str, Path]) -> Iterator[AnnotatedImage]:
"""Iterate over all files"""
directory = Path(directory)
for filepath in directory.glob("*.json"):
if filepath.name == "allTags.json":
continue
yield AnnotatedImage.parse_file(directory / (filepath.stem + ".json"))
# Usage:
# - Use a key stored in a file:
# VISION_TRAINING_KEY_FILE=.vision_key python scripts/download_images.py
# - Set key as environment variable
# VISION_TRAINING_KEY=<KEY_VALUE> python3 scripts/download_images.py
if __name__ == "__main__":
download_files(executor=ThreadPoolExecutor(16))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment