-
-
Save vurtun/2de760b1bf9b35f3aa7620da4234e30f to your computer and use it in GitHub Desktop.
Neural Network
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <stdio.h> | |
#include <stdlib.h> | |
#include <limits.h> | |
#include <float.h> | |
#include <string.h> | |
#include <assert.h> | |
#include <stdint.h> | |
#include <stdarg.h> | |
#include <math.h> | |
#define unused(a) ((void)a) | |
#define cast(t, p) ((t)(p)) | |
#define casts(p) cast(short,p) | |
#define casti(p) cast(int,p) | |
#define castu(p) cast(unsigned,p) | |
#define castl(p) cast(long,p) | |
#define castll(p) cast(long long,p) | |
#define castull(p) cast(unsigned long long,p) | |
#define castf(p) cast(float,p) | |
#define castd(p) cast(double,p) | |
#define castsz(p) cast(size_t,p) | |
#define castss(p) cast(ssize_t,p) | |
#define recast(T,p) ((T)cast(void*,(p))) | |
#define szof(a) ((int)sizeof(a)) | |
#define cntof(a) ((int)(sizeof(a) / sizeof((a)[0]))) | |
#define flag(n) ((1u) << (n)) | |
#define min(a,b) ((a) < (b) ? (a) : (b)) | |
#define max(a,b) ((a) > (b) ? (a) : (b)) | |
#define clamp(a, v, b) (max(min(b, v), a)) | |
#define iswap(x,y) do {((x) ^= (y), (y) ^= (x), (x) ^= (y)); }while(0) | |
#define align_mask(a) ((a)-1) | |
#define align_down_masked(n, m) ((n) & ~(m)) | |
#define align_down(n, a) align_down_masked(n, align_mask(a)) | |
#define align_up(n, a) align_down((n) + align_mask(a), (a)) | |
#define xglue(x, y) x##y | |
#define glue(x, y) xglue(x, y) | |
#define uniqid(name) glue(name, __LINE__) | |
static void | |
die(const char *fmt, ...) { | |
va_list args; | |
va_start(args, fmt); | |
vfprintf(stderr, fmt, args); | |
fprintf(stderr, "\n"); | |
va_end(args); | |
exit(1); | |
} | |
/* --------------------------------------------------------------------------- | |
* Platform | |
* --------------------------------------------------------------------------- */ | |
#ifdef _MSC_VER | |
#define alignto(x) __declspec(align(x)) | |
#define bit_cnt(u) __popcnt(u) | |
#define bit_cnt64(u) __popcnt64(u) | |
static int bit_ffs32(unsigned int u) {_BitScanForward(&u, u); return casti(u);} | |
static int bit_ffs64(unsigned long long u) {_BitScanForward64(&u, u); return casti(u);} | |
#else /* GCC, CLANG */ | |
#define alignto(x) __attribute__((aligned(x))) | |
#define bit_cnt(u) __builtin_popcount(u) | |
#define bit_cnt64(u) __builtin_popcountll(u) | |
#define bit_ffs32(u) __builtin_ctz(u) | |
#define bit_ffs64(u) __builtin_ctzll(u) | |
#endif | |
#ifdef _WIN32 /* Windows */ | |
#include "ntsecapi.h" | |
static unsigned long long | |
sys_rnd64(void) { | |
unsigned long long rnd = 0; | |
if (!RtlGenRandom(&rnd, sizeof(rnd))) { | |
fprintf(stderr, "failed to generate system random number\n"); | |
exit(1); | |
} | |
return rnd; | |
} | |
#else /* UNIX */ | |
#include <fcntl.h> | |
#include <unistd.h> | |
static unsigned long long | |
sys_rnd64(void) { | |
ssize_t res; | |
unsigned long long rnd = 0; | |
int fp = open("/dev/urandom", O_RDONLY); | |
if (fp == -1) { | |
fprintf(stderr, "failed to access system random number\n"); | |
exit(1); | |
} | |
res = read(fp, cast(char*, &rnd), sizeof(rnd)); | |
if (res < szof(rnd)) { | |
fprintf(stderr, "failed to generate system random number\n"); | |
exit(1); | |
} | |
close(fp); | |
return rnd; | |
} | |
#endif | |
#ifdef __x86_64__ /* SSE */ | |
#define SSE_ALIGN_BYTES 16 | |
#define SSE_ALIGN alignto(SSE_ALIGN_BYTES) | |
#include <emmintrin.h> | |
static const char* | |
str_chr(const char *s, const char *e, int chr) { | |
static const char unsigned ovr_msk[32] = { | |
255, 255, 255, 255, 255, 255, 255, 255, | |
255, 255, 255, 255, 255, 255, 255, 255, | |
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |
}; | |
const char *e = s + n; | |
__m128i m = _mm_set1_epi8(chr & 0xff); | |
for (; s < e; s += 16) { | |
int r = (int)(e - s); r = r > 16 ? 16 : r; | |
__m128i o = _mm_loadu_si128((const __m128i *)(ovr_msk + 16 - r)); | |
__m128i d = _mm_loadu_si128((const __m128i *)(const void*)s); | |
__m128i v = _mm_and_si128(d, o); | |
unsigned msk = _mm_movemask_epi8(_mm_cmpeq_epi8(v,m)); | |
if (msk) { | |
return s + (31 - __builtin_clz(msk)); | |
} | |
} | |
return e; | |
} | |
static int | |
line_cnt(const char *s, int n) { | |
static const char unsigned ovr_msk[32] = { | |
255, 255, 255, 255, 255, 255, 255, 255, | |
255, 255, 255, 255, 255, 255, 255, 255, | |
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |
}; | |
int cnt = 0; | |
const char *e = s + n; | |
__m128i m = _mm_set1_epi8('\n'); | |
for (; s < e; s += 16) { | |
int r = casti(e - s); | |
int l = r > 16 ? 16 : r; | |
__m128i o = _mm_loadu_si128((const __m128i *)(ovr_msk + 16 - l)); | |
__m128i d = _mm_loadu_si128((const __m128i *)(const void*)s); | |
__m128i v = _mm_and_si128(d, o); | |
unsigned msk = _mm_movemask_epi8(_mm_cmpeq_epi8(v,m)); | |
cnt += bit_cnt(msk); | |
} | |
return cnt; | |
} | |
#elif defined(__arm__) || defined(__aarch64__) /* ARM NEON */ | |
#include <arm_neon.h> | |
#define SSE_ALIGN_BYTES 16 | |
#define SSE_ALIGN alignto(SSE_ALIGN_BYTES) | |
static const char* | |
str_chr(const char *s, const char *e, int chr) { | |
static const char unsigned ovr_msk[32] = { | |
255, 255, 255, 255, 255, 255, 255, 255, | |
255, 255, 255, 255, 255, 255, 255, 255, | |
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |
}; | |
uint8x16_t m = vdupq_n_u8(chr & 0xff); | |
for (; s < e; s += 16) { | |
int r = (int)(e - s); r = r > 16 ? 16 : r; | |
uint8x16_t o = vld1q_u8(ovr_msk + 16 - r); | |
uint8x16_t d = vld1q_u8((const unsigned char*)s); | |
uint8x16_t v = vandq_u8(d, o); | |
uint8x16_t c = vceqq_u8(v, m); | |
uint64x2_t p = vreinterpretq_u64_u8(c); | |
uint64_t vlo = vgetq_lane_u64(p, 0); | |
if (vlo) { | |
return s + ((bit_ffs64(vlo)) >> 3); | |
} | |
uint64_t vhi = vgetq_lane_u64(p, 1); | |
if (vhi) { | |
return s + 8 + ((bit_ffs64(vhi)) >> 3); | |
} | |
} | |
return e; | |
} | |
static int | |
line_cnt(const char *s, int n) { | |
static const char unsigned ovr_msk[32] = { | |
255, 255, 255, 255, 255, 255, 255, 255, | |
255, 255, 255, 255, 255, 255, 255, 255, | |
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |
}; | |
int cnt = 0; | |
const char *e = s + n; | |
uint8x16_t m = vdupq_n_u8('\n'); | |
for (; s < e; s += 16) { | |
int r = casti(e - s); | |
int l = r > 16 ? 16 : r; | |
uint8x16_t o = vld1q_u8(ovr_msk + 16 - l); | |
uint8x16_t d = vld1q_u8((const unsigned char*)s); | |
uint8x16_t v = vandq_u8(d, o); | |
uint8x16_t c = vceqq_u8(v, m); | |
uint64x2_t p = vreinterpretq_u64_u8(c); | |
cnt += bit_cnt64(vgetq_lane_u64(p, 0)) >> 3; | |
cnt += bit_cnt64(vgetq_lane_u64(p, 1)) >> 3; | |
} | |
return cnt; | |
} | |
#else /* standard c */ | |
static const char* | |
str_chr(const char *s, const char *e, int c) { | |
unsigned m = 0x01010101u * (unsigned)(c & 0xff); | |
for (;s < e && ((uintptr_t)s & 0x03); ++s) { | |
if (s[0] == c) { | |
return s; | |
} | |
} | |
for (;s < end; s += 4) { | |
unsigned v = *(unsigned*)s; | |
unsigned k = (~v) & 0x80808080; | |
unsigned x = ((v ^ m) - 0x01010101) & k; | |
if (x) return s + (bit_ffs32(x) >> 3u); | |
} | |
} | |
static int | |
line_cnt(const char *str, int len) { | |
int cnt = 1; | |
for (int i = 0; i < len; ++i) { | |
if (str[i] == '\n') { | |
cnt++; | |
} | |
} | |
return cnt; | |
} | |
#endif | |
/* --------------------------------------------------------------------------- | |
* String | |
* --------------------------------------------------------------------------- */ | |
struct str { | |
const char *str; | |
const char *end; | |
int len; | |
}; | |
#define str_rhs(s, n) str_sub(s, min((s).len, n), (s).len) | |
#define str_lhs(s, n) str_sub(s, 0, min((s).len, n)) | |
#define str_cut_lhs(s, n) *(s) = str_rhs(*(s), n) | |
#define str_cut_rhs(s, n) *(s) = str_lhs(*(s), n) | |
#define for_str_tok(it, rest, src, delim) \ | |
for ((rest) = (src), (it) = str_split_cut(&(rest), (delim)); \ | |
(it).len; (it) = str_split_cut(&(rest), (delim))) | |
static struct str | |
str(const char *p, int len) { | |
struct str s = {0}; | |
s.str = p; | |
s.end = p + len; | |
s.len = len; | |
return s; | |
} | |
static struct str | |
str_sub(struct str s, int from, int to) { | |
int b = min(from, to); | |
int e = max(from, to); | |
struct str r = {0}; | |
r.str = s.str + min(b, s.len); | |
r.end = s.str + min(e, s.len); | |
r.len = casti(r.end - r.str); | |
return r; | |
} | |
static struct str | |
str_split_cut(struct str *s, int delim) { | |
const char *at = str_chr(s->str, s->end, delim); | |
if (at < s->end) { | |
int p = casti(at - s->str); | |
struct str res = str_lhs(*s, p); | |
str_cut_lhs(s, p + 1); | |
return res; | |
} else { | |
struct str res = *s; | |
memset(s, 0, sizeof(*s)); | |
return res; | |
} | |
} | |
/* --------------------------------------------------------------------------- | |
* Command Arguments | |
* --------------------------------------------------------------------------- */ | |
#define CMD_ARGC() argc_ | |
#define cmd_arg_opt_str(argv, x) ((argv[0][1] == '\0' && argv[1] == 0)?\ | |
((x), (char *)0) : (brk_ = 1, (argv[0][1] != '\0') ?\ | |
(&argv[0][1]) : (argc--, argv++, argv[0]))) | |
#define cmd_arg_opt_int(argv,x) cmd_arg_int(cmd_arg_opt_str(argv,x)) | |
#define cmd_arg_opt_flt(argv,x) cmd_arg_flt(cmd_arg_opt_str(argv,x)) | |
#define CMD_ARG_BEGIN(argv0, argc, argv) \ | |
for (argv0 = *argv, argv++, argc--; argv[0] && argv[0][1] && argv[0][0] == '-'; argc--, argv++) {\ | |
char argc_, **argv_; int brk_;\ | |
if (argv[0][1] == '-' && argv[0][2] == '\0') {argv++; argc--; break;}\ | |
for (brk_ = 0, argv[0]++, argv_ = argv; argv[0][0] && !brk_; argv[0]++) {\ | |
if (argv_ != argv) break;\ | |
argc_ = argv[0][0];\ | |
switch (argc_) | |
#define CMD_ARG_END }} | |
static int | |
cmd_arg_int(const char *str) { | |
char *ep = 0; | |
long n = strtol(str, &ep, 10); | |
if (*ep != '\0' || ep == str) { | |
die("Invalid argument number: %s\n", str); | |
} | |
if (n < INT_MIN || n > INT_MAX) { | |
die("Argument number: %ld is out of range\n", n); | |
} | |
return casti(n); | |
} | |
static float | |
cmd_arg_flt(const char *str) { | |
char *ep = 0; | |
float n = strtof(str, &ep); | |
if (*ep != '\0' || ep == str) { | |
die("Invalid argument number: %s\n", str); | |
} | |
return n; | |
} | |
/* --------------------------------------------------------------------------- | |
* Utility | |
* --------------------------------------------------------------------------- */ | |
#define swap(x,y) do { \ | |
unsigned char uniqid(t)[szof(x) == szof(y) ? szof(x) : -1]; \ | |
memcpy(uniqid(t),&y,sizeof(x)); \ | |
memcpy(&y,&x,sizeof(x)); \ | |
memcpy(&x,uniqid(t),sizeof(x)); \ | |
} while(0) | |
#define arr_shfl(a,n,p) do { \ | |
for (int uniqid(i) = 0; uniqid(i) < n; ++uniqid(i)) { \ | |
if (p[uniqid(i)] >= 0) { \ | |
int uniqid(j) = uniqid(i); \ | |
while (p[uniqid(j)] != uniqid(i)) { \ | |
const int uniqid(d) = p[uniqid(j)]; \ | |
swap(a[uniqid(j)], a[uniqid(d)]); \ | |
p[uniqid(j)] = -1 - uniqid(d); \ | |
uniqid(j) = uniqid(d); \ | |
} p[uniqid(j)] = -1 - p[uniqid(j)]; \ | |
} \ | |
}} while (0) | |
static void* | |
xalloc(int siz) { | |
void *mem = calloc(castsz(siz), 1); | |
if (!mem) { | |
die("Out of Memory"); | |
} | |
return mem; | |
} | |
static unsigned long long | |
rnd_gen(unsigned long long x, int n) { | |
return x + castull(n) * 0x9E3779B97F4A7C15llu; | |
} | |
static unsigned long long | |
rnd_mix(unsigned long long z) { | |
z = (z ^ (z >> 30)) * 0xBF58476D1CE4E5B9llu; | |
z = (z ^ (z >> 27)) * 0x94D049BB133111EBllu; | |
return z ^ (z >> 31llu); | |
} | |
static unsigned long long | |
rnd_split_mix(unsigned long long *x, int i) { | |
*x = rnd_gen(*x, i); | |
return rnd_mix(*x); | |
} | |
static unsigned long long | |
rnd(unsigned long long *x) { | |
return rnd_split_mix(x, 1); | |
} | |
static unsigned | |
rndu(unsigned long long *x) { | |
unsigned long long z = rnd(x); | |
return castu(z & 0xffffffffu); | |
} | |
static unsigned | |
rnduu(unsigned long long *x, unsigned mini, unsigned maxi) { | |
unsigned lo = min(mini, maxi); | |
unsigned hi = max(mini, maxi); | |
unsigned rng = castu(-1); | |
unsigned n = hi - lo + 1u; | |
if (n == 1u) { | |
return mini; | |
} else if(n == 0u) { | |
return rndu(x); | |
} else { | |
unsigned v = 0; | |
unsigned remainder = rng % n; | |
do {v = rndu(x);} | |
while(v >= rng - remainder); | |
return mini + v % n; | |
} | |
} | |
static float | |
rndf01(unsigned long long *x) { | |
unsigned u = rndu(x); | |
double du = castd(u); | |
double div = castd((unsigned)-1); | |
return castf(du/div); | |
} | |
static float | |
rnduf(unsigned long long *x, float mini, float maxi) { | |
float lo = min(mini, maxi); | |
float hi = max(mini, maxi); | |
unsigned u = rndu(x); | |
float rng = hi - lo; | |
double du = castd(u); | |
double div = castd((unsigned)-1); | |
return lo + rng * castf(du/div); | |
} | |
static void | |
seq_lin(int *seq, int cnt) { | |
int i = 0; | |
for (i = 0; i < cnt; ++i) { | |
seq[i] = i; | |
} | |
} | |
static void | |
seq_rnd(int *seq, int n, unsigned long long *r) { | |
int i = 0; | |
for (i = n - 1; i > 0; --i) { | |
unsigned at = rndu(r) % castu(i + 1); | |
iswap(seq[i], seq[at]); | |
} | |
} | |
static char* | |
file_load(int *siz, const char *path) { | |
FILE *fd = 0; | |
char *mem = 0; | |
size_t ret = 0; | |
/* open file */ | |
assert(path); | |
fd = fopen(path, "r"); | |
if (!fd) { | |
die("Unable to open file: %s\n", path); | |
} | |
/* calculate file size */ | |
fseek(fd, 0, SEEK_END); | |
*siz = casti(ftell(fd)); | |
if (*siz < 0) { | |
die("Unable to access file: %s\n", path); | |
} | |
if (*siz == 0) { | |
die("File is empty: %s\n", path); | |
} | |
fseek(fd, 0, SEEK_SET); | |
/* read file into memory */ | |
mem = xalloc(*siz + SSE_ALIGN_BYTES); | |
assert(mem); | |
if (!mem) { | |
die("file calloc failed"); | |
} | |
ret = fread(mem, 1, castsz(*siz), fd); | |
if (ret < castsz(*siz)) { | |
die("file fread failed: %d\n", ret); | |
} | |
fclose(fd); | |
return mem; | |
} | |
/* ---------------------------------------------------------------------------== | |
* | |
* | |
* Neural Network | |
* | |
* | |
* --------------------------------------------------------------------------- */ | |
struct nn_ctx { | |
/* in */ | |
void *mem; | |
int in_cnt; | |
int out_cnt; | |
int hid_cnt; | |
/* out */ | |
int n_in_cnt; | |
int n_out_cnt; | |
int n_hid_cnt; | |
int w_in_hid_cnt; | |
int w_hid_out_cnt; | |
float *n_in; | |
float *n_hid; | |
float *n_out; | |
int *outs; | |
float *w_in_hid; | |
float *w_hid_out; | |
/* intern */ | |
unsigned setup:1; | |
}; | |
static int | |
nn_req_net_siz(struct nn_ctx *ctx) { | |
int n_in_siz = szof(float) * (ctx->in_cnt + 1); | |
int n_hid_siz = szof(float) * (ctx->hid_cnt + 1); | |
int n_out_siz = szof(float) * ctx->out_cnt; | |
int out_siz = szof(int) * ctx->out_cnt; | |
int wih_siz = szof(float) * ((ctx->in_cnt + 1) * (ctx->hid_cnt + 1)); | |
int woh_siz = szof(float) * ((ctx->hid_cnt + 1) * ctx->out_cnt); | |
return n_in_siz + n_hid_siz + n_out_siz + out_siz + wih_siz + woh_siz; | |
} | |
static void | |
nn__setup_net(struct nn_ctx *ctx) { | |
ctx->setup = 1; | |
ctx->n_in_cnt = ctx->in_cnt + 1; | |
ctx->n_out_cnt = ctx->out_cnt; | |
ctx->n_hid_cnt = ctx->hid_cnt + 1; | |
ctx->w_in_hid_cnt = (ctx->n_in_cnt * ctx->n_hid_cnt); | |
ctx->w_hid_out_cnt = (ctx->n_out_cnt * ctx->n_hid_cnt); | |
/* setup memory */ | |
ctx->n_in = ctx->mem; | |
ctx->n_hid = ctx->n_in + ctx->n_in_cnt; | |
ctx->n_out = ctx->n_hid + ctx->n_hid_cnt; | |
ctx->w_in_hid = ctx->n_out + ctx->n_out_cnt; | |
ctx->w_hid_out = ctx->w_in_hid + ctx->w_in_hid_cnt; | |
ctx->outs = cast(int*, ctx->w_hid_out + ctx->w_hid_out_cnt); | |
ctx->n_in[ctx->in_cnt] = -1.0; | |
ctx->n_hid[ctx->hid_cnt] = -1.0; | |
} | |
static int | |
nn__in_hide_weight_idx(const struct nn_ctx *ctx, int in_idx, int hid_idx) { | |
return in_idx * ctx->hid_cnt + hid_idx; | |
} | |
static int | |
nn__hide_out_weight_idx(const struct nn_ctx *ctx, int hid_idx, int out_idx) { | |
return hid_idx * ctx->out_cnt + out_idx; | |
} | |
static void | |
nn_init(struct nn_ctx *ctx) { | |
int ii, hi, oi; | |
unsigned long long rnd_gen = sys_rnd64(); | |
float dist = (2.4f / ctx->in_cnt); | |
if (!ctx->setup) { | |
nn__setup_net(ctx); | |
} | |
for (ii = 0; ii <= ctx->n_in_cnt; ++ii) { | |
for (hi = 0; hi <= ctx->hid_cnt; ++hi) { | |
int wi = nn__in_hide_weight_idx(ctx, ii, hi); | |
ctx->w_in_hid[wi] = (rndf01(&rnd_gen) * 2.0f - 1.0f) * dist; | |
} | |
} | |
for (hi = 0; hi <= ctx->n_hid_cnt; ++hi) { | |
for (oi = 0; oi < ctx->out_cnt; ++oi) { | |
int wi = nn__hide_out_weight_idx(ctx, hi, oi); | |
ctx->w_hid_out[wi] = (rndf01(&rnd_gen) * 2.0f - 1.0f) * dist; | |
} | |
} | |
} | |
static float | |
nn__sigmoid(float x) { | |
return 1.0f / ( 1.0f + expf(-x)); | |
} | |
static int | |
nn__clamp(float x) { | |
if (x < 0.1f) return 0; | |
else if (x > 0.9f) return 1; | |
else return -1.0f; | |
} | |
static int* | |
nn_eval(struct nn_ctx *ctx, const float *in) { | |
int oi, ii, hi = 0; | |
memcpy(ctx->n_in, in, sizeof(float) * castsz(ctx->in_cnt)); | |
for (hi = 0; hi < ctx->hid_cnt; ++hi) { | |
/* calculate weightes sum of pattern and bias neuron */ | |
ctx->n_hid[hi] = 0; | |
for (ii = 0; ii <= ctx->in_cnt; ++ii) { | |
int wi = nn__in_hide_weight_idx(ctx, ii, hi); | |
ctx->n_hid[hi] += ctx->n_in[ii] * ctx->w_in_hid[wi]; | |
} | |
ctx->n_hid[hi] = nn__sigmoid(ctx->n_hid[hi]); | |
} | |
for (oi = 0; oi < ctx->out_cnt; ++oi) { | |
/* calcule output values - include bias neuron */ | |
ctx->n_out[oi] = 0; | |
for (hi = 0; hi <= ctx->hid_cnt; ++hi) { | |
int wi = nn__hide_out_weight_idx(ctx, hi, oi); | |
ctx->n_out[oi] += ctx->n_hid[hi] * ctx->w_hid_out[wi]; | |
} | |
ctx->n_out[oi] = nn__sigmoid(ctx->n_out[oi]); | |
ctx->outs[oi] = nn__clamp(ctx->n_out[oi]); | |
} | |
return ctx->outs; | |
} | |
/* --------------------------------------------------------------------------- | |
* Data | |
* --------------------------------------------------------------------------- | |
*/ | |
struct nn_train_elm { | |
float *in; | |
int *out; | |
}; | |
struct nn_train_set { | |
int begin; | |
int end; | |
int cnt; | |
}; | |
struct nn_train_data { | |
/* in */ | |
void *mem; | |
int in_cnt; | |
int hid_cnt; | |
int out_cnt; | |
int elm_cnt; | |
/* out */ | |
struct nn_train_elm *elms; | |
struct nn_train_set train; | |
struct nn_train_set gen; | |
struct nn_train_set val; | |
int *seq; | |
/* intern */ | |
unsigned setup:1; | |
}; | |
static int | |
nn_req_train_data_siz(const struct nn_train_data *trn) { | |
int es = szof(struct nn_train_elm) * trn->elm_cnt; | |
int ins = szof(float) * trn->in_cnt * trn->elm_cnt; | |
int outs = szof(int) * trn->out_cnt * trn->elm_cnt; | |
int shfl = szof(int) * trn->elm_cnt; | |
return es + ins + shfl + outs; | |
} | |
static void | |
nn__setup_train_data(struct nn_train_data *ctx) { | |
int i; | |
float *ins; | |
int *outs; | |
ctx->elms = recast(struct nn_train_elm*, ctx->mem); | |
ctx->seq = recast(int*, ctx->elms + ctx->elm_cnt); | |
ins = recast(float*, (ctx->seq + ctx->elm_cnt)); | |
outs = recast(int*, (ins + ctx->in_cnt * ctx->elm_cnt)); | |
for (i = 0; i < ctx->elm_cnt; ++i, ins += ctx->in_cnt, outs += ctx->out_cnt) { | |
ctx->elms[i].in = ins; | |
ctx->elms[i].out = outs; | |
} | |
ctx->setup = 1; | |
} | |
static void | |
nn__split_train_data_set(struct nn_train_data *dat, int cnt) { | |
float elm_cnt = castf(cnt); | |
int train_cnt = casti(0.6f * elm_cnt); | |
int gen_cnt = casti(ceil(0.2f * elm_cnt)); | |
dat->train.begin = 0; | |
dat->train.end = train_cnt; | |
dat->train.cnt = train_cnt; | |
dat->gen.begin = train_cnt; | |
dat->gen.end = dat->gen.begin + gen_cnt; | |
dat->gen.cnt = gen_cnt; | |
dat->val.begin = dat->gen.end; | |
dat->val.end = cnt; | |
dat->val.cnt = cnt - dat->val.begin; | |
} | |
static void | |
nn__load_train_data_elm(struct nn_train_elm *elm, const struct str *tok, | |
int in_cnt, int out_cnt) { | |
int i = 0; | |
struct str it, _; | |
for_str_tok(it, _, *tok, ',') { | |
char *ep = 0; | |
float f = 0.0; | |
if (i >= in_cnt + out_cnt) { | |
break; | |
} | |
*(char*)it.end = '\0'; | |
f = strtof(it.str, &ep); | |
*(char*)it.end = ','; | |
if (i < in_cnt) { | |
elm->in[i] = f; | |
} else { | |
elm->out[i - in_cnt] = casti(f); | |
} | |
i++; | |
} | |
} | |
static void | |
nn_train_data_load(struct nn_train_data* dat, const char *file, int len) { | |
int i = 0; | |
unsigned long long rnd_gen = sys_rnd64(); | |
struct str ln, _, in = str(file,len); | |
if (!dat->setup) { | |
nn__setup_train_data(dat); | |
} | |
for_str_tok(ln, _, in, '\n') { | |
if (ln.len > 2) { | |
struct nn_train_elm *elm = &dat->elms[i]; | |
nn__load_train_data_elm(elm, &ln, dat->in_cnt, dat->out_cnt); | |
i++; | |
} | |
} | |
seq_rnd(dat->seq, dat->elm_cnt, &rnd_gen); | |
arr_shfl(dat->elms, dat->elm_cnt, dat->seq); | |
nn__split_train_data_set(dat, dat->elm_cnt); | |
} | |
/* --------------------------------------------------------------------------- | |
* Trainer | |
* --------------------------------------------------------------------------- | |
*/ | |
struct nn_trn { | |
/* in */ | |
void *mem; | |
/* in: settings */ | |
unsigned use_batch:1; | |
float learn_rate; | |
float momentum; | |
float tar_acc; | |
float acc; | |
/* out */ | |
float trn_acc; | |
float trn_mse; | |
float gen_acc; | |
float gen_mse; | |
float val_acc; | |
float val_mse; | |
unsigned done:1; | |
/* intern */ | |
unsigned setup:1; | |
float *dt_in_hid; | |
float *dt_hid_out; | |
float *err_hid; | |
float *err_out; | |
}; | |
static int | |
nn_req_trn_siz(struct nn_ctx *ctx) { | |
int ih = szof(float) * ctx->w_in_hid_cnt; | |
int oh = szof(float) * ctx->w_hid_out_cnt; | |
int nh = szof(float) * ctx->n_hid_cnt; | |
int no = szof(float) * ctx->n_out_cnt; | |
return ih + oh + nh + no; | |
} | |
static void | |
nn_trn_init(struct nn_trn *trn, struct nn_ctx *ctx) { | |
trn->setup = 1; | |
trn->learn_rate = (trn->learn_rate == 0.0f) ? 0.001f : trn->learn_rate; | |
trn->momentum = (trn->momentum == 0.0f) ? 0.9f : trn->momentum; | |
trn->acc = (trn->acc == 0.0f) ? 90.0f : trn->acc; | |
trn->dt_in_hid = trn->mem; | |
trn->dt_hid_out = trn->dt_in_hid + szof(float) * ctx->w_in_hid_cnt; | |
trn->err_hid = trn->dt_hid_out + szof(float) * ctx->w_hid_out_cnt; | |
trn->err_out = trn->err_hid + szof(float) * ctx->n_hid_cnt; | |
} | |
static void | |
nn__trn_update_weights(struct nn_trn *trn, struct nn_ctx *ctx) { | |
int ii, hi, oi; | |
/* input -> hidden weights */ | |
for (ii = 0; ii <= ctx->in_cnt; ++ii) { | |
for (hi = 0; hi <= ctx->hid_cnt; ++hi) { | |
int w_idx = nn__in_hide_weight_idx(ctx, ii, hi); | |
ctx->w_in_hid[w_idx] += trn->dt_in_hid[w_idx]; | |
if (trn->use_batch) { | |
trn->dt_in_hid[w_idx] = 0; | |
} | |
} | |
} | |
/* output -> hidden weights */ | |
for (hi = 0; hi <= ctx->hid_cnt; ++hi) { | |
for (oi = 0; oi < ctx->out_cnt; ++oi) { | |
int wi = nn__hide_out_weight_idx(ctx, hi, oi); | |
ctx->w_hid_out[wi] += trn->dt_hid_out[wi]; | |
if (trn->use_batch) { | |
trn->dt_hid_out[wi] = 0; | |
} | |
} | |
} | |
} | |
static float | |
nn__trn_out_err_gradient(float tar_val, float out_val) { | |
return out_val * (1.0f - out_val) * (tar_val - out_val); | |
} | |
static float | |
nn__trn_hid_err_gradient(struct nn_trn *trn, struct nn_ctx *ctx, int hi) { | |
int oi = 0; | |
float w_sum = 0; | |
for (oi = 0; oi < ctx->out_cnt; ++oi) { | |
int w_idx = nn__hide_out_weight_idx(ctx, hi, oi); | |
w_sum += ctx->w_hid_out[w_idx] * trn->err_out[oi]; | |
} | |
return ctx->n_hid[hi] * (1.0f - ctx->n_hid[hi]) * w_sum; | |
} | |
static void | |
nn__trn_backpropergate(struct nn_trn *trn, struct nn_ctx *ctx, const int *exp_out) { | |
int ii = 0, oi = 0, hi = 0; | |
/* modify deltas between hidden and output layers */ | |
for (oi = 0; oi < ctx->out_cnt; ++oi) { | |
/* get error gradient for output node */ | |
trn->err_out[oi] = nn__trn_out_err_gradient(castf(exp_out[oi]), ctx->n_out[oi]); | |
/* for all nodes in hidden layer and bias neuron calculate change in weight */ | |
for (hi = 0; hi <= ctx->hid_cnt; ++hi) { | |
int wi = nn__hide_out_weight_idx(ctx, hi, oi); | |
if (trn->use_batch) { | |
trn->dt_hid_out[wi] += trn->learn_rate * ctx->n_hid[hi] * trn->err_out[oi]; | |
} else { | |
trn->dt_hid_out[wi] = trn->learn_rate * ctx->n_hid[hi] * trn->err_out[oi] + trn->momentum * trn->dt_hid_out[wi]; | |
} | |
} | |
} | |
/* modify deltas between input and hidden layers */ | |
for (hi = 0; hi <= ctx->hid_cnt; ++hi) { | |
trn->err_hid[hi] = nn__trn_hid_err_gradient(trn, ctx, hi); | |
for (ii = 0; ii <= ctx->in_cnt; ++ii) { | |
int wi = nn__in_hide_weight_idx(ctx, ii, hi); | |
if (trn->use_batch) { | |
trn->dt_in_hid[wi] += trn->learn_rate * ctx->n_in[ii] * trn->err_hid[hi]; | |
} else { | |
trn->dt_in_hid[wi] = trn->learn_rate * ctx->n_in[ii] * trn->err_hid[hi] + trn->momentum * trn->dt_in_hid[wi]; | |
} | |
} | |
} | |
/* If using stochastic learning update the weights immediately */ | |
if (!trn->use_batch) { | |
nn__trn_update_weights(trn, ctx); | |
} | |
} | |
static void | |
nn__trn_tst(float *mse, float *acc, struct nn_trn *trn, struct nn_ctx *ctx, | |
const struct nn_train_data *dat, const struct nn_train_set *set) { | |
int oi, i = 0; | |
float false_cnt = 0; | |
*mse = *acc = 0; | |
for (i = set->begin; i < set->end; ++i) { | |
int is_correct = 1; | |
struct nn_train_elm *elm = dat->elms + i; | |
nn_eval(ctx, elm->in); | |
for (oi = 0; oi < ctx->out_cnt; ++oi) { | |
if (ctx->outs[oi] != elm->out[oi]) { | |
is_correct = 0; | |
} | |
*mse += castf(pow((ctx->n_out[oi] - elm->out[oi]), 2)); | |
} | |
if (!is_correct) { | |
false_cnt += 1.0f; | |
} | |
} | |
*acc = 100.0f - (false_cnt / castf(set->cnt)) * 100.0f; | |
*mse = *mse / (ctx->out_cnt * castf(set->cnt)); | |
} | |
static int | |
nn_trn(struct nn_trn *trn, struct nn_ctx *ctx, struct nn_train_data *dat) { | |
int i = 0, oi = 0; | |
float mse = 0; | |
float incorrect_entries = 0; | |
if (trn->trn_acc >= trn->tar_acc && trn->gen_acc >= trn->tar_acc) { | |
trn->done = 1; | |
return 0; | |
} | |
for (i = dat->train.begin; i < dat->train.end; ++i) { | |
int is_correct = 1; | |
struct nn_train_elm *elm = dat->elms + i; | |
nn_eval(ctx, elm->in); | |
nn__trn_backpropergate(trn, ctx, elm->out); | |
for (oi = 0; oi < ctx->out_cnt; ++oi) { | |
if (ctx->outs[oi] != elm->out[oi]) { | |
is_correct = 0; | |
} | |
mse += castf(pow((ctx->n_out[oi] - elm->out[oi]), 2)); | |
} | |
if (!is_correct) { | |
incorrect_entries += 1.0f; | |
} | |
} | |
if (trn->use_batch) { | |
nn__trn_update_weights(trn, ctx); | |
} | |
trn->trn_acc = 100.0f - (incorrect_entries / castf(dat->train.cnt)) * 100.0f; | |
trn->trn_mse = mse / (ctx->out_cnt * castf(dat->train.cnt)); | |
nn__trn_tst(&trn->gen_mse, &trn->gen_acc, trn, ctx, dat, &dat->gen); | |
return 1; | |
} | |
static void | |
nn_trn_tst(struct nn_trn *trn, struct nn_ctx *ctx, struct nn_train_data *dat) { | |
nn__trn_tst(&trn->val_mse, &trn->val_acc, trn, ctx, dat, &dat->val); | |
} | |
/* --------------------------------------------------------------------------- | |
* | |
* | |
* App | |
* | |
* | |
* --------------------------------------------------------------------------- */ | |
static void | |
usage(const char *app) { | |
die("\n" | |
"usage: %s [options] data in hidden out\n" | |
"\n" | |
" arguments:\n" | |
"\n" | |
" data, Path to training data file\n" | |
" in, Number of input neurons\n" | |
" hidden, Number of hidden neurons\n" | |
" out, Number of output neurons\n" | |
"\n" | |
" options:\n" | |
" -n <epochs>, epoch count (150))\n" | |
" -l <rate>, Learning rate (0.001)\n" | |
" -a <acc>, Desired Accuracy (90)\n" | |
" -m <mom>, Momentum (0.9)\n" | |
" -o <path>, Code out file path\n" | |
" -b enable batch learning\n" | |
" -h help message\n" | |
"\n", | |
app | |
); | |
exit(1); | |
} | |
static void | |
write_result(FILE *fp, const struct nn_ctx *ctx) { | |
int hi, ii, oi, n = 0; | |
static const char code[] = | |
" for (hi = 0; hi < hid_cnt; ++hi) {\n" | |
" float32x4_t sum = vdupq_n_f32(0.0f);\n" | |
" for (ii = 0; ii <= in_cnt; i += 4) {\n" | |
" float32x4_t in = vld1q_f32(n_in + ii);\n" | |
" float32x4_t w = vld1q_f32(w_ih + hi * in_cnt + ii);\n" | |
" float32x4_t m = vmulq_f32(in, w);\n" | |
" sum = vaddq_f32(sum, m);\n" | |
" }\n" | |
" sum = vpaddq_f32(sum,sum);\n" | |
" sum = vpaddq_f32(sum,sum);\n" | |
" n_hid[hi] = vgetq_lane_f32(sum,0);\n" | |
" n_hid[hi] = 1.0f / (1.0f + expf(-n_hid[hi]));\n" | |
" }\n" | |
" for (oi = 0; oi < out_cnt; ++oi) {\n" | |
" float32x4_t sum = vdupq_n_f32(0.0f);\n" | |
" for (hi = 0; hi <= hid_cnt; hi += 4) {\n" | |
" float32x4_t in = vld1q_f32(n_hid + hi);\n" | |
" float32x4_t w = vld1q_f32(w_ho + oi * hid_cnt + hi);\n" | |
" float32x4_t m = vmulq_f32(in, w);\n" | |
" sum = vaddq_f32(sum, m);\n" | |
" }\n" | |
" sum = vpaddq_f32(sum,sum);\n" | |
" sum = vpaddq_f32(sum,sum);\n" | |
" n_out[oi] = vgetq_lane_f32(sum,0);\n" | |
" n_out[oi] = 1.0f / (1.0f + expf(-n_out[oi]));\n" | |
" if (n_out[oi] < 0.1f) {\n" | |
" out[oi] = 0.0f;\n" | |
" } else if (x > 0.9f) {\n" | |
" out[oi] = 1.0f;\n" | |
" } else {\n" | |
" out[oi] = n_out[oi];\n" | |
" }\n" | |
" }\n"; | |
fprintf(fp, "static void nn_eval(float *out, const float *in){\n"); | |
fprintf(fp, " static const int in_cnt = %d;\n", ctx->in_cnt); | |
fprintf(fp, " static const int hid_cnt = %d;\n", ctx->hid_cnt); | |
fprintf(fp, " static const int out_cnt = %d;\n", ctx->out_cnt); | |
fprintf(fp, " static const float w_ih[] = {\n"); | |
fprintf(fp, " "); | |
for (hi = 0; hi < ctx->hid_cnt; ++hi) { | |
for (ii = 0; ii <= ctx->in_cnt; ++ii) { | |
int wi = nn__in_hide_weight_idx(ctx, ii, hi); | |
fprintf(fp, "%f,", ctx->w_in_hid[wi]); | |
if ((++n % 7 == 0)) { | |
fprintf(fp, "\n "); | |
} | |
} | |
} | |
n = 0; | |
fprintf(fp, "\n };\n"); | |
fprintf(fp, " static const float w_ho[] = {\n"); | |
fprintf(fp, " "); | |
for (oi = 0; oi < ctx->out_cnt; ++oi) { | |
for (hi = 0; hi <= ctx->hid_cnt; ++hi) { | |
int wi = nn__hide_out_weight_idx(ctx, hi, oi); | |
fprintf(fp, "%f,", ctx->w_hid_out[wi]); | |
if ((++n % 7 == 0)) { | |
fprintf(fp, "\n "); | |
} | |
} | |
} | |
fprintf(fp, "\n };\n"); | |
fprintf(fp, " float n_in[%d] = {0};\n",align_up(ctx->n_in_cnt, 4)); | |
fprintf(fp, " float n_hid[%d] = {0};\n", align_up(ctx->n_hid_cnt, 4)); | |
fprintf(fp, " float n_out[%d] = {0};\n",ctx->out_cnt); | |
fprintf(fp, " int hi, ii, oi;\n\n"); | |
fprintf(fp, " memcpy(n_in, in, (size_t)in_cnt * sizeof(float));\n"); | |
fprintf(fp, " n_in[%d] = n_hid[%d] = -1;\n", ctx->in_cnt, ctx->hid_cnt); | |
fprintf(fp, "%s", code); | |
fprintf(fp, "}\n"); | |
} | |
extern int | |
main(int argc, char **argv) { | |
const char *app = 0; | |
const char *out_file = 0; | |
int epochs = 150, epoch = 0; | |
struct nn_train_data dat = {0}; | |
struct nn_ctx ctx = {0}; | |
struct nn_trn trn = {0}; | |
trn.use_batch = 0; | |
trn.tar_acc = 90.0f; | |
trn.momentum = 0.9f; | |
trn.learn_rate = 0.001f; | |
/* Command Arguments */ | |
CMD_ARG_BEGIN(app, argc, argv){ | |
case 'h': default: usage(app); break; | |
case 'b': trn.use_batch = 1; break; | |
case 'n': epochs = cmd_arg_opt_int(argv,usage(app)); break; | |
case 'l': trn.learn_rate = cmd_arg_opt_flt(argv,usage(app)); break; | |
case 'm': trn.momentum = cmd_arg_opt_flt(argv,usage(app)); break; | |
case 'o': out_file = cmd_arg_opt_str(argv,usage(app)); break; | |
case 'a': { | |
float v = cmd_arg_opt_flt(argv,usage(app)); | |
trn.tar_acc = clamp(0.0f, v, 100.0f); | |
} break; | |
} CMD_ARG_END; | |
if (argc < 4) { | |
usage(app); | |
} | |
{ | |
/* Training Data */ | |
int file_siz = 0; | |
char *file = file_load(&file_siz, argv[0]); | |
dat.in_cnt = cmd_arg_int(argv[1]); | |
dat.hid_cnt = cmd_arg_int(argv[2]); | |
dat.out_cnt = cmd_arg_int(argv[3]); | |
dat.elm_cnt = line_cnt(file, file_siz); | |
dat.mem = xalloc(nn_req_train_data_siz(&dat)); | |
nn_train_data_load(&dat, file, file_siz); | |
free(file); | |
} | |
/* Network */ | |
ctx.in_cnt = dat.in_cnt; | |
ctx.out_cnt = dat.out_cnt; | |
ctx.hid_cnt = dat.hid_cnt; | |
ctx.mem = xalloc(nn_req_net_siz(&ctx)); | |
nn_init(&ctx); | |
/* Trainer */ | |
trn.mem = xalloc(nn_req_trn_siz(&ctx)); | |
nn_trn_init(&trn, &ctx); | |
printf("\nNeural Network Training Starting:\n"); | |
printf("=================================================================\n"); | |
printf("\tLearning Rate: %f\n", trn.learn_rate); | |
printf("\tMomentum: %f\n", trn.momentum); | |
printf("\tEpochs: %d\n", 150); | |
printf("\tInput Neurons: %d\n", ctx.in_cnt); | |
printf("\tHidden Neurons: %d\n", ctx.hid_cnt); | |
printf("\tOutput Neurons: %d\n", ctx.out_cnt); | |
printf("=================================================================\n"); | |
for (epoch = 0; epoch < epochs; ++epoch) { | |
nn_trn(&trn, &ctx, &dat); | |
if (trn.done) break; | |
printf("\tEpoch: %d Train Accuracy: %f Train MSE: %f General Accuracy: %f General MSE: %f\n", | |
epoch, trn.trn_acc, trn.trn_mse, trn.gen_acc, trn.gen_mse); | |
} | |
nn_trn_tst(&trn, &ctx, &dat); | |
printf("\nNeural Network Training Complete!!!\n"); | |
printf("\tEllapsed Epochs: %d\n", epoch); | |
printf("\tAccuracy: %f\n", trn.val_acc); | |
printf("\tMSE: %f\n\n", trn.val_mse); | |
if (out_file) { | |
FILE *fp = fopen(out_file, "w"); | |
if (fp) { | |
write_result(fp, &ctx); | |
fclose(fp); | |
} else { | |
printf("[error] failed to open output file!\n"); | |
} | |
} | |
/* Cleanup */ | |
free(trn.mem); | |
free(ctx.mem); | |
free(dat.mem); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment