Last active
April 18, 2022 09:29
-
-
Save yashbonde/a4fe05c49f9e1a5c8e21ae878c0e816e to your computer and use it in GitHub Desktop.
[Guesslang](https://github.com/yoeo/guesslang/) is cool, but there's to many bells and whistles in it. Here is a script that makes it easy.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sys | |
import json | |
import logging | |
import subprocess | |
from pathlib import Path | |
try: | |
import tensorflow as tf | |
except: | |
# install tensorflow | |
commands = ["pip", "install"] | |
if sys.platform.startswith("linux"): | |
commands.append("tensorflow==2.5.0") | |
elif sys.platform == "darwin": | |
commands.append("tensorflow-macos==2.8.0") | |
elif sys.platform == "win32": | |
commands.append("tensorflow==2.5.0") | |
subprocess.call(commands) | |
# re-import it | |
import tensorflow as tf | |
# Constants | |
DATA_DIR = Path(__file__).absolute().parent.joinpath('data') | |
DEFAULT_MODEL_DIR = DATA_DIR.joinpath('model') | |
LANGUAGES_FILE = DATA_DIR.joinpath('languages.json') | |
MODEL_FILES_URL = [ | |
"languages.json", | |
"model/saved_model.pb", | |
"model/variables/variables.data-00000-of-00001", | |
"model/variables/variables.index" | |
] | |
class Guess: | |
@staticmethod | |
def download_files(): | |
os.makedirs(DATA_DIR) # data/ | |
os.makedirs(os.path.join(DATA_DIR, "model")) # data/model/ | |
os.makedirs(os.path.join(DATA_DIR, "model/variables")) # data/model/variables | |
for suff in MODEL_FILES_URL: | |
subprocess.call([ | |
"wget", | |
"https://raw.githubusercontent.com/yoeo/guesslang/master/guesslang/data/" + suff, | |
"-O", | |
os.path.join(DATA_DIR, suff), | |
]) | |
def __init__(self): | |
if not os.path.exists(DEFAULT_MODEL_DIR): | |
self.download_files() | |
self._saved_model_dir = str(DEFAULT_MODEL_DIR) | |
self._model = tf.saved_model.load(self._saved_model_dir) | |
language_json = LANGUAGES_FILE.read_text() | |
self.language_map = json.loads(language_json) | |
self.extension_map = { | |
ext: name for name, ext in self.language_map.items() | |
} | |
def __call__(self, source_code: str, cutoff_p: float = 0.8, topk = 5): | |
# model forward pass logic | |
if type(source_code) == str: | |
content_tensor = tf.constant([source_code]) | |
elif isinstance(source_code, list) and type(source_code[0]) == str: | |
content_tensor = tf.constant(source_code) | |
else: | |
raise ValueError(f"Unkown type for source_code: {type(source_code)}") | |
predicted = self._model.signatures['serving_default'](content_tensor) | |
logits = predicted['scores'].numpy() | |
extensions = predicted['classes'].numpy() | |
# The predicted language probability must be higher than 2 standard deviations from the mean. | |
logits_mask_thresh = logits.max(-1) > (logits.mean(-1) + 2 * logits.std(-1)) | |
logits_mask_cutoff = (logits > cutoff_p).sum(-1) > 0 | |
logits_mask = logits_mask_thresh & logits_mask_cutoff | |
results = [] | |
for _npf, _ext, _mask in zip(logits, extensions, logits_mask): | |
if not _mask: | |
results.append(None) | |
continue | |
res_ = [] | |
for value, ext in zip(_npf, _ext): | |
res_.append((self.extension_map[ext.decode()], value)) | |
res_ = sorted(res_, key=lambda x: x[1], reverse=True) | |
results.append(res_[:topk]) | |
return results | |
if __name__ == "__main__": | |
gl = Guess() | |
print("Guesslang working fine!") | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment