Last active
December 26, 2017 06:16
-
-
Save goddoe/2debab59938a042bea9d1e14fecb8d35 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import rpy2.robjects as robjects | |
from rpy2.robjects import numpy2ri | |
from rpy2.robjects.packages import importr | |
r = robjects.r | |
numpy2ri.activate() | |
class Model(object): | |
""" | |
R Model Loader | |
Attributes | |
---------- | |
model : R object | |
Examples | |
-------- | |
>>> import numpy as np | |
>>> from rpy2_wrapper.model import Model | |
>>> | |
>>> # Constants | |
>>> MODEL_PATH = "path/to/model" | |
>>> | |
>>> # Example Input | |
>>> X = np.array([[3,1,0,0,3], # 0 | |
>>> [1,2,3,0,2]] # 1 | |
>>> | |
>>> # Example Run | |
>>> model = Model().load(MODEL_PATH) | |
>>> pred = model.predict(X) | |
>>> | |
>>> # Check result | |
>>> label = np.argmax(pred,axis=1) | |
>>> assert np.array_equal(label, [0,1]), "Result is wrong" | |
>>> | |
>>> # Example output | |
>>> print("="*30) | |
>>> print("probs") | |
>>> print(pred) | |
""" | |
def __init__(self): | |
self.model = None | |
def load(self, path): | |
model_rds_path = "{}.rds".format(path) | |
model_dep_path = "{}.dep".format(path) | |
print("="*30) | |
print("load model from: {}".format(model_rds_path)) | |
self.model = r.readRDS(model_rds_path) | |
print("load model dep from: {}".format(model_dep_path)) | |
with open(model_dep_path, "rt") as f: | |
model_dep_list = [ importr(dep.strip()) for dep in f.readlines( ) if dep.strip() != '' ] | |
print("="*30) | |
return self | |
def predict(self, X): | |
""" | |
Perform classification on samples in X. | |
Parameters | |
---------- | |
X : array, shape (n_samples, n_features) | |
Returns | |
------- | |
pred_probs : array, shape (n_samples, probs) | |
""" | |
if self.model is None: | |
raise Exception("There is no Model") | |
if type(X) is not np.ndarray: | |
X = np.array(X) | |
pred = r.predict(self.model, X, probability=True) | |
probs = r.attr(pred, "probabilities") | |
return np.array(probs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment