Created
April 9, 2025 07:57
-
-
Save kojix2/08b15818fff1978bbff63c23c879f3ea to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require "../src/onnxruntime" | |
require "vips" | |
require "option_parser" | |
# YOLOv7 Object Detection Example | |
# This example demonstrates how to use ONNXRuntime.cr with YOLOv7 for object detection | |
# COCO dataset labels | |
LABELS = ["person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", | |
"truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", | |
"bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", | |
"bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", | |
"suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", | |
"baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", | |
"bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", | |
"apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", | |
"donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", | |
"toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", | |
"microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", | |
"vase", "scissors", "teddy bear", "hair drier", "toothbrush"] | |
# Detection result structure | |
struct Detection | |
property class_id : Int32 | |
property score : Float32 | |
property x1 : Int32 | |
property y1 : Int32 | |
property x2 : Int32 | |
property y2 : Int32 | |
def initialize(@class_id : Int32, @score : Float32, @x1 : Int32, @y1 : Int32, @x2 : Int32, @y2 : Int32) | |
end | |
def label | |
LABELS[@class_id] | |
end | |
end | |
# Process model output to get detection results | |
def process_detections(scores : Array(Float32), indices : Array(Int64), | |
orig_width : Int32, orig_height : Int32, | |
model_width : Int32 = 640, model_height : Int32 = 640, | |
confidence_threshold : Float32 = 0.25) : Array(Detection) | |
detections = [] of Detection | |
# Process each detection | |
scores.size.times do |i| | |
score = scores[i] | |
# Skip low confidence detections | |
next if score < confidence_threshold | |
# Get the corresponding indices array | |
idx = indices[(i * 6)..((i + 1) * 6 - 1)] | |
# Extract index information (batch_idx, class_idx, y1, x1, y2, x2) | |
class_id = idx[1].to_i32 | |
# Scale coordinates to original image size | |
y1_scaled = (idx[2].to_i32 * orig_height / model_height).to_i | |
x1_scaled = (idx[3].to_i32 * orig_width / model_width).to_i | |
y2_scaled = (idx[4].to_i32 * orig_height / model_height).to_i | |
x2_scaled = (idx[5].to_i32 * orig_width / model_width).to_i | |
# Create detection object | |
detections << Detection.new(class_id, score, x1_scaled, y1_scaled, x2_scaled, y2_scaled) | |
end | |
detections | |
end | |
# Convert HWC format to NCHW format (Height, Width, Channel -> Batch, Channel, Height, Width) | |
def hwc_to_nchw(pixels : Array(Float32), width : Int32, height : Int32, channels : Int32 = 3) : Array(Float32) | |
# Validate array size | |
if pixels.size < width * height * channels | |
# Adjust channels if array size doesn't match expected size | |
actual_size = pixels.size | |
actual_channels = (actual_size / (width * height)).to_i | |
if actual_channels > 0 | |
channels = actual_channels | |
end | |
end | |
# Ensure integer type | |
channels = channels.to_i | |
# Create output array | |
nchw = Array(Float32).new((1 * channels * height * width).to_i, 0.0_f32) | |
# Initialize variables for error handling | |
c = 0 | |
h = 0 | |
w = 0 | |
src_idx = 0 | |
dst_idx = 0 | |
begin | |
channels.times do |c_val| | |
c = c_val | |
height.times do |h_val| | |
h = h_val | |
width.times do |w_val| | |
w = w_val | |
src_idx = ((h * width + w) * channels + c).to_i | |
dst_idx = (((c) * height + h) * width + w).to_i | |
# Check index bounds | |
if src_idx >= pixels.size || dst_idx >= nchw.size | |
next | |
end | |
nchw[dst_idx] = pixels[src_idx] | |
end | |
end | |
end | |
rescue ex | |
puts "Error in hwc_to_nchw: #{ex.message}" | |
raise ex | |
end | |
nchw | |
end | |
# Main function | |
def main(input_path : String, output_path : String, confidence_threshold : Float32 = 0.5) | |
# Load YOLOv7 model | |
model_path = "models/yolov7_post_640x640.onnx" | |
begin | |
model = OnnxRuntime::Model.new(model_path) | |
# Load image using Vips | |
image = Vips::Image.new_from_file(input_path) | |
width = image.width | |
height = image.height | |
# Resize image to model input size (640x640) | |
# Use resize instead of thumbnail_image to ensure exact 640x640 dimensions | |
resized = image.resize(640.0 / image.width, vscale: 640.0 / image.height) | |
# Convert to RGB if needed | |
if resized.bands == 1 | |
# Grayscale to RGB | |
resized = resized.bandjoin([resized, resized]) | |
elsif resized.bands == 4 | |
# RGBA to RGB | |
resized = resized.extract_band(0, n: 3) | |
end | |
# Get pixel data and normalize to 0-1 | |
data_ptr, data_size = resized.write_to_memory | |
data_slice = Slice.new(data_ptr.as(UInt8*), data_size) | |
# Convert to Float32 array and normalize | |
pixels = Array(Float32).new(data_size) do |i| | |
data_slice[i].to_f32 / 255.0_f32 | |
end | |
# Check and adjust data size if needed | |
expected_size = resized.width * resized.height * resized.bands | |
if pixels.size != expected_size | |
# Estimate channels based on actual data size | |
if pixels.size % (resized.width * resized.height) == 0 | |
actual_channels = (pixels.size / (resized.width * resized.height)).to_i | |
input_tensor = hwc_to_nchw(pixels, resized.width, resized.height, actual_channels) | |
else | |
# Rebuild data if size mismatch | |
adjusted_pixels = Array(Float32).new(expected_size, 0.0_f32) | |
pixels.each_with_index do |value, i| | |
adjusted_pixels[i] = value if i < adjusted_pixels.size | |
end | |
input_tensor = hwc_to_nchw(adjusted_pixels, resized.width, resized.height, resized.bands) | |
end | |
else | |
# Process data normally if size matches | |
input_tensor = hwc_to_nchw(pixels, resized.width, resized.height, resized.bands) | |
end | |
# Adjust model input shape | |
channels = (input_tensor.size / (resized.width * resized.height)).to_i | |
input_shape = [1_i64, channels.to_i64, resized.height.to_i64, resized.width.to_i64] | |
# Run inference | |
output = model.predict( | |
{ "images" => input_tensor }, | |
nil, | |
shape: { "images" => input_shape } | |
) | |
# Process output | |
output_values = output.values | |
scores = output_values[0].as(Array(Float32)) | |
indices = output_values[1].as(Array(Int64)) | |
# Get detection results | |
detections = process_detections(scores, indices, width, height, 640, 640, confidence_threshold) | |
# Limit the number of detections to avoid performance issues | |
if detections.size > 50 | |
detections = detections.sort_by { |det| -det.score }[0...50] | |
end | |
# Load original image for drawing | |
result_image = Vips::Image.new_from_file(input_path) | |
# Draw detections | |
result_image = result_image.mutate do |mutable| | |
detections.each do |det| | |
# Color palette based on class_id (0.0-255.0 range) | |
colors = [ | |
[255.0, 0.0, 0.0], # Red | |
[0.0, 255.0, 0.0], # Green | |
[0.0, 0.0, 255.0], # Blue | |
[255.0, 255.0, 0.0], # Yellow | |
[255.0, 0.0, 255.0], # Magenta | |
[0.0, 255.0, 255.0], # Cyan | |
[255.0, 128.0, 0.0], # Orange | |
[128.0, 0.0, 255.0], # Purple | |
[0.0, 128.0, 255.0], # Light blue | |
[255.0, 0.0, 128.0] # Pink | |
] | |
# Select color based on class_id | |
color_idx = det.class_id % colors.size | |
r, g, b = colors[color_idx] | |
# Draw rectangle outline | |
mutable.draw_rect([r, g, b].map(&.to_f64), det.x1, det.y1, det.x2 - det.x1, det.y2 - det.y1, fill: false) | |
# Draw label - commented out as text rendering may not be supported | |
# label_text = "#{det.label} (#{(det.score * 100).round(1)}%)" | |
# mutable.text(label_text, x: det.x1, y: det.y1 - 10, font: "sans", fontsize: 12, color: [r, g, b]) | |
end | |
end | |
# Save result | |
result_image.write_to_file(output_path) | |
# Release resources | |
model.release | |
OnnxRuntime::InferenceSession.release_env | |
rescue ex | |
puts "Error: #{ex.message}" | |
puts ex.backtrace.join("\n") | |
exit(1) | |
end | |
end | |
# Parse command line arguments | |
input_path = "" | |
output_path = "" | |
confidence_threshold = 0.7_f32 | |
OptionParser.parse do |parser| | |
parser.banner = "Usage: crystal examples/yolov7.cr [options]" | |
parser.on("-i PATH", "--input=PATH", "Input image path") { |path| input_path = path } | |
parser.on("-o PATH", "--output=PATH", "Output image path") { |path| output_path = path } | |
parser.on("-c THRESHOLD", "--confidence=THRESHOLD", "Confidence threshold (0.0-1.0)") { |threshold| | |
confidence_threshold = threshold.to_f32 | |
} | |
parser.on("-h", "--help", "Show this help") do | |
puts parser | |
exit | |
end | |
parser.invalid_option do |flag| | |
STDERR.puts "ERROR: #{flag} is not a valid option." | |
STDERR.puts parser | |
exit(1) | |
end | |
end | |
# Validate arguments | |
if input_path.empty? | |
STDERR.puts "ERROR: Input image path is required." | |
exit(1) | |
end | |
if output_path.empty? | |
STDERR.puts "ERROR: Output image path is required." | |
exit(1) | |
end | |
# Run main function | |
main(input_path, output_path, confidence_threshold) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment