Skip to content

Instantly share code, notes, and snippets.

@python273
Last active April 8, 2024 16:41
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save python273/c709de026ce43684292c29cf1f43e7ee to your computer and use it in GitHub Desktop.
Save python273/c709de026ce43684292c29cf1f43e7ee to your computer and use it in GitHub Desktop.
#include "tglang.h"
#include "weights.h"
#include <stdlib.h>
#include <string.h>
#include <stdint.h>
#include <stdio.h>
#include <math.h>
static inline float silu(float x) {
return x / (1.0 + expf(-x));
}
int probes_binary_search(int l, int r, uint32_t x) {
if (r >= l) {
int mid = l + (r - l) / 2;
if (probes[mid] == x) return mid;
if (probes[mid] > x) return probes_binary_search(l, mid - 1, x);
return probes_binary_search(mid + 1, r, x);
}
return -1;
}
int compare_slices(const void* a, const void* b) {
const uint32_t* val1 = (const uint32_t*)a;
const uint32_t* val2 = (const uint32_t*)b;
if (*val1 < *val2) {
return -1;
} else if (*val1 > *val2) {
return 1;
}
return 0;
}
const char *remove_prefixes(const char *text) {
const char *prefixes[] = {
"python\n", "javascript\n", "java\n", "lua\n", "bash\n", "csharp\n", "html\n",
"js\n", "php\n", "kotlin\n", "dart\n", "c\n", "sql\n", "css\n", "cpp\n",
"c++\n", "rust\n",
};
const size_t num_prefixes = sizeof(prefixes) / sizeof(prefixes[0]);
for (size_t i = 0; i < num_prefixes; ++i) {
size_t prefix_len = strlen(prefixes[i]);
if (strncmp(text, prefixes[i], prefix_len) == 0) {
text += prefix_len;
break;
}
}
return text;
}
enum TglangLanguage tglang_detect_programming_language(const char *text) {
float l1[l1_OUT] = {0.0};
float l2[l2_OUT] = {0.0};
float l3[l3_OUT] = {0.0};
float l4[l4_OUT] = {0.0};
float classifier[classifier_OUT] = {0.0};
while (*text == '\n' || *text == ' ') { ++text; }
text = remove_prefixes(text);
while (*text == '\n' || *text == ' ') { ++text; }
size_t text_len = strlen(text);
// printf("len %ld\n", text_len);
if (text_len < 6) return TGLANG_LANGUAGE_OTHER;
if (text_len > 4096*4) text_len = 4096*4;
// we need unique 4 byte slices from text
uint32_t slices[4096*4];
uint32_t slice = 0x0A0A0A0A;
size_t slices_index = 0;
for (; slices_index < text_len; slices_index++) {
unsigned char c = (unsigned char)text[slices_index];
if (c == '\t') c = ' ';
if (c == '\r') c = ' ';
slice = slice << 8 | c;
slices[slices_index] = slice;
}
qsort(slices, slices_index, sizeof(slices[0]), compare_slices);
uint32_t prev_slice = -1;
uint8_t found_slice = 0;
for (size_t i = 0; i < slices_index; i++) {
uint32_t slice = slices[i];
if (slice == prev_slice) { continue; }
prev_slice = slice;
int ind = probes_binary_search(0, (sizeof(probes) / sizeof(probes[0]))-1, slice);
if (ind == -1) continue;
found_slice = 1;
// printf("%d ", v);
for (int j = 0; j < l1_OUT; j++) {
l1[j] += l1_weight[ind][j];
}
}
if (found_slice == 0) return TGLANG_LANGUAGE_OTHER;
// printf("\n");
for (int i = 0; i < l1_OUT; i++) {
l1[i] += l1_bias[i];
}
for (int j = 0; j < l2_IN; j++) {
for (int i = 0; i < l2_OUT; i++) {
l2[i] += l1[j] * l2_weight[j][i];
}
}
for (int i = 0; i < l2_OUT; i++) {
l2[i] = silu(l2[i] + l2_bias[i]);
}
for (int j = 0; j < l3_IN; j++) {
for (int i = 0; i < l3_OUT; i++) {
l3[i] += l2[j] * l3_weight[j][i];
}
}
for (int i = 0; i < l3_OUT; i++) {
l2[i] += silu(l3[i] + l3_bias[i]);
}
for (int j = 0; j < l4_IN; j++) {
for (int i = 0; i < l4_OUT; i++) {
l4[i] += l2[j] * l4_weight[j][i];
}
}
for (int i = 0; i < l4_OUT; i++) {
l2[i] += silu(l4[i] + l4_bias[i]);
}
for (int j = 0; j < classifier_IN; j++) {
for (int i = 0; i < classifier_OUT; i++) {
classifier[i] += l2[j] * classifier_weight[j][i];
}
}
// for (int i = 0; i < classifier_OUT; i++) {
// printf("%f ", classifier[i]);
// }
// printf("\n");
size_t argmax = 0;
for (int i = 1; i < classifier_OUT; i++) {
if (classifier[i] > classifier[argmax]) {
argmax = i;
}
}
// printf("%ld\n", argmax);
return class_mapping[argmax];
}
# for reference, actual definition slightly different in weights init
PROBES_NUM = 177606
d_l1 = 32
d_model = 512
class Net(nn.Module):
def __init__(self):
super().__init__()
self.emb = nn.Embedding(PROBES_NUM+1, d_l1, padding_idx=PROBES_NUM)
self.emb_bias = nn.Parameter(torch.zeros([d_l1]))
self.l2 = nn.Linear(d_l1, d_model)
self.l3 = nn.Linear(d_model, d_model)
self.l4 = nn.Linear(d_model, d_model)
self.classifier = nn.Linear(d_model, LABEL_LEN, bias=False)
def forward(self, x):
x = self.emb(x)
x = x.sum(dim=1)
x = x + self.emb_bias
x = F.silu(self.l2(x))
x = x + F.silu(self.l3(x))
x = x + F.silu(self.l4(x))
x = self.classifier(x)
return x
import timeit
import ctypes
with open('tglang.h') as f:
s = f.read().split('enum TglangLanguage {', 1)[1].split('}', 1)[0]
TglangLanguage = [i.strip().removesuffix(',') for i in s.split('\n') if i]
libtglang = ctypes.CDLL('./build/libtglang.so')
libtglang.tglang_detect_programming_language.argtypes = [ctypes.c_char_p]
libtglang.tglang_detect_programming_language.restype = ctypes.c_int
print(libtglang)
def detect_programming_language(text):
r = libtglang.tglang_detect_programming_language(text)
return TglangLanguage[r]
print(detect_programming_language(b'SELECT * FROM users;'))
with open('../train_model.py', 'rb') as f:
s = f.read()
print(len(s))
s = s[:4096*4]
print(len(s))
print(detect_programming_language(s))
s = timeit.timeit(lambda: detect_programming_language(s), number=10000)
print(f'{s/10000} {10000/s}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment