from giskard_vision.core.dataloaders.base import DataIteratorBase from giskard_vision.core.dataloaders.meta import MetaData from giskard_vision.core.issues import EthicalIssueMeta, PerformanceIssueMeta import numpy as np import cv2 from typing import Optional class DataLoaderClassification(DataIteratorBase): @property def idx_sampler(self) -> np.ndarray: return list(range(len(self.image_paths))) @classmethod def get_image(self, idx: int) -> np.ndarray: return cv2.imread(str(self.image_paths[idx])) @classmethod def get_label(self, idx: int) -> Optional[np.ndarray]: return 'label' @classmethod def get_meta(self, idx: int) -> Optional[MetaData]: default_meta = super().get_meta() # To load default metadata return MetaData( data={ **default_meta.data, 'meta1': 'value1', 'meta2': 'value2', 'categorical_meta1': 'cat_value1', 'categorical_meta2': 'cat_value2' }, categories=default_meta.categories + ['categorical_meta1', 'categorical_meta2'], issue_groups={ **default_meta.issue_groups, 'meta1': PerformanceIssueMeta, 'meta2': EthicalIssueMeta, 'categorical_meta1': PerformanceIssueMeta, 'categorical_meta2': EthicalIssueMeta, } ) giskard_dataset = DataLoaderClassification()