This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from typing import List, Dict, Generator | |
| def flatten_luple(struct: List | tuple) -> Generator: | |
| """Flatten a list/tuple. | |
| E.g. [(1,2,3), [5,6], 3] -> 1, 2, 3, 4, 5, 6 | |
| Args: | |
| struct (List | tuple): A list or tuple |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from typing import NamedTuple, Callable | |
| import jax.numpy as jnp | |
| import jax.random as random | |
| import jax.scipy.stats as stats | |
| import numpy as np | |
| import plotext as plt | |
| from jax import vmap |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Pretty print tables summarizing properties of tensor arrays in numpy, pytorch, jax, etc. |