Created
April 17, 2023 11:19
-
-
Save git-hamza/2e2e52fd767ab872f3c2cb8073723b4a to your computer and use it in GitHub Desktop.
class/label based statistics in horizental bar chart to visualize the data distribution for yolov5 format dataset
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import yaml | |
import os | |
import matplotlib.pyplot as plt | |
# Define paths to image and label directories, and class YAML file | |
data_dirs = ["train", "valid", "test"] | |
class_yaml_path = "data.yaml" | |
# Load class names from YAML file | |
with open(class_yaml_path, 'r') as f: | |
classes = yaml.safe_load(f)['names'] | |
# Initialize dictionary to store class-based statistics | |
class_stats = {c: {'num_labels': 0, 'num_images': 0} for c in classes} | |
# Loop through each data directory | |
for data_dir in data_dirs: | |
# Define paths to image and label directories in current data directory | |
image_dir = os.path.join(data_dir, 'images') | |
label_dir = os.path.join(data_dir, 'labels') | |
# Loop through images in image directory | |
for filename in os.listdir(image_dir): | |
if filename.endswith('.jpg'): | |
# Read in corresponding YOLO format label file | |
label_filename = os.path.splitext(filename)[0] + '.txt' | |
label_filepath = os.path.join(label_dir, label_filename) | |
with open(label_filepath, 'r') as f: | |
labels = f.readlines() | |
# Extract class labels from YOLO format label file | |
class_labels = [classes[int(label.split()[0])] for label in labels] | |
# Update class-based statistics | |
for c in set(class_labels): | |
class_stats[c]['num_labels'] += class_labels.count(c) | |
class_stats[c]['num_images'] += 1 | |
# Sort classes by number of labels in descending order | |
sorted_classes = sorted(classes, key=lambda c: class_stats[c]['num_labels']) | |
# Generate horizontal bar chart of class-based statistics | |
num_labels = [class_stats[c]['num_labels'] for c in sorted_classes] | |
num_images = [class_stats[c]['num_images'] for c in sorted_classes] | |
fig, ax = plt.subplots() | |
ax.barh(sorted_classes, num_labels) | |
ax.set_title('Number of Labels per Class') | |
ax.set_xlabel('Number of Labels') | |
ax.set_ylabel('Class') | |
plt.show() | |
fig, ax = plt.subplots() | |
ax.barh(sorted_classes, num_images) | |
ax.set_title('Number of Images per Class') | |
ax.set_xlabel('Number of Images') | |
ax.set_ylabel('Class') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment