Skip to content

Instantly share code, notes, and snippets.

@goddoe
Last active December 26, 2017 06:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save goddoe/2debab59938a042bea9d1e14fecb8d35 to your computer and use it in GitHub Desktop.
Save goddoe/2debab59938a042bea9d1e14fecb8d35 to your computer and use it in GitHub Desktop.
import numpy as np
import rpy2.robjects as robjects
from rpy2.robjects import numpy2ri
from rpy2.robjects.packages import importr
r = robjects.r
numpy2ri.activate()
class Model(object):
"""
R Model Loader
Attributes
----------
model : R object
Examples
--------
>>> import numpy as np
>>> from rpy2_wrapper.model import Model
>>>
>>> # Constants
>>> MODEL_PATH = "path/to/model"
>>>
>>> # Example Input
>>> X = np.array([[3,1,0,0,3], # 0
>>> [1,2,3,0,2]] # 1
>>>
>>> # Example Run
>>> model = Model().load(MODEL_PATH)
>>> pred = model.predict(X)
>>>
>>> # Check result
>>> label = np.argmax(pred,axis=1)
>>> assert np.array_equal(label, [0,1]), "Result is wrong"
>>>
>>> # Example output
>>> print("="*30)
>>> print("probs")
>>> print(pred)
"""
def __init__(self):
self.model = None
def load(self, path):
model_rds_path = "{}.rds".format(path)
model_dep_path = "{}.dep".format(path)
print("="*30)
print("load model from: {}".format(model_rds_path))
self.model = r.readRDS(model_rds_path)
print("load model dep from: {}".format(model_dep_path))
with open(model_dep_path, "rt") as f:
model_dep_list = [ importr(dep.strip()) for dep in f.readlines( ) if dep.strip() != '' ]
print("="*30)
return self
def predict(self, X):
"""
Perform classification on samples in X.
Parameters
----------
X : array, shape (n_samples, n_features)
Returns
-------
pred_probs : array, shape (n_samples, probs)
"""
if self.model is None:
raise Exception("There is no Model")
if type(X) is not np.ndarray:
X = np.array(X)
pred = r.predict(self.model, X, probability=True)
probs = r.attr(pred, "probabilities")
return np.array(probs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment