Last active
May 28, 2023 08:09
-
-
Save yt7589/e6f28328a0ce56f21db3861113ea5c94 to your computer and use it in GitHub Desktop.
Multiplication of mojo NDBuffer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "b1f2ae38", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from Buffer import NDBuffer\n", | |
"from List import DimList\n", | |
"from DType import DType\n", | |
"from Pointer import DTypePointer\n", | |
"from Index import StaticIntTuple\n", | |
"from List import VariadicList\n", | |
"\n", | |
"from Benchmark import Benchmark\n", | |
"from Intrinsics import strided_load\n", | |
"from Math import div_ceil, min\n", | |
"from Memory import memset_zero\n", | |
"from Object import object, Attr\n", | |
"from Random import rand, random_f64\n", | |
"from TargetInfo import dtype_sizeof, dtype_simd_width\n", | |
"from Functional import vectorize_unroll\n", | |
"from Functional import Static2DTileUnitFunc as Tile2DFunc\n", | |
"from Functional import parallelize" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "f519d0a2", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Mojo has SIMD vector types, we can vectorize the Matmul code as follows.\n", | |
"alias nelts = dtype_simd_width[DType.f32]() # The SIMD vector width.\n", | |
"\n", | |
"alias B_r: Int = 3\n", | |
"alias B_d: DimList = DimList(2, 5, 1)\n", | |
"alias B_t: DType = DType.f32\n", | |
"\n", | |
"alias A_r: Int = 3\n", | |
"alias A_d: DimList = DimList(2, 1, 5)\n", | |
"alias A_t: DType = DType.f32\n", | |
"\n", | |
"alias C_r: Int = 3\n", | |
"alias C_d: DimList = DimList(2, 1, 1)\n", | |
"alias C_t: DType = DType.f32\n", | |
"\n", | |
"\n", | |
"\n", | |
"# Perform 2D tiling on the iteration space defined by end_x and end_y. [Just copy from Matmul.ipynb]\n", | |
"fn tile[tiled_fn: Tile2DFunc, tile_x: Int, tile_y: Int](end_x: Int, end_y: Int):\n", | |
" # Note: this assumes that ends are multiples of the tiles.\n", | |
" for y in range(0, end_y, tile_y):\n", | |
" for x in range(0, end_x, tile_x):\n", | |
" tiled_fn[tile_x, tile_y](x, y) \n", | |
"\n", | |
"# The entry point to NDBuffer multiplication [Just copy from Matmul.ipynb]\n", | |
"fn matmul(inout C: NDBuffer[C_r, C_d, C_t], A: NDBuffer[A_r, A_d, A_t], B: NDBuffer[B_r, B_d, B_t]):\n", | |
" let mr = A.get_rank() - 2\n", | |
" if 0 == mr:\n", | |
" # 矩阵相乘\n", | |
" matmul_2d(C, A, B)\n", | |
" elif 1 == mr:\n", | |
" for idx in range(A.dim(0)):\n", | |
" matmul_3d(C, A, B, idx)\n", | |
" elif 2 == mr:\n", | |
" for i1 in range(A.dim(0)):\n", | |
" for i2 in range(A.dim(1)):\n", | |
" matmul_4d(C, A, B, i1, i2)\n", | |
" else:\n", | |
" print('Error: We only support 2, 3, 4 dimension NDBuffer multiplication at present')\n", | |
"\n", | |
"# 2D matrix multiplication [Copy from Matmul.ipynb with some modification]\n", | |
"fn matmul_2d(C: NDBuffer[C_r, C_d, C_t], A: NDBuffer[A_r, A_d, A_t], B: NDBuffer[B_r, B_d, B_t]):\n", | |
" @parameter\n", | |
" fn calc_row(m: Int):\n", | |
" @parameter\n", | |
" fn calc_tile[tile_x: Int, tile_y: Int](x: Int, y: Int):\n", | |
" for k in range(y, y + tile_y):\n", | |
" @parameter\n", | |
" fn dot[nelts : Int,](n : Int):\n", | |
" #C.simd_store[nelts](m,n+x, C.load[nelts](m,n+x) + A[m,k] * B.load[nelts](k,n+x))\n", | |
" C.simd_store[nelts](StaticIntTuple[C_r](m,n+x), C.simd_load[nelts](VariadicList(m,n+x)) + A.simd_load[nelts](VariadicList(m, k)) * B.simd_load[nelts](VariadicList(k,n+x)) )\n", | |
"\n", | |
" # Vectorize by nelts and unroll by tile_x/nelts\n", | |
" # Here unroll factor is 4\n", | |
" vectorize_unroll[nelts, tile_x//nelts, dot](tile_x)\n", | |
"\n", | |
" alias x_tile_size = 5\n", | |
" alias y_tile_size = 1\n", | |
" #tile[calc_tile, nelts*tile_size, tile_size](A.dim(A.get_rank()-1), C.dim(C.get_rank()-1))\n", | |
" tile[calc_tile, nelts*x_tile_size, y_tile_size](A.dim(A.get_rank()-1), C.dim(C.get_rank()-1))\n", | |
"\n", | |
" parallelize[calc_row](C.dim(C.get_rank()-2))\n", | |
" \n", | |
"# 3D tensor multiplication [Copy from Matmul.ipynb with some modification]\n", | |
"fn matmul_3d(inout C: NDBuffer[C_r, C_d, C_t], A: NDBuffer[A_r, A_d, A_t], B: NDBuffer[B_r, B_d, B_t], idx: Int):\n", | |
" @parameter\n", | |
" fn calc_row(m: Int):\n", | |
" @parameter\n", | |
" fn calc_tile[tile_x: Int, tile_y: Int](x: Int, y: Int):\n", | |
" for k in range(y, y + tile_y):\n", | |
" @parameter\n", | |
" fn dot[nelts : Int,](n : Int):\n", | |
" C.simd_store[1](StaticIntTuple[C_r](idx, m,n+x), C.simd_load[1](VariadicList(idx,m,n+x)) + A.simd_load[1](VariadicList(idx, m, k)) * B.simd_load[1](VariadicList(idx, k,n+x)) )\n", | |
"\n", | |
" # Vectorize by nelts and unroll by tile_x/nelts\n", | |
" # Here unroll factor is 4\n", | |
" vectorize_unroll[nelts, tile_x//nelts, dot](tile_x)\n", | |
"\n", | |
" alias x_tile_size = 1\n", | |
" alias y_tile_size = 5\n", | |
" tile[calc_tile, nelts*x_tile_size, y_tile_size](A.dim(A.get_rank()-1), C.dim(C.get_rank()-1)) # [calc_tile, 4, 4](5, 1)\n", | |
"\n", | |
" parallelize[calc_row](C.dim(C.get_rank()-2))\n", | |
" \n", | |
"# 4D tensor multiplication. Because I do not know how to concat two VariadicList to form a new VariadicList. So I have to write a seperate program for nD tensor multiplication. \n", | |
"# I know that it is uglly. I want to the right way to this things. But I do not know how!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! [Copy from Matmul.ipynb with some modification]\n", | |
"fn matmul_4d(C: NDBuffer[C_r, C_d, C_t], A: NDBuffer[A_r, A_d, A_t], B: NDBuffer[B_r, B_d, B_t], i1: Int, i2: Int):\n", | |
" @parameter\n", | |
" fn calc_row(m: Int):\n", | |
" @parameter\n", | |
" fn calc_tile[tile_x: Int, tile_y: Int](x: Int, y: Int):\n", | |
" for k in range(y, y + tile_y):\n", | |
" @parameter\n", | |
" fn dot[nelts : Int,](n : Int):\n", | |
" #C.simd_store[nelts](m,n+x, C.load[nelts](m,n+x) + A[m,k] * B.load[nelts](k,n+x))\n", | |
" C.simd_store[nelts](StaticIntTuple[C_r](i1, i2, m,n+x), C.simd_load[nelts](VariadicList(i1, i2,m,n+x)) + A.simd_load[nelts](VariadicList(i1, i2, m, k)) * B.simd_load[nelts](VariadicList(i1, i2, k,n+x)) )\n", | |
"\n", | |
" # Vectorize by nelts and unroll by tile_x/nelts\n", | |
" # Here unroll factor is 4\n", | |
" vectorize_unroll[nelts, tile_x//nelts, dot](tile_x)\n", | |
"\n", | |
" alias x_tile_size = 5\n", | |
" alias y_tile_size = 1\n", | |
" tile[calc_tile, nelts*x_tile_size, y_tile_size](A.dim(A.get_rank()-1), C.dim(C.get_rank()-1))\n", | |
"\n", | |
" parallelize[calc_row](C.dim(C.get_rank()-2))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "716492f3", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"var X_data = DTypePointer[B_t].alloc(8*5*1)\n", | |
"var X = NDBuffer[B_r, B_d, B_t](X_data)\n", | |
"for r in range(2):\n", | |
" for c in range(5):\n", | |
" X[VariadicList(r, c, 0)] = r*10 + c + 100\n", | |
"print(X[0,0,0], X[0,1,0], X[0, 2, 0], X[0,3,0], X[0,4,0])\n", | |
"print(X[1,0,0], X[1,1,0], X[1,2,0], X[1,3,0], X[1,4,0])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "d6b8b00d", | |
"metadata": {}, | |
"source": [ | |
"100.000000 101.000000 102.000000 103.000000 104.000000 \n", | |
"110.000000 111.000000 112.000000 113.000000 114.000000" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "19bfc4cf", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"var W_data = DTypePointer[A_t].alloc(8*1*5)\n", | |
"var W = NDBuffer[A_r, A_d, A_t](W_data)\n", | |
"for r in range(2):\n", | |
" for c in range(5):\n", | |
" W[VariadicList(r, 0, c)] = 1.0 + c*0.1\n", | |
"print(W[0,0,0], W[0,0,1], W[0,0,2], W[0,0,3], W[0,0,4])\n", | |
"print(W[1,0,0], W[1,0,1], W[1,0,2], W[1,0,3], W[1,0,4])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "79c9b7b4", | |
"metadata": {}, | |
"source": [ | |
"1.000000 1.100000 1.200000 1.300000 1.400000 \n", | |
"1.000000 1.100000 1.200000 1.300000 1.400000" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "df6c03df", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"var Z_data = DTypePointer[C_t].alloc(2*1*1)\n", | |
"var Z = NDBuffer[C_r, C_d, C_t](Z_data)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "fefe1b06", | |
"metadata": {}, | |
"source": [ | |
"I will do:\n", | |
"$$\n", | |
"Z = W \\cdot X \\quad Z \\in R^{2 \\times 1 \\times 1}, W \\in R^{2 \\times 1 \\times 5}, X \\in R^{2 \\times 5 \\times 1}\n", | |
"$$\n", | |
"Notes:\n", | |
"* 2 is the batch_size;\n", | |
"* the feature is $\\boldsymbol{x} \\in R^{5 \\times 1}$;\n", | |
"* X is the design matrix of one batch. W is the weight matrix. \n", | |
"I used $R^{n \\times 1}$ to represent vector." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "d1998b2d", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"matmul(Z, W, X)\n", | |
"print(Z[0,0,0], Z[1,0,0])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "8e7eb8ff", | |
"metadata": {}, | |
"source": [ | |
"613.000000 \n", | |
"673.000000" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "e84aff37", | |
"metadata": {}, | |
"source": [ | |
"My Questions:\n", | |
"1. It seemed that I have to write a seperate matmul function for every shape matrix multiplication. It is unacceptable;\n", | |
"2. How to do transpose operation against NDBuffer?\n", | |
"3. How to do reshape operation against NDBuffer?\n", | |
"4. How to do broadcast operation? I have to duplication weights manually in order to multiply to mini batch design matrix.\n", | |
"5. If I want to set value I can not use X[i,j,k]=value format. But I can use X[i,j,k] to access that value." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "5e351db9", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.7" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment