Created
July 19, 2025 16:56
-
-
Save ChillFish8/a20272907ebc5b8685976c21282b61d8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| enable f16; | |
| @group(0) | |
| @binding(0) | |
| var<storage, read_write> buffer_0_global: array<vec4<f16>>; | |
| @group(0) | |
| @binding(1) | |
| var<storage, read_write> buffer_1_global: array<vec4<f16>>; | |
| @group(0) | |
| @binding(2) | |
| var<storage, read_write> buffer_2_global: array<f32>; | |
| @group(0) | |
| @binding(3) | |
| var<storage, read_write> info: array<u32>; | |
| const WORKGROUP_SIZE_X = 64u; | |
| const WORKGROUP_SIZE_Y = 1u; | |
| const WORKGROUP_SIZE_Z = 1u; | |
| @compute | |
| @workgroup_size(64, 1, 1) | |
| fn compute_cosine( | |
| @builtin(local_invocation_id) local_invocation_id: vec3<u32>, | |
| ) { | |
| var l_mut_21: f16; | |
| var l_mut_2: vec4<f16>; | |
| var l_mut_3: vec4<f16>; | |
| var l_mut_5: vec4<f16>; | |
| var l_mut_7: vec4<f16>; | |
| // Haystack offset. | |
| let l_0 = local_invocation_id.x * 1024u; | |
| // Product init | |
| let l_1 = f16(0); | |
| l_mut_2 = vec4<f16>(l_1); | |
| l_mut_3 = l_mut_2; | |
| // Norm a | |
| let l_4 = f16(0); | |
| l_mut_2 = vec4<f16>(l_4); | |
| l_mut_5 = l_mut_2; | |
| // Norm b | |
| let l_6 = f16(0); | |
| l_mut_2 = vec4<f16>(l_6); | |
| l_mut_7 = l_mut_2; | |
| for (var l_mut_8: u32 = 0u; l_mut_8 < 1024u; l_mut_8++) { | |
| // a = a[i]; | |
| let l_9 = buffer_0_global[l_mut_8]; | |
| // b = b[offset + i]; | |
| let l_10 = l_0 + l_mut_8; | |
| let l_11 = buffer_1_global[l_10]; | |
| // product += a * b; | |
| let l_12 = l_9 * l_11; | |
| l_mut_3 = l_mut_3 + l_12; | |
| // a = a[i]; x 2 | |
| let l_13 = buffer_0_global[l_mut_8]; | |
| let l_14 = buffer_0_global[l_mut_8]; | |
| // norm_a += a * a; | |
| let l_15 = l_13 * l_14; | |
| l_mut_5 = l_mut_5 + l_15; | |
| // b = b[offset + i]; | |
| let l_16 = l_0 + l_mut_8; | |
| let l_17 = buffer_1_global[l_16]; | |
| // b = b[offset + i]; | |
| let l_18 = l_0 + l_mut_8; | |
| let l_19 = buffer_1_global[l_18]; | |
| // norm_b += b * b; | |
| let l_20 = l_17 * l_19; | |
| l_mut_7 = l_mut_7 + l_20; | |
| } | |
| // Reduce product | |
| l_mut_21 = 0h; | |
| let l_22 = l_mut_3[0u]; | |
| l_mut_21 = l_mut_21 + l_22; | |
| let l_23 = l_mut_3[1u]; | |
| l_mut_21 = l_mut_21 + l_23; | |
| let l_24 = l_mut_3[2u]; | |
| l_mut_21 = l_mut_21 + l_24; | |
| let l_25 = l_mut_3[3u]; | |
| l_mut_21 = l_mut_21 + l_25; | |
| let l_26 = f32(l_mut_21); | |
| // Reduce Norm A | |
| l_mut_21 = 0h; | |
| let l_27 = l_mut_5[0u]; | |
| l_mut_21 = l_mut_21 + l_27; | |
| let l_28 = l_mut_5[1u]; | |
| l_mut_21 = l_mut_21 + l_28; | |
| let l_29 = l_mut_5[2u]; | |
| l_mut_21 = l_mut_21 + l_29; | |
| let l_30 = l_mut_5[3u]; | |
| l_mut_21 = l_mut_21 + l_30; | |
| let l_31 = f32(l_mut_21); | |
| // Reduce Norm B | |
| l_mut_21 = 0h; | |
| let l_32 = l_mut_7[0u]; | |
| l_mut_21 = l_mut_21 + l_32; | |
| let l_33 = l_mut_7[1u]; | |
| l_mut_21 = l_mut_21 + l_33; | |
| let l_34 = l_mut_7[2u]; | |
| l_mut_21 = l_mut_21 + l_34; | |
| let l_35 = l_mut_7[3u]; | |
| l_mut_21 = l_mut_21 + l_35; | |
| let l_36 = f32(l_mut_21); | |
| // 1.0 - product / sqrt(norm_a * norm_b) | |
| let l_37 = l_31 * l_36; | |
| let l_38 = sqrt(l_37); | |
| let l_39 = l_26 / l_38; | |
| let l_40 = 1f - l_39; | |
| buffer_2_global[local_invocation_id.x] = l_40; | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment