Skip to content

Instantly share code, notes, and snippets.

@frederik-encord
Created June 26, 2023 08:09
Show Gist options
  • Save frederik-encord/0f093aa398ec5f33306b577edca63638 to your computer and use it in GitHub Desktop.
Save frederik-encord/0f093aa398ec5f33306b577edca63638 to your computer and use it in GitHub Desktop.
from pathlib import Path
import numpy as np
from encord_active.lib.labels.label_transformer import (
BoundingBox,
BoundingBoxLabel,
DataLabel,
LabelTransformer,
PolygonLabel,
)
from pydantic import BaseModel
from yaml import safe_load
class DataDefinition(BaseModel):
train: Path
validation: Path | None
test: Path | None
nc: int
names: list[str]
def read_dataset(data_file: Path) -> DataDefinition | None:
try:
obj = safe_load(data_file.read_text())
return DataDefinition.parse_obj(obj)
except:
import traceback
traceback.print_exc()
__import__("ipdb").set_trace()
return None
class YOLOv5Transformer(LabelTransformer):
def get_labels_for_data_file(
self, data_file: Path, data_def: DataDefinition
) -> list[DataLabel]:
label_file = (
data_file.parents[1] / "labels" / data_file.with_suffix(".txt").name
)
if not label_file.is_file():
return []
labels: list[DataLabel] = []
with label_file.open() as f:
for l in f:
if not l:
continue
cls_idx_str, *pieces = l.split()
cls_name = data_def.names[int(cls_idx_str)]
if len(pieces) > 4:
points = np.array(list(map(float, pieces))).reshape(-1, 2)
labels.append(
DataLabel(
abs_data_path=data_file,
label=PolygonLabel(class_=cls_name, polygon=points),
)
)
else:
x_center, y_center, width, height = pieces
x_center = float(x_center)
y_center = float(y_center)
width = float(width)
height = float(height)
x = x_center - width / 2
y = y_center - height / 2
bbox = BoundingBox(x=x, y=y, w=width, h=height)
labels.append(
DataLabel(
abs_data_path=data_file,
label=BoundingBoxLabel(bounding_box=bbox, class_=cls_name),
)
)
return labels
def from_custom_labels(self, label_files: list[Path], data_files: list[Path]):
data_def: DataDefinition | None = next(
filter(None, map(read_dataset, label_files)), None
)
if not data_def:
raise ValueError("Missing dataset file as a label glob")
out: list[DataLabel] = []
for data_file in data_files:
out += self.get_labels_for_data_file(data_file, data_def)
return out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment