Created
January 4, 2024 02:34
-
-
Save schroneko/055fd8c5e85724a988eb596e9dd56125 to your computer and use it in GitHub Desktop.
test: oshizo/japanese-sexual-moderation-v2
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## https://huggingface.co/oshizo/japanese-sexual-moderation-v2 | |
## How to use | |
## $ python3 -m venv venv && source venv/bin/activate | |
## $ pip install transformers torch sentencepiece | |
## $ python main.py --input "富士山は日本で一番高い山です。" | |
import torch | |
import argparse | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
class JapaneseSexualModeration: | |
def __init__(self, model_id="oshizo/japanese-sexual-moderation-v2"): | |
self.tokenizer = AutoTokenizer.from_pretrained(model_id) | |
self.model = AutoModelForSequenceClassification.from_pretrained( | |
model_id, problem_type="regression" | |
) | |
def predict(self, texts): | |
with torch.no_grad(): | |
encoding = self.tokenizer(texts, padding="max_length", max_length=64, return_tensors="pt") | |
scores = self.model(**encoding).logits | |
return scores | |
def main(): | |
parser = argparse.ArgumentParser(description='Japanese text moderation.') | |
parser.add_argument('--input', nargs='+', help='List of texts to be moderated', required=True) | |
args = parser.parse_args() | |
texts = args.input | |
moderation = JapaneseSexualModeration() | |
scores = moderation.predict(texts) | |
for text, score in zip(texts, scores): | |
print(f"Text: {text}\nScore: {score.item()}\n") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment