Last active
April 23, 2024 10:46
-
-
Save datancoffee/4575e8e8900264546051a7b0b53eb8fa to your computer and use it in GitHub Desktop.
Place this into core.inference
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from transformers import pipeline as hf_pipeline | |
| from typing import Any | |
| from core.actions import Action | |
| class InferWithHuggingface(Action): | |
| def __init__( | |
| self, | |
| actionname: str = None, | |
| *args, | |
| **kwargs): | |
| super().__init__(actionname) | |
| self.hf_pipeline = hf_pipeline(*args, **kwargs) | |
| def do(self, inputs, *args:Any, **kwargs: Any): | |
| # pass the baton to __call__ of transformers.Pipeline | |
| labels = self.hf_pipeline(inputs, *args, truncation=True , **kwargs) | |
| return labels | |
| class EnrichWithHuggingface(InferWithHuggingface): | |
| def do(self, inputs:list, field:str, *args:Any, **kwargs: Any): | |
| """ | |
| Will enrich the input rows with outputs of the model. | |
| Will preserve all input fields and add the output of the model inference | |
| :param inputs: Assumed to be a list of dict | |
| :param field: The field in the dict to run the model on | |
| :param args: | |
| :param kwargs: | |
| :return: | |
| """ | |
| hf_inputs = [i.get(field,'') for i in inputs] | |
| # first get the labels and scores | |
| labels = super().do(hf_inputs, *args, **kwargs) | |
| # then add the input records, assuming that labels are in the same order | |
| for l,i in zip(labels,inputs): | |
| l.update(i) | |
| return labels |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment