Skip to content

Instantly share code, notes, and snippets.

@ChillFish8
Created July 19, 2025 16:56
Show Gist options
  • Select an option

  • Save ChillFish8/a20272907ebc5b8685976c21282b61d8 to your computer and use it in GitHub Desktop.

Select an option

Save ChillFish8/a20272907ebc5b8685976c21282b61d8 to your computer and use it in GitHub Desktop.
enable f16;
@group(0)
@binding(0)
var<storage, read_write> buffer_0_global: array<vec4<f16>>;
@group(0)
@binding(1)
var<storage, read_write> buffer_1_global: array<vec4<f16>>;
@group(0)
@binding(2)
var<storage, read_write> buffer_2_global: array<f32>;
@group(0)
@binding(3)
var<storage, read_write> info: array<u32>;
const WORKGROUP_SIZE_X = 64u;
const WORKGROUP_SIZE_Y = 1u;
const WORKGROUP_SIZE_Z = 1u;
@compute
@workgroup_size(64, 1, 1)
fn compute_cosine(
@builtin(local_invocation_id) local_invocation_id: vec3<u32>,
) {
var l_mut_21: f16;
var l_mut_2: vec4<f16>;
var l_mut_3: vec4<f16>;
var l_mut_5: vec4<f16>;
var l_mut_7: vec4<f16>;
// Haystack offset.
let l_0 = local_invocation_id.x * 1024u;
// Product init
let l_1 = f16(0);
l_mut_2 = vec4<f16>(l_1);
l_mut_3 = l_mut_2;
// Norm a
let l_4 = f16(0);
l_mut_2 = vec4<f16>(l_4);
l_mut_5 = l_mut_2;
// Norm b
let l_6 = f16(0);
l_mut_2 = vec4<f16>(l_6);
l_mut_7 = l_mut_2;
for (var l_mut_8: u32 = 0u; l_mut_8 < 1024u; l_mut_8++) {
// a = a[i];
let l_9 = buffer_0_global[l_mut_8];
// b = b[offset + i];
let l_10 = l_0 + l_mut_8;
let l_11 = buffer_1_global[l_10];
// product += a * b;
let l_12 = l_9 * l_11;
l_mut_3 = l_mut_3 + l_12;
// a = a[i]; x 2
let l_13 = buffer_0_global[l_mut_8];
let l_14 = buffer_0_global[l_mut_8];
// norm_a += a * a;
let l_15 = l_13 * l_14;
l_mut_5 = l_mut_5 + l_15;
// b = b[offset + i];
let l_16 = l_0 + l_mut_8;
let l_17 = buffer_1_global[l_16];
// b = b[offset + i];
let l_18 = l_0 + l_mut_8;
let l_19 = buffer_1_global[l_18];
// norm_b += b * b;
let l_20 = l_17 * l_19;
l_mut_7 = l_mut_7 + l_20;
}
// Reduce product
l_mut_21 = 0h;
let l_22 = l_mut_3[0u];
l_mut_21 = l_mut_21 + l_22;
let l_23 = l_mut_3[1u];
l_mut_21 = l_mut_21 + l_23;
let l_24 = l_mut_3[2u];
l_mut_21 = l_mut_21 + l_24;
let l_25 = l_mut_3[3u];
l_mut_21 = l_mut_21 + l_25;
let l_26 = f32(l_mut_21);
// Reduce Norm A
l_mut_21 = 0h;
let l_27 = l_mut_5[0u];
l_mut_21 = l_mut_21 + l_27;
let l_28 = l_mut_5[1u];
l_mut_21 = l_mut_21 + l_28;
let l_29 = l_mut_5[2u];
l_mut_21 = l_mut_21 + l_29;
let l_30 = l_mut_5[3u];
l_mut_21 = l_mut_21 + l_30;
let l_31 = f32(l_mut_21);
// Reduce Norm B
l_mut_21 = 0h;
let l_32 = l_mut_7[0u];
l_mut_21 = l_mut_21 + l_32;
let l_33 = l_mut_7[1u];
l_mut_21 = l_mut_21 + l_33;
let l_34 = l_mut_7[2u];
l_mut_21 = l_mut_21 + l_34;
let l_35 = l_mut_7[3u];
l_mut_21 = l_mut_21 + l_35;
let l_36 = f32(l_mut_21);
// 1.0 - product / sqrt(norm_a * norm_b)
let l_37 = l_31 * l_36;
let l_38 = sqrt(l_37);
let l_39 = l_26 / l_38;
let l_40 = 1f - l_39;
buffer_2_global[local_invocation_id.x] = l_40;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment