Skip to content

Instantly share code, notes, and snippets.

@kojix2
Created April 9, 2025 07:57
Show Gist options
  • Save kojix2/08b15818fff1978bbff63c23c879f3ea to your computer and use it in GitHub Desktop.
Save kojix2/08b15818fff1978bbff63c23c879f3ea to your computer and use it in GitHub Desktop.
require "../src/onnxruntime"
require "vips"
require "option_parser"
# YOLOv7 Object Detection Example
# This example demonstrates how to use ONNXRuntime.cr with YOLOv7 for object detection
# COCO dataset labels
LABELS = ["person", "bicycle", "car", "motorcycle", "airplane", "bus", "train",
"truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter",
"bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant",
"bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie",
"suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite",
"baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket",
"bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana",
"apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza",
"donut", "cake", "chair", "couch", "potted plant", "bed", "dining table",
"toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone",
"microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock",
"vase", "scissors", "teddy bear", "hair drier", "toothbrush"]
# Detection result structure
struct Detection
property class_id : Int32
property score : Float32
property x1 : Int32
property y1 : Int32
property x2 : Int32
property y2 : Int32
def initialize(@class_id : Int32, @score : Float32, @x1 : Int32, @y1 : Int32, @x2 : Int32, @y2 : Int32)
end
def label
LABELS[@class_id]
end
end
# Process model output to get detection results
def process_detections(scores : Array(Float32), indices : Array(Int64),
orig_width : Int32, orig_height : Int32,
model_width : Int32 = 640, model_height : Int32 = 640,
confidence_threshold : Float32 = 0.25) : Array(Detection)
detections = [] of Detection
# Process each detection
scores.size.times do |i|
score = scores[i]
# Skip low confidence detections
next if score < confidence_threshold
# Get the corresponding indices array
idx = indices[(i * 6)..((i + 1) * 6 - 1)]
# Extract index information (batch_idx, class_idx, y1, x1, y2, x2)
class_id = idx[1].to_i32
# Scale coordinates to original image size
y1_scaled = (idx[2].to_i32 * orig_height / model_height).to_i
x1_scaled = (idx[3].to_i32 * orig_width / model_width).to_i
y2_scaled = (idx[4].to_i32 * orig_height / model_height).to_i
x2_scaled = (idx[5].to_i32 * orig_width / model_width).to_i
# Create detection object
detections << Detection.new(class_id, score, x1_scaled, y1_scaled, x2_scaled, y2_scaled)
end
detections
end
# Convert HWC format to NCHW format (Height, Width, Channel -> Batch, Channel, Height, Width)
def hwc_to_nchw(pixels : Array(Float32), width : Int32, height : Int32, channels : Int32 = 3) : Array(Float32)
# Validate array size
if pixels.size < width * height * channels
# Adjust channels if array size doesn't match expected size
actual_size = pixels.size
actual_channels = (actual_size / (width * height)).to_i
if actual_channels > 0
channels = actual_channels
end
end
# Ensure integer type
channels = channels.to_i
# Create output array
nchw = Array(Float32).new((1 * channels * height * width).to_i, 0.0_f32)
# Initialize variables for error handling
c = 0
h = 0
w = 0
src_idx = 0
dst_idx = 0
begin
channels.times do |c_val|
c = c_val
height.times do |h_val|
h = h_val
width.times do |w_val|
w = w_val
src_idx = ((h * width + w) * channels + c).to_i
dst_idx = (((c) * height + h) * width + w).to_i
# Check index bounds
if src_idx >= pixels.size || dst_idx >= nchw.size
next
end
nchw[dst_idx] = pixels[src_idx]
end
end
end
rescue ex
puts "Error in hwc_to_nchw: #{ex.message}"
raise ex
end
nchw
end
# Main function
def main(input_path : String, output_path : String, confidence_threshold : Float32 = 0.5)
# Load YOLOv7 model
model_path = "models/yolov7_post_640x640.onnx"
begin
model = OnnxRuntime::Model.new(model_path)
# Load image using Vips
image = Vips::Image.new_from_file(input_path)
width = image.width
height = image.height
# Resize image to model input size (640x640)
# Use resize instead of thumbnail_image to ensure exact 640x640 dimensions
resized = image.resize(640.0 / image.width, vscale: 640.0 / image.height)
# Convert to RGB if needed
if resized.bands == 1
# Grayscale to RGB
resized = resized.bandjoin([resized, resized])
elsif resized.bands == 4
# RGBA to RGB
resized = resized.extract_band(0, n: 3)
end
# Get pixel data and normalize to 0-1
data_ptr, data_size = resized.write_to_memory
data_slice = Slice.new(data_ptr.as(UInt8*), data_size)
# Convert to Float32 array and normalize
pixels = Array(Float32).new(data_size) do |i|
data_slice[i].to_f32 / 255.0_f32
end
# Check and adjust data size if needed
expected_size = resized.width * resized.height * resized.bands
if pixels.size != expected_size
# Estimate channels based on actual data size
if pixels.size % (resized.width * resized.height) == 0
actual_channels = (pixels.size / (resized.width * resized.height)).to_i
input_tensor = hwc_to_nchw(pixels, resized.width, resized.height, actual_channels)
else
# Rebuild data if size mismatch
adjusted_pixels = Array(Float32).new(expected_size, 0.0_f32)
pixels.each_with_index do |value, i|
adjusted_pixels[i] = value if i < adjusted_pixels.size
end
input_tensor = hwc_to_nchw(adjusted_pixels, resized.width, resized.height, resized.bands)
end
else
# Process data normally if size matches
input_tensor = hwc_to_nchw(pixels, resized.width, resized.height, resized.bands)
end
# Adjust model input shape
channels = (input_tensor.size / (resized.width * resized.height)).to_i
input_shape = [1_i64, channels.to_i64, resized.height.to_i64, resized.width.to_i64]
# Run inference
output = model.predict(
{ "images" => input_tensor },
nil,
shape: { "images" => input_shape }
)
# Process output
output_values = output.values
scores = output_values[0].as(Array(Float32))
indices = output_values[1].as(Array(Int64))
# Get detection results
detections = process_detections(scores, indices, width, height, 640, 640, confidence_threshold)
# Limit the number of detections to avoid performance issues
if detections.size > 50
detections = detections.sort_by { |det| -det.score }[0...50]
end
# Load original image for drawing
result_image = Vips::Image.new_from_file(input_path)
# Draw detections
result_image = result_image.mutate do |mutable|
detections.each do |det|
# Color palette based on class_id (0.0-255.0 range)
colors = [
[255.0, 0.0, 0.0], # Red
[0.0, 255.0, 0.0], # Green
[0.0, 0.0, 255.0], # Blue
[255.0, 255.0, 0.0], # Yellow
[255.0, 0.0, 255.0], # Magenta
[0.0, 255.0, 255.0], # Cyan
[255.0, 128.0, 0.0], # Orange
[128.0, 0.0, 255.0], # Purple
[0.0, 128.0, 255.0], # Light blue
[255.0, 0.0, 128.0] # Pink
]
# Select color based on class_id
color_idx = det.class_id % colors.size
r, g, b = colors[color_idx]
# Draw rectangle outline
mutable.draw_rect([r, g, b].map(&.to_f64), det.x1, det.y1, det.x2 - det.x1, det.y2 - det.y1, fill: false)
# Draw label - commented out as text rendering may not be supported
# label_text = "#{det.label} (#{(det.score * 100).round(1)}%)"
# mutable.text(label_text, x: det.x1, y: det.y1 - 10, font: "sans", fontsize: 12, color: [r, g, b])
end
end
# Save result
result_image.write_to_file(output_path)
# Release resources
model.release
OnnxRuntime::InferenceSession.release_env
rescue ex
puts "Error: #{ex.message}"
puts ex.backtrace.join("\n")
exit(1)
end
end
# Parse command line arguments
input_path = ""
output_path = ""
confidence_threshold = 0.7_f32
OptionParser.parse do |parser|
parser.banner = "Usage: crystal examples/yolov7.cr [options]"
parser.on("-i PATH", "--input=PATH", "Input image path") { |path| input_path = path }
parser.on("-o PATH", "--output=PATH", "Output image path") { |path| output_path = path }
parser.on("-c THRESHOLD", "--confidence=THRESHOLD", "Confidence threshold (0.0-1.0)") { |threshold|
confidence_threshold = threshold.to_f32
}
parser.on("-h", "--help", "Show this help") do
puts parser
exit
end
parser.invalid_option do |flag|
STDERR.puts "ERROR: #{flag} is not a valid option."
STDERR.puts parser
exit(1)
end
end
# Validate arguments
if input_path.empty?
STDERR.puts "ERROR: Input image path is required."
exit(1)
end
if output_path.empty?
STDERR.puts "ERROR: Output image path is required."
exit(1)
end
# Run main function
main(input_path, output_path, confidence_threshold)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment