Skip to content

Instantly share code, notes, and snippets.

@nousr
Created January 25, 2023 20:44
Show Gist options
  • Save nousr/4cdca926ca5ad21a89f0e80d341d4b15 to your computer and use it in GitHub Desktop.
Save nousr/4cdca926ca5ad21a89f0e80d341d4b15 to your computer and use it in GitHub Desktop.
@lru_cache(maxsize=None)
def load_safety_model(clip_model):
"""load the safety model"""
import torch # pylint: disable=import-outside-toplevel
import autokeras as ak # pylint: disable=import-outside-toplevel
from tensorflow.keras.models import load_model # pylint: disable=import-outside-toplevel
class H14_NSFW_Detector(nn.Module):
def __init__(self, input_size=1024):
super().__init__()
self.input_size = input_size
self.layers = nn.Sequential(
nn.Linear(self.input_size, 1024),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(1024, 2048),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(2048, 1024),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(1024, 256),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(256, 128),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, 16),
nn.Linear(16, 1)
)
def forward(self, x):
return self.layers(x)
cache_folder = get_cache_folder(clip_model)
if clip_model == "ViT-L/14":
model_dir = cache_folder + "/clip_autokeras_binary_nsfw"
dim = 768
elif clip_model == "ViT-B/32":
model_dir = cache_folder + "/clip_autokeras_nsfw_b32"
dim = 512
elif clip_model == "open_clip:ViT-H-14":
model_dir = cache_folder + "/h14_nsfw_detector"
else:
raise ValueError(f"Safety model for {clip_model} not available.")
if not os.path.exists(model_dir):
os.makedirs(cache_folder, exist_ok=True)
from urllib.request import urlretrieve # pylint: disable=import-outside-toplevel
path_to_zip_file = cache_folder + "/clip_autokeras_binary_nsfw.zip"
if clip_model == "ViT-L/14":
url_model = "https://raw.githubusercontent.com/LAION-AI/CLIP-based-NSFW-Detector/main/clip_autokeras_binary_nsfw.zip"
elif clip_model == "ViT-B/32":
url_model = (
"https://raw.githubusercontent.com/LAION-AI/CLIP-based-NSFW-Detector/main/clip_autokeras_nsfw_b32.zip"
)
elif clip_model == "open_clip:ViT-H-14":
url_model = "https://github.com/LAION-AI/CLIP-based-NSFW-Detector/raw/main/h14_nsfw.pth"
else:
raise ValueError("Unknown model {}".format(clip_model)) # pylint: disable=consider-using-f-string
urlretrieve(url_model, path_to_zip_file)
import zipfile # pylint: disable=import-outside-toplevel
with zipfile.ZipFile(path_to_zip_file, "r") as zip_ref:
zip_ref.extractall(cache_folder)
FAKE_BATCH_SIZE = 10**3
FAKE_BATCH = np.random.rand(FAKE_BATCH_SIZE, dim).astype("float32")
if clip_model == "open_clip:ViT-H-14":
state = torch.load(os.path.join(model_dir, "h14_nsfw.pth"), map_location="cpu")
loaded_model = H14_NSFW_Detector()
loaded_model.load_state_dict(state)
loaded_model(torch.from_numpy(FAKE_BATCH))
else:
loaded_model = load_model(model_dir, custom_objects=ak.CUSTOM_OBJECTS)
loaded_model.predict(FAKE_BATCH, batch_size=FAKE_BATCH_SIZE)
return loaded_model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment