from giskard_vision.core.dataloaders.base import DataIteratorBase
from giskard_vision.core.dataloaders.meta import MetaData
from giskard_vision.core.issues import EthicalIssueMeta, PerformanceIssueMeta
import numpy as np
import cv2
from typing import Optional

class DataLoaderClassification(DataIteratorBase):
	@property
	def idx_sampler(self) -> np.ndarray:
    	    	return list(range(len(self.image_paths)))

	@classmethod
	def get_image(self, idx: int) -> np.ndarray:
    	    	return cv2.imread(str(self.image_paths[idx]))

	@classmethod
	def get_label(self, idx: int) -> Optional[np.ndarray]:
    	    	return 'label'

	@classmethod
	def get_meta(self, idx: int) -> Optional[MetaData]:
    	    	default_meta = super().get_meta()  # To load default metadata
    	    	return MetaData(
        	    	data={
            	    	**default_meta.data,
            	    	'meta1': 'value1',
            	    	'meta2': 'value2',
            	    	'categorical_meta1': 'cat_value1',
            	    	'categorical_meta2': 'cat_value2'
        	    	},
        	    	categories=default_meta.categories + ['categorical_meta1', 'categorical_meta2'],
        	    	issue_groups={
            	    	**default_meta.issue_groups,
            	    	'meta1': PerformanceIssueMeta,
            	    	'meta2': EthicalIssueMeta,
            	    	'categorical_meta1': PerformanceIssueMeta,
            	    	'categorical_meta2': EthicalIssueMeta,
        	    	}
    	    	)

giskard_dataset = DataLoaderClassification()