Skip to content

Instantly share code, notes, and snippets.

@tcapelle
Created March 24, 2021 13:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tcapelle/463f3e12962c32842dbe551c829ad738 to your computer and use it in GitHub Desktop.
Save tcapelle/463f3e12962c32842dbe551c829ad738 to your computer and use it in GitHub Desktop.
def _get_col_idxs(df, cols):
"return cols index to perform iloc"
return [df.columns.get_loc(c) for c in L(cols) if c in df]
def _iloc(df, rows, cols=None):
"Iloc that supports col names"
if isinstance(cols, (tuple, list, str)):
cols = _get_col_idxs(df, cols)
return df.iloc[rows, cols]
return df.iloc[rows, slice(None)]
# Cell
class _Iloc:
"Get/set rows by iloc and cols by name"
def __init__(self, ds):
self.ds = ds
def __getitem__(self, idxs):
df = self.ds.df
if isinstance(idxs, tuple):
rows, cols = idxs
else:
rows, cols = idxs, slice(None)
return _iloc(df, rows, cols)
# Cell
class WindowDataset:
"A class to make windowed selection on a DataFrame"
def __init__(self, df, x_cols=None, y_cols=None, bsteps=2, fsteps=1, overlap=None, shift=1, debug=False):
store_attr(but='x_cols,y_cols')
self.x_cols, self.y_cols = listify(x_cols), listify(y_cols)
self.overlap = ifnone(overlap, 0)
self.check_validity()
def check_validity(self):
assert self.x_cols and self.y_cols, "plaease give me some x_cols y_cols to play with"
assert len(self.df) > self.bsteps + self.fsteps, "the dataset is empty"
@property
def iloc(self):
"A better iloc, that supports col names"
return _Iloc(self)
def _get_x(self, idx):
i,j = (idx-self.bsteps+1,idx+1)
return self.iloc[i:j, self.x_cols]
def _get_y(self, idx):
i,j = (idx+1+self.overlap, idx+self.fsteps+1)
return self.iloc[i:j, self.y_cols]
def __getitem__(self, idx):
"get on pair (x,y) at idx"
assert idx < len(self.df), f"idx out of bounds, len(ds)={len(self.df)}, python indexing starts at zero."
idx = (idx*self.shift + self.bsteps-1)
x, y = self._get_x(idx), self._get_y(idx)
if self.debug:
display_xy(x,y)
return x, y
def __len__(self):
return (len(self.df) - self.fsteps - self.bsteps ) // self.shift + 1
from itertools import chain
class MetaDataset:
" A dataset capable of indexing mutiple datasets at the same time!"
def __init__(self, datasets):
self.datasets = datasets
self.mapping = self._mapping_simple()
def __len__(self):
return sum([len(ds) for ds in self.datasets])
def _mapping(self):
relative_idxs = [range_of(r) for r in self.datasets]
idx_pairs = [list(zip(len(r)*[i], r)) for i,r in enumerate(relative_idxs)]
return list(chain.from_iterable(idx_pairs))
def _mapping_simple(self):
res = []
for i,ds in enumerate(self.datasets):
for j in range_of(ds):
res.append((i,j))
return res
def __getitem__(self, idx):
ds_idx, rel_idx = self.mapping[idx]
return self.datasets[ds_idx][rel_idx]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment