-
-
Save sklam/4f08128860485485c7b3dbd2c9149e07 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from numba.core.extending import overload | |
from numba import njit, literally | |
from numba import typed | |
from numba.types import unliteral | |
import numpy as np | |
from collections import namedtuple | |
DF = namedtuple("DF", ['columns', 'values']) | |
def dataframe(d): | |
... | |
@overload(dataframe) | |
def _ov_dataframe(d): | |
lit_d = d.literal_value | |
# turn literal dictionary in tuples | |
columns = tuple(map(lambda x: x.literal_value, lit_d.keys())) | |
values = ', '.join([f"d[{k!r}]" for k in columns]) | |
# dynamically make source code to populate the dataframe | |
source = f""" | |
def impl(d): | |
values = ({values}) | |
return DF(columns=columns, values=values) | |
""" | |
ns = dict(DF=DF, columns=columns) | |
exec(source, ns) | |
impl = ns['impl'] | |
return impl | |
def df_transform(df, ops): | |
... | |
@overload(df_transform) | |
def _ov_transform(df, ops): | |
lit_ops = ops.literal_value | |
columns = tuple(map(lambda x: x.literal_value, lit_ops.keys())) | |
ops = {k.literal_value: v for k, v in lit_ops.items()} | |
values = ', '.join([f"ops[{k!r}](df.values[{i}])" for i, k in enumerate(columns)]) | |
# dynamically make source code to populate the dataframe | |
source = f""" | |
def impl(df, ops): | |
values = ({values}) | |
return DF(columns=columns, values=values) | |
""" | |
print(source) | |
ns = dict(DF=DF, columns=columns) | |
exec(source, ns) | |
impl = ns['impl'] | |
return impl | |
@njit | |
def incr_count(x): | |
return x + 1 | |
@njit | |
def name_length(xs): | |
out = typed.List() | |
for x in xs: | |
out.append(len(x)) | |
return out | |
@njit | |
def make_dataframe(names): | |
count = np.arange(len(names)) | |
d = { | |
'names': names, | |
'count': count, | |
} | |
df = dataframe(d) | |
df_after = df_transform(df, {'names': name_length, 'count': incr_count}) | |
return df, df_after | |
names = typed.List(["alice", "bob", "charles"]) | |
df, df_after = make_dataframe(names) | |
# make_dataframe.inspect_types() | |
print(df) | |
print(df_after) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment