Skip to content

Instantly share code, notes, and snippets.

View Steboss89's full-sized avatar

Stefano Bosisio Steboss89

View GitHub Profile
@Steboss89
Steboss89 / floatingpoint.csv
Created July 24, 2023 13:33
Floating point table specifications
floating point format number of bits bit for sign bits for exponent bits for mantissa base
f16 (half-precision) 16 1 5 10 2
f32 (single-precision) 32 1 8 23 2
f64 (double-precision) 64 1 11 52 2
@Steboss89
Steboss89 / example.py
Created June 2, 2023 13:41
Example of working KFP
# Pipeline to retrieve some data on staging
# trustedplatform-pl-stging.italianpromo.fake_dataset
from pathlib import Path
from typing import List
import kfp
import yaml
from trustedplatform_kfpvertex.pipelines.components import bigquery, virtual_machine
from trustedplatform_kfpvertex.pipelines.utils import compile
Matrix size JAX Strassen AlphaTensor
8192 2.51+/- 0.01 2.34+/- 0.01 2.24+/- 0.01
10240 5.18+/- 0.03 4.31+/- 0.01 4.12+/- 0.01
12288 8.67+/- 0.04 8.09+/- 0.02 7.94+/- 0.05
14336 13.67+/- 0.03 11.52+/- 0.01 11.16+/- 0.02
16384 20.41+/- 0.03 21.69+/- 0.03 21.03+/- 0.02
18432 30.64+/- 0.09 25.50+/- 0.01 24.96+/- 0.03
@Steboss89
Steboss89 / T4_comparison.csv
Created February 10, 2023 15:37
Comparison of performance on a Tesla T4
Matrix size JAX Strassen AlphaTensor
8192 2.513199887099989+/-0.014528186425767666 2.344882236099994+/-0.013139273132835319 2.235233979999987+/-0.0060807684769800165
10240 5.178147598400005+/-0.0341956234999571 4.310240157300029+/-0.006896268670154992 4.115188468500003+/-0.009684188165498608
12288 8.674847026200018+/-0.042683489691167704 8.094868272800023+/-0.023601467532573948 7.936362483399989+/-0.04820043866223
14336 13.67069818479997+/-0.026384762961236184 11.523782438399985+/-0.011888051059463828 11.158847375799905+/-0.023378686004509534
16384 20.408328052200023+/-0.029682819426008433 21.686704270200018+/-0.03336953765400482 21.02908948869999+/-0.023961257266623504
18432 30.6381309062+/-0.09175637807267167 25.495546809300027+/-0.008717825537352966 24.961940088700022+/-0.03458530070354073
@Steboss89
Steboss89 / jax_concatenate_matrix.py
Created October 21, 2022 13:59
Transpose the product matrix to a SIZE x SIZE matrix
c = f(a, c)
# remember to block until ready
c[0][0].block_until_ready()
# convert back to original matrix size
c_arr = np.hstack(np.concatenate(np.array(c), axis=1)).reshape(SIZE,SIZE)
@Steboss89
Steboss89 / jax_strassen.py
Created October 21, 2022 13:52
Main function for computing matmul with Strassen in JAX
def f(a: BlockMatrix, b: BlockMatrix) -> BlockMatrix:
"""Multiplies block matrices `a` and `b`."""
n = len(a)
result = [[None] * n for _ in range(n)]
for alpha in range(rank):
left = None
for i in range(n):
for j in range(n):
if factors[0][i, j, alpha] != 0:
curr = factors[0][i, j, alpha] * a[i][j]
def strassen(A, B):
"""
Parameters
----------
A: np.array: block matrix A
B: np.array: block matrix B
Return
------
C: np.array: product matrix C
Year ω Author
1969 2.81 Strassen [8]
1978 2.79 Pan [15]
1979 2.78 Bini [3]
1981 2.55 Schonage [16]
1987 2.38 Coppersmith Winograd [5]
2013 2.37 Stothers Munro [17]
2014 2.36 Le Gall [12]
2021 2.36 Vassilevska Alman [11]
@Steboss89
Steboss89 / strassen_improvements.csv
Created October 18, 2022 22:07
Improvements over time of the Strassen algorithm
Yes O(n^\omega) Author
1969 2.81 Strassen [8]
1978 2.79 Pan [15]
1979 2.78 Bini [3]
1981 2.55 Schonage [16]
1987 2.38 Coppersmith Winograd [5]
2013 2.37 Stothers Munro [17]
2014 2.36 Le Gall [12]
2021 2.36 Vassilevska Alman [11]
I = (A_{11} + A_{22})(B_{11} + B_{22})