-
-
Save strickvl/b7d5f50cc563cf7d2fde71e033925d7a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pathlib import Path | |
from typing import List, Optional | |
import random | |
import prodigy | |
from prodigy.components.loaders import Images | |
from prodigy.util import split_string | |
from prodigy.util import b64_uri_to_bytes | |
from prodigy.components.filters import filter_inputs | |
from prodigy import set_hashes | |
from prodigy.components.db import connect | |
from icevision.all import * | |
from icevision.models.checkpoint import model_from_checkpoint | |
from icevision.models import mmdet | |
from icevision.models.checkpoint import * | |
from fastai.vision.all import * | |
db = connect() | |
checkpoint_path = ( | |
"path/to/models/vfnet-checkpoint-full.pth" | |
) | |
checkpoint_and_model = model_from_checkpoint(checkpoint_path) | |
model_type = checkpoint_and_model["model_type"] | |
class_map = checkpoint_and_model["class_map"] | |
img_size = checkpoint_and_model["img_size"] | |
def detect_objects(img): | |
valid_tfms = tfms.A.Adapter( | |
[*tfms.A.resize_and_pad(img_size), tfms.A.Normalize()] | |
) | |
pred_dict = model_type.end2end_detect( | |
img, | |
valid_tfms, | |
checkpoint_and_model["model"], | |
class_map=class_map, | |
detection_threshold=0.7, | |
) | |
return pred_dict | |
def convert_prediction_to_prodigy(preds, threshold): | |
spans = [] | |
for idx, score in enumerate(preds["detection"]["scores"]): | |
xmin, ymin, xmax, ymax = preds["detection"]["bboxes"][idx].to_tensor() | |
xmin, ymin, xmax, ymax = ( | |
xmin.item(), | |
ymin.item(), | |
xmax.item(), | |
ymax.item(), | |
) | |
if score > threshold: | |
span = {} | |
points = [[xmin, ymin], [xmin, ymax], [xmax, ymin], [xmax, ymax]] | |
span["points"] = points | |
span["label"] = preds["detection"]["labels"][idx].upper() | |
spans.append(span) | |
return spans | |
def before_db(examples): | |
for eg in examples: | |
# If the image is a base64 string and the path to the original file | |
# is present in the task, remove the image data | |
if eg["image"].startswith("data:") and "path" in eg: | |
eg["image"] = eg["path"] | |
return examples | |
# Recipe decorator with argument annotations: (description, argument type, | |
# shortcut, type / converter function called on value before it's passed to | |
# the function). Descriptions are also shown when typing --help. | |
@prodigy.recipe( | |
"random_object_annotation", | |
dataset=("The dataset to use", "positional", None, str), | |
source=("Path to a directory of images", "positional", None, str), | |
label=("One or more comma-separated labels", "option", "l", split_string), | |
exclude=("Names of datasets to exclude", "option", "e", split_string), | |
darken=("Darken image to make boxes stand out more", "flag", "D", bool), | |
) | |
def random_object_annotation( | |
dataset: str, | |
source: str, | |
label: Optional[List[str]] = None, | |
exclude: Optional[List[str]] = None, | |
darken: bool = False, | |
): | |
""" | |
Manually annotate images by drawing rectangular bounding boxes or polygon | |
shapes on the image. | |
""" | |
# Load a stream of images from a directory and return a generator that | |
# yields a dictionary for each example in the data. All images are | |
# converted to base64-encoded data URIs. | |
def get_random_stream(): | |
input_hashes = connect().get_input_hashes(dataset) | |
stream = Images(source) | |
for eg in stream: | |
eg = set_hashes(eg) | |
if ( | |
random.random() < 0.03 | |
and eg["_input_hash"] not in input_hashes | |
): | |
# Pass the image (bytes) to your model and get its predictions | |
with Image.open(Path(eg["path"])) as im: | |
predictions = detect_objects(im) | |
eg["spans"] = convert_prediction_to_prodigy( | |
predictions, 0.5 | |
) | |
yield eg | |
stream = get_random_stream() | |
return { | |
"view_id": "image_manual", # Annotation interface to use | |
"dataset": dataset, # Name of dataset to save annotations | |
"stream": stream, # Incoming stream of examples | |
"exclude": exclude, # List of dataset names to exclude | |
"before_db": before_db, | |
"config": { # Additional config settings, mostly for app UI | |
"label": ", ".join(label) if label is not None else "all", | |
"labels": label, # Selectable label options, | |
"darken_image": 0.3 if darken else 0, | |
}, | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment