Skip to content

Instantly share code, notes, and snippets.

@urigoren
Last active August 31, 2022 15:54
Show Gist options
  • Save urigoren/4c1b6c6882c861d51678e584915153b6 to your computer and use it in GitHub Desktop.
Save urigoren/4c1b6c6882c861d51678e584915153b6 to your computer and use it in GitHub Desktop.
from collections import defaultdict
from itertools import product
from scipy import sparse
from sklearn.base import TransformerMixin
class InteractionBySplit(TransformerMixin):
"""
Takes a sparse matrix as input, and an index to split by, and returns all possible interactions before and after that index.
"""
def __init__(self, split_index,*args,**kwargs):
super().__init__(*args,**kwargs)
self.split_index=split_index
def transform(self,X):
X=X.tocoo()
M=sparse.dok_matrix((X.shape[0],self.split_index*(X.shape[1]-self.split_index)))
pre,post=defaultdict(list),defaultdict(list)
rows=set()
for row,col,v in zip(X.row,X.col,X.data):
rows.add(row)
if col<self.split_index:
pre[row].append((col,v))
else:
post[row].append((col-self.split_index,v))
for row in rows:
for a,b in product(pre[row],post[row]):
M[row,a[0]+b[0]*self.split_index]=a[1]*b[1]
return M.tocsr()
if __name__=="__main__":
X = sparse.coo_matrix([[1,0,0,1,0,0],[1,0,0,0,1,0],[0,1,0,1,0,0]])
Y = InteractionBySplit(3).transform(X).todense()
print(Y)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment