Skip to content

Instantly share code, notes, and snippets.

@strickvl
Created December 18, 2021 21:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save strickvl/b7d5f50cc563cf7d2fde71e033925d7a to your computer and use it in GitHub Desktop.
Save strickvl/b7d5f50cc563cf7d2fde71e033925d7a to your computer and use it in GitHub Desktop.
from pathlib import Path
from typing import List, Optional
import random
import prodigy
from prodigy.components.loaders import Images
from prodigy.util import split_string
from prodigy.util import b64_uri_to_bytes
from prodigy.components.filters import filter_inputs
from prodigy import set_hashes
from prodigy.components.db import connect
from icevision.all import *
from icevision.models.checkpoint import model_from_checkpoint
from icevision.models import mmdet
from icevision.models.checkpoint import *
from fastai.vision.all import *
db = connect()
checkpoint_path = (
"path/to/models/vfnet-checkpoint-full.pth"
)
checkpoint_and_model = model_from_checkpoint(checkpoint_path)
model_type = checkpoint_and_model["model_type"]
class_map = checkpoint_and_model["class_map"]
img_size = checkpoint_and_model["img_size"]
def detect_objects(img):
valid_tfms = tfms.A.Adapter(
[*tfms.A.resize_and_pad(img_size), tfms.A.Normalize()]
)
pred_dict = model_type.end2end_detect(
img,
valid_tfms,
checkpoint_and_model["model"],
class_map=class_map,
detection_threshold=0.7,
)
return pred_dict
def convert_prediction_to_prodigy(preds, threshold):
spans = []
for idx, score in enumerate(preds["detection"]["scores"]):
xmin, ymin, xmax, ymax = preds["detection"]["bboxes"][idx].to_tensor()
xmin, ymin, xmax, ymax = (
xmin.item(),
ymin.item(),
xmax.item(),
ymax.item(),
)
if score > threshold:
span = {}
points = [[xmin, ymin], [xmin, ymax], [xmax, ymin], [xmax, ymax]]
span["points"] = points
span["label"] = preds["detection"]["labels"][idx].upper()
spans.append(span)
return spans
def before_db(examples):
for eg in examples:
# If the image is a base64 string and the path to the original file
# is present in the task, remove the image data
if eg["image"].startswith("data:") and "path" in eg:
eg["image"] = eg["path"]
return examples
# Recipe decorator with argument annotations: (description, argument type,
# shortcut, type / converter function called on value before it's passed to
# the function). Descriptions are also shown when typing --help.
@prodigy.recipe(
"random_object_annotation",
dataset=("The dataset to use", "positional", None, str),
source=("Path to a directory of images", "positional", None, str),
label=("One or more comma-separated labels", "option", "l", split_string),
exclude=("Names of datasets to exclude", "option", "e", split_string),
darken=("Darken image to make boxes stand out more", "flag", "D", bool),
)
def random_object_annotation(
dataset: str,
source: str,
label: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
darken: bool = False,
):
"""
Manually annotate images by drawing rectangular bounding boxes or polygon
shapes on the image.
"""
# Load a stream of images from a directory and return a generator that
# yields a dictionary for each example in the data. All images are
# converted to base64-encoded data URIs.
def get_random_stream():
input_hashes = connect().get_input_hashes(dataset)
stream = Images(source)
for eg in stream:
eg = set_hashes(eg)
if (
random.random() < 0.03
and eg["_input_hash"] not in input_hashes
):
# Pass the image (bytes) to your model and get its predictions
with Image.open(Path(eg["path"])) as im:
predictions = detect_objects(im)
eg["spans"] = convert_prediction_to_prodigy(
predictions, 0.5
)
yield eg
stream = get_random_stream()
return {
"view_id": "image_manual", # Annotation interface to use
"dataset": dataset, # Name of dataset to save annotations
"stream": stream, # Incoming stream of examples
"exclude": exclude, # List of dataset names to exclude
"before_db": before_db,
"config": { # Additional config settings, mostly for app UI
"label": ", ".join(label) if label is not None else "all",
"labels": label, # Selectable label options,
"darken_image": 0.3 if darken else 0,
},
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment