Created
December 31, 2023 00:33
-
-
Save alexander-wei/9748f6e3cdd02e82f0e64e756063b14d to your computer and use it in GitHub Desktop.
RBM Ingredients Pipeline
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pandas import DataFrame as pdf, Series | |
from tqdm.auto import tqdm | |
from random import sample, randint | |
from ast import literal_eval | |
class DataFrame(pdf): | |
"""DataFrame structure for loading ingredient lists from the Food.com dataset | |
Extends pandas.DataFrame | |
Implements preprocess(), collect_first_words() methods for parsing raw ingredient data | |
""" | |
ingredients: Series | |
sampled_words: Series | |
id: Series | |
def __init__(self, data: str | pdf, *ac, **av) -> None: | |
"""init from path or pandas.dataframe""" | |
if isinstance(data, str): | |
data = DataFrame.read_csv(data, *ac, **av) | |
super().__init__(data=data) | |
self.preprocess() | |
self.collect_first_words() | |
def collect_first_words(self): | |
ingr_aggregator = [] | |
self['sampled_words'] = Series() | |
idx_ingredients = list(zip(self.id, self.ingredients)) | |
for idx, ingred_list\ | |
in tqdm(idx_ingredients): | |
# recipe_id, list of ingredients per recipe | |
for ingr_group in sample(ingred_list, randint(0,len(ingred_list))): | |
# set of ingredients subset of all ingredients (sample) | |
for w in collect_first_k_words(ingr_group, 2): | |
# first two words (split ' ') of each ingredient is entered | |
ingr_aggregator.append({'id': idx, | |
'ingredients': w}) | |
parsed_ingrs = pdf(ingr_aggregator) | |
parsed_ingr_groups = parsed_ingrs.groupby("id").agg(list) | |
idx_lookup = self.df.reset_index().set_index('id').to_dict()['index'] | |
for idx, ingreds in tqdm(parsed_ingr_groups.iterrows(), | |
total=len(parsed_ingr_groups)): | |
self.at[idx_lookup[idx], 'sampled_words'] = ingreds.to_list()[0] | |
@staticmethod | |
def read_csv(path, *ac, **av): | |
df = read_csv(path, *ac, **av) | |
return df | |
def preprocess(self): | |
ingredients = self.ingredients.apply(try_literal_eval) | |
self['ingredients'] = ingredients | |
def collect_first_k_words(words: list, k): | |
yield from words.split(" ")[:k] | |
def try_literal_eval(s): | |
"""convert raw string into Python object""" | |
try: | |
return literal_eval(s) | |
except: return [] | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List, Iterable, Tuple | |
from types import NoneType | |
from sklearn.preprocessing import MultiLabelBinarizer | |
from ingest import DataFrame | |
from util import split_iter | |
import numpy as np | |
from joblib import numpy_pickle | |
class Binarizer(MultiLabelBinarizer): | |
"""multi hot""" | |
dataframe = None | |
def __init__(self, dataframe: DataFrame, classes: List | None = None, sparse_output: bool = False) -> None: | |
super().__init__(classes=classes, sparse_output=sparse_output) | |
self.fit(dataframe.sampled_words) | |
@staticmethod | |
def load(dataframe: DataFrame, mlb: MultiLabelBinarizer | NoneType = None): | |
if isinstance(mlb, NoneType): | |
mlb = Binarizer(dataframe=dataframe) | |
xtrain = mlb.transform(dataframe.sampled_words) | |
return xtrain.astype(np.float16) | |
class BatchedDataFrame: | |
"""Convenience methods for generation of final training dataset""" | |
label_binarizer: Binarizer | |
iterator: Iterable | |
@staticmethod | |
def batches(data: DataFrame, batch_size=36, *ac, **av) -> Tuple[Binarizer, Iterable]: | |
# super().__init__(data, *ac, **av) | |
label_binarizer = Binarizer(data.dropna()) | |
iterator = split_iter(Binarizer.load(data.dropna(), label_binarizer), batch_size) | |
return label_binarizer, iterator | |
@staticmethod | |
def save(df: DataFrame, batch_size=36, dest="."): | |
v, batches = BatchedDataFrame.batches(df, batch_size) | |
# for i, batch in enumerate(batches): | |
# # qee = split(yambo.astype(np.float16),6500) | |
# # for j in range(128): | |
numpy_pickle.dump(list(batches), "./out/batches_%d_%d.job" % (0,batch_size)) | |
def collect_first_k_words(words: list, k): | |
yield from words.split(" ")[:k] | |
def split(a: List, n: int): | |
q = len(a) // n | |
assert q > 0 | |
k, m = divmod(len(a), q) | |
return (a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(q)) | |
def split_iter(a: List, n: int): | |
q = len(a) // n - 1 | |
assert q > 0 | |
k, m = divmod(len(a), q) | |
for i in range(q): | |
yield a[i*n:(i+1)*n] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment