Skip to content

Instantly share code, notes, and snippets.

@vurtun
Last active November 1, 2023 16:13
Show Gist options
  • Save vurtun/2de760b1bf9b35f3aa7620da4234e30f to your computer and use it in GitHub Desktop.
Save vurtun/2de760b1bf9b35f3aa7620da4234e30f to your computer and use it in GitHub Desktop.
Neural Network
#include <stdio.h>
#include <stdlib.h>
#include <limits.h>
#include <float.h>
#include <string.h>
#include <assert.h>
#include <stdint.h>
#include <stdarg.h>
#include <math.h>
#define unused(a) ((void)a)
#define cast(t, p) ((t)(p))
#define casts(p) cast(short,p)
#define casti(p) cast(int,p)
#define castu(p) cast(unsigned,p)
#define castl(p) cast(long,p)
#define castll(p) cast(long long,p)
#define castull(p) cast(unsigned long long,p)
#define castf(p) cast(float,p)
#define castd(p) cast(double,p)
#define castsz(p) cast(size_t,p)
#define castss(p) cast(ssize_t,p)
#define recast(T,p) ((T)cast(void*,(p)))
#define szof(a) ((int)sizeof(a))
#define cntof(a) ((int)(sizeof(a) / sizeof((a)[0])))
#define flag(n) ((1u) << (n))
#define min(a,b) ((a) < (b) ? (a) : (b))
#define max(a,b) ((a) > (b) ? (a) : (b))
#define clamp(a, v, b) (max(min(b, v), a))
#define iswap(x,y) do {((x) ^= (y), (y) ^= (x), (x) ^= (y)); }while(0)
#define align_mask(a) ((a)-1)
#define align_down_masked(n, m) ((n) & ~(m))
#define align_down(n, a) align_down_masked(n, align_mask(a))
#define align_up(n, a) align_down((n) + align_mask(a), (a))
#define xglue(x, y) x##y
#define glue(x, y) xglue(x, y)
#define uniqid(name) glue(name, __LINE__)
static void
die(const char *fmt, ...) {
va_list args;
va_start(args, fmt);
vfprintf(stderr, fmt, args);
fprintf(stderr, "\n");
va_end(args);
exit(1);
}
/* ---------------------------------------------------------------------------
* Platform
* --------------------------------------------------------------------------- */
#ifdef _MSC_VER
#define alignto(x) __declspec(align(x))
#define bit_cnt(u) __popcnt(u)
#define bit_cnt64(u) __popcnt64(u)
static int bit_ffs32(unsigned int u) {_BitScanForward(&u, u); return casti(u);}
static int bit_ffs64(unsigned long long u) {_BitScanForward64(&u, u); return casti(u);}
#else /* GCC, CLANG */
#define alignto(x) __attribute__((aligned(x)))
#define bit_cnt(u) __builtin_popcount(u)
#define bit_cnt64(u) __builtin_popcountll(u)
#define bit_ffs32(u) __builtin_ctz(u)
#define bit_ffs64(u) __builtin_ctzll(u)
#endif
#ifdef _WIN32 /* Windows */
#include "ntsecapi.h"
static unsigned long long
sys_rnd64(void) {
unsigned long long rnd = 0;
if (!RtlGenRandom(&rnd, sizeof(rnd))) {
fprintf(stderr, "failed to generate system random number\n");
exit(1);
}
return rnd;
}
#else /* UNIX */
#include <fcntl.h>
#include <unistd.h>
static unsigned long long
sys_rnd64(void) {
ssize_t res;
unsigned long long rnd = 0;
int fp = open("/dev/urandom", O_RDONLY);
if (fp == -1) {
fprintf(stderr, "failed to access system random number\n");
exit(1);
}
res = read(fp, cast(char*, &rnd), sizeof(rnd));
if (res < szof(rnd)) {
fprintf(stderr, "failed to generate system random number\n");
exit(1);
}
close(fp);
return rnd;
}
#endif
#ifdef __x86_64__ /* SSE */
#define SSE_ALIGN_BYTES 16
#define SSE_ALIGN alignto(SSE_ALIGN_BYTES)
#include <emmintrin.h>
static const char*
str_chr(const char *s, const char *e, int chr) {
static const char unsigned ovr_msk[32] = {
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
};
const char *e = s + n;
__m128i m = _mm_set1_epi8(chr & 0xff);
for (; s < e; s += 16) {
int r = (int)(e - s); r = r > 16 ? 16 : r;
__m128i o = _mm_loadu_si128((const __m128i *)(ovr_msk + 16 - r));
__m128i d = _mm_loadu_si128((const __m128i *)(const void*)s);
__m128i v = _mm_and_si128(d, o);
unsigned msk = _mm_movemask_epi8(_mm_cmpeq_epi8(v,m));
if (msk) {
return s + (31 - __builtin_clz(msk));
}
}
return e;
}
static int
line_cnt(const char *s, int n) {
static const char unsigned ovr_msk[32] = {
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
};
int cnt = 0;
const char *e = s + n;
__m128i m = _mm_set1_epi8('\n');
for (; s < e; s += 16) {
int r = casti(e - s);
int l = r > 16 ? 16 : r;
__m128i o = _mm_loadu_si128((const __m128i *)(ovr_msk + 16 - l));
__m128i d = _mm_loadu_si128((const __m128i *)(const void*)s);
__m128i v = _mm_and_si128(d, o);
unsigned msk = _mm_movemask_epi8(_mm_cmpeq_epi8(v,m));
cnt += bit_cnt(msk);
}
return cnt;
}
#elif defined(__arm__) || defined(__aarch64__) /* ARM NEON */
#include <arm_neon.h>
#define SSE_ALIGN_BYTES 16
#define SSE_ALIGN alignto(SSE_ALIGN_BYTES)
static const char*
str_chr(const char *s, const char *e, int chr) {
static const char unsigned ovr_msk[32] = {
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
};
uint8x16_t m = vdupq_n_u8(chr & 0xff);
for (; s < e; s += 16) {
int r = (int)(e - s); r = r > 16 ? 16 : r;
uint8x16_t o = vld1q_u8(ovr_msk + 16 - r);
uint8x16_t d = vld1q_u8((const unsigned char*)s);
uint8x16_t v = vandq_u8(d, o);
uint8x16_t c = vceqq_u8(v, m);
uint64x2_t p = vreinterpretq_u64_u8(c);
uint64_t vlo = vgetq_lane_u64(p, 0);
if (vlo) {
return s + ((bit_ffs64(vlo)) >> 3);
}
uint64_t vhi = vgetq_lane_u64(p, 1);
if (vhi) {
return s + 8 + ((bit_ffs64(vhi)) >> 3);
}
}
return e;
}
static int
line_cnt(const char *s, int n) {
static const char unsigned ovr_msk[32] = {
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
};
int cnt = 0;
const char *e = s + n;
uint8x16_t m = vdupq_n_u8('\n');
for (; s < e; s += 16) {
int r = casti(e - s);
int l = r > 16 ? 16 : r;
uint8x16_t o = vld1q_u8(ovr_msk + 16 - l);
uint8x16_t d = vld1q_u8((const unsigned char*)s);
uint8x16_t v = vandq_u8(d, o);
uint8x16_t c = vceqq_u8(v, m);
uint64x2_t p = vreinterpretq_u64_u8(c);
cnt += bit_cnt64(vgetq_lane_u64(p, 0)) >> 3;
cnt += bit_cnt64(vgetq_lane_u64(p, 1)) >> 3;
}
return cnt;
}
#else /* standard c */
static const char*
str_chr(const char *s, const char *e, int c) {
unsigned m = 0x01010101u * (unsigned)(c & 0xff);
for (;s < e && ((uintptr_t)s & 0x03); ++s) {
if (s[0] == c) {
return s;
}
}
for (;s < end; s += 4) {
unsigned v = *(unsigned*)s;
unsigned k = (~v) & 0x80808080;
unsigned x = ((v ^ m) - 0x01010101) & k;
if (x) return s + (bit_ffs32(x) >> 3u);
}
}
static int
line_cnt(const char *str, int len) {
int cnt = 1;
for (int i = 0; i < len; ++i) {
if (str[i] == '\n') {
cnt++;
}
}
return cnt;
}
#endif
/* ---------------------------------------------------------------------------
* String
* --------------------------------------------------------------------------- */
struct str {
const char *str;
const char *end;
int len;
};
#define str_rhs(s, n) str_sub(s, min((s).len, n), (s).len)
#define str_lhs(s, n) str_sub(s, 0, min((s).len, n))
#define str_cut_lhs(s, n) *(s) = str_rhs(*(s), n)
#define str_cut_rhs(s, n) *(s) = str_lhs(*(s), n)
#define for_str_tok(it, rest, src, delim) \
for ((rest) = (src), (it) = str_split_cut(&(rest), (delim)); \
(it).len; (it) = str_split_cut(&(rest), (delim)))
static struct str
str(const char *p, int len) {
struct str s = {0};
s.str = p;
s.end = p + len;
s.len = len;
return s;
}
static struct str
str_sub(struct str s, int from, int to) {
int b = min(from, to);
int e = max(from, to);
struct str r = {0};
r.str = s.str + min(b, s.len);
r.end = s.str + min(e, s.len);
r.len = casti(r.end - r.str);
return r;
}
static struct str
str_split_cut(struct str *s, int delim) {
const char *at = str_chr(s->str, s->end, delim);
if (at < s->end) {
int p = casti(at - s->str);
struct str res = str_lhs(*s, p);
str_cut_lhs(s, p + 1);
return res;
} else {
struct str res = *s;
memset(s, 0, sizeof(*s));
return res;
}
}
/* ---------------------------------------------------------------------------
* Command Arguments
* --------------------------------------------------------------------------- */
#define CMD_ARGC() argc_
#define cmd_arg_opt_str(argv, x) ((argv[0][1] == '\0' && argv[1] == 0)?\
((x), (char *)0) : (brk_ = 1, (argv[0][1] != '\0') ?\
(&argv[0][1]) : (argc--, argv++, argv[0])))
#define cmd_arg_opt_int(argv,x) cmd_arg_int(cmd_arg_opt_str(argv,x))
#define cmd_arg_opt_flt(argv,x) cmd_arg_flt(cmd_arg_opt_str(argv,x))
#define CMD_ARG_BEGIN(argv0, argc, argv) \
for (argv0 = *argv, argv++, argc--; argv[0] && argv[0][1] && argv[0][0] == '-'; argc--, argv++) {\
char argc_, **argv_; int brk_;\
if (argv[0][1] == '-' && argv[0][2] == '\0') {argv++; argc--; break;}\
for (brk_ = 0, argv[0]++, argv_ = argv; argv[0][0] && !brk_; argv[0]++) {\
if (argv_ != argv) break;\
argc_ = argv[0][0];\
switch (argc_)
#define CMD_ARG_END }}
static int
cmd_arg_int(const char *str) {
char *ep = 0;
long n = strtol(str, &ep, 10);
if (*ep != '\0' || ep == str) {
die("Invalid argument number: %s\n", str);
}
if (n < INT_MIN || n > INT_MAX) {
die("Argument number: %ld is out of range\n", n);
}
return casti(n);
}
static float
cmd_arg_flt(const char *str) {
char *ep = 0;
float n = strtof(str, &ep);
if (*ep != '\0' || ep == str) {
die("Invalid argument number: %s\n", str);
}
return n;
}
/* ---------------------------------------------------------------------------
* Utility
* --------------------------------------------------------------------------- */
#define swap(x,y) do { \
unsigned char uniqid(t)[szof(x) == szof(y) ? szof(x) : -1]; \
memcpy(uniqid(t),&y,sizeof(x)); \
memcpy(&y,&x,sizeof(x)); \
memcpy(&x,uniqid(t),sizeof(x)); \
} while(0)
#define arr_shfl(a,n,p) do { \
for (int uniqid(i) = 0; uniqid(i) < n; ++uniqid(i)) { \
if (p[uniqid(i)] >= 0) { \
int uniqid(j) = uniqid(i); \
while (p[uniqid(j)] != uniqid(i)) { \
const int uniqid(d) = p[uniqid(j)]; \
swap(a[uniqid(j)], a[uniqid(d)]); \
p[uniqid(j)] = -1 - uniqid(d); \
uniqid(j) = uniqid(d); \
} p[uniqid(j)] = -1 - p[uniqid(j)]; \
} \
}} while (0)
static void*
xalloc(int siz) {
void *mem = calloc(castsz(siz), 1);
if (!mem) {
die("Out of Memory");
}
return mem;
}
static unsigned long long
rnd_gen(unsigned long long x, int n) {
return x + castull(n) * 0x9E3779B97F4A7C15llu;
}
static unsigned long long
rnd_mix(unsigned long long z) {
z = (z ^ (z >> 30)) * 0xBF58476D1CE4E5B9llu;
z = (z ^ (z >> 27)) * 0x94D049BB133111EBllu;
return z ^ (z >> 31llu);
}
static unsigned long long
rnd_split_mix(unsigned long long *x, int i) {
*x = rnd_gen(*x, i);
return rnd_mix(*x);
}
static unsigned long long
rnd(unsigned long long *x) {
return rnd_split_mix(x, 1);
}
static unsigned
rndu(unsigned long long *x) {
unsigned long long z = rnd(x);
return castu(z & 0xffffffffu);
}
static unsigned
rnduu(unsigned long long *x, unsigned mini, unsigned maxi) {
unsigned lo = min(mini, maxi);
unsigned hi = max(mini, maxi);
unsigned rng = castu(-1);
unsigned n = hi - lo + 1u;
if (n == 1u) {
return mini;
} else if(n == 0u) {
return rndu(x);
} else {
unsigned v = 0;
unsigned remainder = rng % n;
do {v = rndu(x);}
while(v >= rng - remainder);
return mini + v % n;
}
}
static float
rndf01(unsigned long long *x) {
unsigned u = rndu(x);
double du = castd(u);
double div = castd((unsigned)-1);
return castf(du/div);
}
static float
rnduf(unsigned long long *x, float mini, float maxi) {
float lo = min(mini, maxi);
float hi = max(mini, maxi);
unsigned u = rndu(x);
float rng = hi - lo;
double du = castd(u);
double div = castd((unsigned)-1);
return lo + rng * castf(du/div);
}
static void
seq_lin(int *seq, int cnt) {
int i = 0;
for (i = 0; i < cnt; ++i) {
seq[i] = i;
}
}
static void
seq_rnd(int *seq, int n, unsigned long long *r) {
int i = 0;
for (i = n - 1; i > 0; --i) {
unsigned at = rndu(r) % castu(i + 1);
iswap(seq[i], seq[at]);
}
}
static char*
file_load(int *siz, const char *path) {
FILE *fd = 0;
char *mem = 0;
size_t ret = 0;
/* open file */
assert(path);
fd = fopen(path, "r");
if (!fd) {
die("Unable to open file: %s\n", path);
}
/* calculate file size */
fseek(fd, 0, SEEK_END);
*siz = casti(ftell(fd));
if (*siz < 0) {
die("Unable to access file: %s\n", path);
}
if (*siz == 0) {
die("File is empty: %s\n", path);
}
fseek(fd, 0, SEEK_SET);
/* read file into memory */
mem = xalloc(*siz + SSE_ALIGN_BYTES);
assert(mem);
if (!mem) {
die("file calloc failed");
}
ret = fread(mem, 1, castsz(*siz), fd);
if (ret < castsz(*siz)) {
die("file fread failed: %d\n", ret);
}
fclose(fd);
return mem;
}
/* ---------------------------------------------------------------------------==
*
*
* Neural Network
*
*
* --------------------------------------------------------------------------- */
struct nn_ctx {
/* in */
void *mem;
int in_cnt;
int out_cnt;
int hid_cnt;
/* out */
int n_in_cnt;
int n_out_cnt;
int n_hid_cnt;
int w_in_hid_cnt;
int w_hid_out_cnt;
float *n_in;
float *n_hid;
float *n_out;
int *outs;
float *w_in_hid;
float *w_hid_out;
/* intern */
unsigned setup:1;
};
static int
nn_req_net_siz(struct nn_ctx *ctx) {
int n_in_siz = szof(float) * (ctx->in_cnt + 1);
int n_hid_siz = szof(float) * (ctx->hid_cnt + 1);
int n_out_siz = szof(float) * ctx->out_cnt;
int out_siz = szof(int) * ctx->out_cnt;
int wih_siz = szof(float) * ((ctx->in_cnt + 1) * (ctx->hid_cnt + 1));
int woh_siz = szof(float) * ((ctx->hid_cnt + 1) * ctx->out_cnt);
return n_in_siz + n_hid_siz + n_out_siz + out_siz + wih_siz + woh_siz;
}
static void
nn__setup_net(struct nn_ctx *ctx) {
ctx->setup = 1;
ctx->n_in_cnt = ctx->in_cnt + 1;
ctx->n_out_cnt = ctx->out_cnt;
ctx->n_hid_cnt = ctx->hid_cnt + 1;
ctx->w_in_hid_cnt = (ctx->n_in_cnt * ctx->n_hid_cnt);
ctx->w_hid_out_cnt = (ctx->n_out_cnt * ctx->n_hid_cnt);
/* setup memory */
ctx->n_in = ctx->mem;
ctx->n_hid = ctx->n_in + ctx->n_in_cnt;
ctx->n_out = ctx->n_hid + ctx->n_hid_cnt;
ctx->w_in_hid = ctx->n_out + ctx->n_out_cnt;
ctx->w_hid_out = ctx->w_in_hid + ctx->w_in_hid_cnt;
ctx->outs = cast(int*, ctx->w_hid_out + ctx->w_hid_out_cnt);
ctx->n_in[ctx->in_cnt] = -1.0;
ctx->n_hid[ctx->hid_cnt] = -1.0;
}
static int
nn__in_hide_weight_idx(const struct nn_ctx *ctx, int in_idx, int hid_idx) {
return in_idx * ctx->hid_cnt + hid_idx;
}
static int
nn__hide_out_weight_idx(const struct nn_ctx *ctx, int hid_idx, int out_idx) {
return hid_idx * ctx->out_cnt + out_idx;
}
static void
nn_init(struct nn_ctx *ctx) {
int ii, hi, oi;
unsigned long long rnd_gen = sys_rnd64();
float dist = (2.4f / ctx->in_cnt);
if (!ctx->setup) {
nn__setup_net(ctx);
}
for (ii = 0; ii <= ctx->n_in_cnt; ++ii) {
for (hi = 0; hi <= ctx->hid_cnt; ++hi) {
int wi = nn__in_hide_weight_idx(ctx, ii, hi);
ctx->w_in_hid[wi] = (rndf01(&rnd_gen) * 2.0f - 1.0f) * dist;
}
}
for (hi = 0; hi <= ctx->n_hid_cnt; ++hi) {
for (oi = 0; oi < ctx->out_cnt; ++oi) {
int wi = nn__hide_out_weight_idx(ctx, hi, oi);
ctx->w_hid_out[wi] = (rndf01(&rnd_gen) * 2.0f - 1.0f) * dist;
}
}
}
static float
nn__sigmoid(float x) {
return 1.0f / ( 1.0f + expf(-x));
}
static int
nn__clamp(float x) {
if (x < 0.1f) return 0;
else if (x > 0.9f) return 1;
else return -1.0f;
}
static int*
nn_eval(struct nn_ctx *ctx, const float *in) {
int oi, ii, hi = 0;
memcpy(ctx->n_in, in, sizeof(float) * castsz(ctx->in_cnt));
for (hi = 0; hi < ctx->hid_cnt; ++hi) {
/* calculate weightes sum of pattern and bias neuron */
ctx->n_hid[hi] = 0;
for (ii = 0; ii <= ctx->in_cnt; ++ii) {
int wi = nn__in_hide_weight_idx(ctx, ii, hi);
ctx->n_hid[hi] += ctx->n_in[ii] * ctx->w_in_hid[wi];
}
ctx->n_hid[hi] = nn__sigmoid(ctx->n_hid[hi]);
}
for (oi = 0; oi < ctx->out_cnt; ++oi) {
/* calcule output values - include bias neuron */
ctx->n_out[oi] = 0;
for (hi = 0; hi <= ctx->hid_cnt; ++hi) {
int wi = nn__hide_out_weight_idx(ctx, hi, oi);
ctx->n_out[oi] += ctx->n_hid[hi] * ctx->w_hid_out[wi];
}
ctx->n_out[oi] = nn__sigmoid(ctx->n_out[oi]);
ctx->outs[oi] = nn__clamp(ctx->n_out[oi]);
}
return ctx->outs;
}
/* ---------------------------------------------------------------------------
* Data
* ---------------------------------------------------------------------------
*/
struct nn_train_elm {
float *in;
int *out;
};
struct nn_train_set {
int begin;
int end;
int cnt;
};
struct nn_train_data {
/* in */
void *mem;
int in_cnt;
int hid_cnt;
int out_cnt;
int elm_cnt;
/* out */
struct nn_train_elm *elms;
struct nn_train_set train;
struct nn_train_set gen;
struct nn_train_set val;
int *seq;
/* intern */
unsigned setup:1;
};
static int
nn_req_train_data_siz(const struct nn_train_data *trn) {
int es = szof(struct nn_train_elm) * trn->elm_cnt;
int ins = szof(float) * trn->in_cnt * trn->elm_cnt;
int outs = szof(int) * trn->out_cnt * trn->elm_cnt;
int shfl = szof(int) * trn->elm_cnt;
return es + ins + shfl + outs;
}
static void
nn__setup_train_data(struct nn_train_data *ctx) {
int i;
float *ins;
int *outs;
ctx->elms = recast(struct nn_train_elm*, ctx->mem);
ctx->seq = recast(int*, ctx->elms + ctx->elm_cnt);
ins = recast(float*, (ctx->seq + ctx->elm_cnt));
outs = recast(int*, (ins + ctx->in_cnt * ctx->elm_cnt));
for (i = 0; i < ctx->elm_cnt; ++i, ins += ctx->in_cnt, outs += ctx->out_cnt) {
ctx->elms[i].in = ins;
ctx->elms[i].out = outs;
}
ctx->setup = 1;
}
static void
nn__split_train_data_set(struct nn_train_data *dat, int cnt) {
float elm_cnt = castf(cnt);
int train_cnt = casti(0.6f * elm_cnt);
int gen_cnt = casti(ceil(0.2f * elm_cnt));
dat->train.begin = 0;
dat->train.end = train_cnt;
dat->train.cnt = train_cnt;
dat->gen.begin = train_cnt;
dat->gen.end = dat->gen.begin + gen_cnt;
dat->gen.cnt = gen_cnt;
dat->val.begin = dat->gen.end;
dat->val.end = cnt;
dat->val.cnt = cnt - dat->val.begin;
}
static void
nn__load_train_data_elm(struct nn_train_elm *elm, const struct str *tok,
int in_cnt, int out_cnt) {
int i = 0;
struct str it, _;
for_str_tok(it, _, *tok, ',') {
char *ep = 0;
float f = 0.0;
if (i >= in_cnt + out_cnt) {
break;
}
*(char*)it.end = '\0';
f = strtof(it.str, &ep);
*(char*)it.end = ',';
if (i < in_cnt) {
elm->in[i] = f;
} else {
elm->out[i - in_cnt] = casti(f);
}
i++;
}
}
static void
nn_train_data_load(struct nn_train_data* dat, const char *file, int len) {
int i = 0;
unsigned long long rnd_gen = sys_rnd64();
struct str ln, _, in = str(file,len);
if (!dat->setup) {
nn__setup_train_data(dat);
}
for_str_tok(ln, _, in, '\n') {
if (ln.len > 2) {
struct nn_train_elm *elm = &dat->elms[i];
nn__load_train_data_elm(elm, &ln, dat->in_cnt, dat->out_cnt);
i++;
}
}
seq_rnd(dat->seq, dat->elm_cnt, &rnd_gen);
arr_shfl(dat->elms, dat->elm_cnt, dat->seq);
nn__split_train_data_set(dat, dat->elm_cnt);
}
/* ---------------------------------------------------------------------------
* Trainer
* ---------------------------------------------------------------------------
*/
struct nn_trn {
/* in */
void *mem;
/* in: settings */
unsigned use_batch:1;
float learn_rate;
float momentum;
float tar_acc;
float acc;
/* out */
float trn_acc;
float trn_mse;
float gen_acc;
float gen_mse;
float val_acc;
float val_mse;
unsigned done:1;
/* intern */
unsigned setup:1;
float *dt_in_hid;
float *dt_hid_out;
float *err_hid;
float *err_out;
};
static int
nn_req_trn_siz(struct nn_ctx *ctx) {
int ih = szof(float) * ctx->w_in_hid_cnt;
int oh = szof(float) * ctx->w_hid_out_cnt;
int nh = szof(float) * ctx->n_hid_cnt;
int no = szof(float) * ctx->n_out_cnt;
return ih + oh + nh + no;
}
static void
nn_trn_init(struct nn_trn *trn, struct nn_ctx *ctx) {
trn->setup = 1;
trn->learn_rate = (trn->learn_rate == 0.0f) ? 0.001f : trn->learn_rate;
trn->momentum = (trn->momentum == 0.0f) ? 0.9f : trn->momentum;
trn->acc = (trn->acc == 0.0f) ? 90.0f : trn->acc;
trn->dt_in_hid = trn->mem;
trn->dt_hid_out = trn->dt_in_hid + szof(float) * ctx->w_in_hid_cnt;
trn->err_hid = trn->dt_hid_out + szof(float) * ctx->w_hid_out_cnt;
trn->err_out = trn->err_hid + szof(float) * ctx->n_hid_cnt;
}
static void
nn__trn_update_weights(struct nn_trn *trn, struct nn_ctx *ctx) {
int ii, hi, oi;
/* input -> hidden weights */
for (ii = 0; ii <= ctx->in_cnt; ++ii) {
for (hi = 0; hi <= ctx->hid_cnt; ++hi) {
int w_idx = nn__in_hide_weight_idx(ctx, ii, hi);
ctx->w_in_hid[w_idx] += trn->dt_in_hid[w_idx];
if (trn->use_batch) {
trn->dt_in_hid[w_idx] = 0;
}
}
}
/* output -> hidden weights */
for (hi = 0; hi <= ctx->hid_cnt; ++hi) {
for (oi = 0; oi < ctx->out_cnt; ++oi) {
int wi = nn__hide_out_weight_idx(ctx, hi, oi);
ctx->w_hid_out[wi] += trn->dt_hid_out[wi];
if (trn->use_batch) {
trn->dt_hid_out[wi] = 0;
}
}
}
}
static float
nn__trn_out_err_gradient(float tar_val, float out_val) {
return out_val * (1.0f - out_val) * (tar_val - out_val);
}
static float
nn__trn_hid_err_gradient(struct nn_trn *trn, struct nn_ctx *ctx, int hi) {
int oi = 0;
float w_sum = 0;
for (oi = 0; oi < ctx->out_cnt; ++oi) {
int w_idx = nn__hide_out_weight_idx(ctx, hi, oi);
w_sum += ctx->w_hid_out[w_idx] * trn->err_out[oi];
}
return ctx->n_hid[hi] * (1.0f - ctx->n_hid[hi]) * w_sum;
}
static void
nn__trn_backpropergate(struct nn_trn *trn, struct nn_ctx *ctx, const int *exp_out) {
int ii = 0, oi = 0, hi = 0;
/* modify deltas between hidden and output layers */
for (oi = 0; oi < ctx->out_cnt; ++oi) {
/* get error gradient for output node */
trn->err_out[oi] = nn__trn_out_err_gradient(castf(exp_out[oi]), ctx->n_out[oi]);
/* for all nodes in hidden layer and bias neuron calculate change in weight */
for (hi = 0; hi <= ctx->hid_cnt; ++hi) {
int wi = nn__hide_out_weight_idx(ctx, hi, oi);
if (trn->use_batch) {
trn->dt_hid_out[wi] += trn->learn_rate * ctx->n_hid[hi] * trn->err_out[oi];
} else {
trn->dt_hid_out[wi] = trn->learn_rate * ctx->n_hid[hi] * trn->err_out[oi] + trn->momentum * trn->dt_hid_out[wi];
}
}
}
/* modify deltas between input and hidden layers */
for (hi = 0; hi <= ctx->hid_cnt; ++hi) {
trn->err_hid[hi] = nn__trn_hid_err_gradient(trn, ctx, hi);
for (ii = 0; ii <= ctx->in_cnt; ++ii) {
int wi = nn__in_hide_weight_idx(ctx, ii, hi);
if (trn->use_batch) {
trn->dt_in_hid[wi] += trn->learn_rate * ctx->n_in[ii] * trn->err_hid[hi];
} else {
trn->dt_in_hid[wi] = trn->learn_rate * ctx->n_in[ii] * trn->err_hid[hi] + trn->momentum * trn->dt_in_hid[wi];
}
}
}
/* If using stochastic learning update the weights immediately */
if (!trn->use_batch) {
nn__trn_update_weights(trn, ctx);
}
}
static void
nn__trn_tst(float *mse, float *acc, struct nn_trn *trn, struct nn_ctx *ctx,
const struct nn_train_data *dat, const struct nn_train_set *set) {
int oi, i = 0;
float false_cnt = 0;
*mse = *acc = 0;
for (i = set->begin; i < set->end; ++i) {
int is_correct = 1;
struct nn_train_elm *elm = dat->elms + i;
nn_eval(ctx, elm->in);
for (oi = 0; oi < ctx->out_cnt; ++oi) {
if (ctx->outs[oi] != elm->out[oi]) {
is_correct = 0;
}
*mse += castf(pow((ctx->n_out[oi] - elm->out[oi]), 2));
}
if (!is_correct) {
false_cnt += 1.0f;
}
}
*acc = 100.0f - (false_cnt / castf(set->cnt)) * 100.0f;
*mse = *mse / (ctx->out_cnt * castf(set->cnt));
}
static int
nn_trn(struct nn_trn *trn, struct nn_ctx *ctx, struct nn_train_data *dat) {
int i = 0, oi = 0;
float mse = 0;
float incorrect_entries = 0;
if (trn->trn_acc >= trn->tar_acc && trn->gen_acc >= trn->tar_acc) {
trn->done = 1;
return 0;
}
for (i = dat->train.begin; i < dat->train.end; ++i) {
int is_correct = 1;
struct nn_train_elm *elm = dat->elms + i;
nn_eval(ctx, elm->in);
nn__trn_backpropergate(trn, ctx, elm->out);
for (oi = 0; oi < ctx->out_cnt; ++oi) {
if (ctx->outs[oi] != elm->out[oi]) {
is_correct = 0;
}
mse += castf(pow((ctx->n_out[oi] - elm->out[oi]), 2));
}
if (!is_correct) {
incorrect_entries += 1.0f;
}
}
if (trn->use_batch) {
nn__trn_update_weights(trn, ctx);
}
trn->trn_acc = 100.0f - (incorrect_entries / castf(dat->train.cnt)) * 100.0f;
trn->trn_mse = mse / (ctx->out_cnt * castf(dat->train.cnt));
nn__trn_tst(&trn->gen_mse, &trn->gen_acc, trn, ctx, dat, &dat->gen);
return 1;
}
static void
nn_trn_tst(struct nn_trn *trn, struct nn_ctx *ctx, struct nn_train_data *dat) {
nn__trn_tst(&trn->val_mse, &trn->val_acc, trn, ctx, dat, &dat->val);
}
/* ---------------------------------------------------------------------------
*
*
* App
*
*
* --------------------------------------------------------------------------- */
static void
usage(const char *app) {
die("\n"
"usage: %s [options] data in hidden out\n"
"\n"
" arguments:\n"
"\n"
" data, Path to training data file\n"
" in, Number of input neurons\n"
" hidden, Number of hidden neurons\n"
" out, Number of output neurons\n"
"\n"
" options:\n"
" -n <epochs>, epoch count (150))\n"
" -l <rate>, Learning rate (0.001)\n"
" -a <acc>, Desired Accuracy (90)\n"
" -m <mom>, Momentum (0.9)\n"
" -o <path>, Code out file path\n"
" -b enable batch learning\n"
" -h help message\n"
"\n",
app
);
exit(1);
}
static void
write_result(FILE *fp, const struct nn_ctx *ctx) {
int hi, ii, oi, n = 0;
static const char code[] =
" for (hi = 0; hi < hid_cnt; ++hi) {\n"
" float32x4_t sum = vdupq_n_f32(0.0f);\n"
" for (ii = 0; ii <= in_cnt; i += 4) {\n"
" float32x4_t in = vld1q_f32(n_in + ii);\n"
" float32x4_t w = vld1q_f32(w_ih + hi * in_cnt + ii);\n"
" float32x4_t m = vmulq_f32(in, w);\n"
" sum = vaddq_f32(sum, m);\n"
" }\n"
" sum = vpaddq_f32(sum,sum);\n"
" sum = vpaddq_f32(sum,sum);\n"
" n_hid[hi] = vgetq_lane_f32(sum,0);\n"
" n_hid[hi] = 1.0f / (1.0f + expf(-n_hid[hi]));\n"
" }\n"
" for (oi = 0; oi < out_cnt; ++oi) {\n"
" float32x4_t sum = vdupq_n_f32(0.0f);\n"
" for (hi = 0; hi <= hid_cnt; hi += 4) {\n"
" float32x4_t in = vld1q_f32(n_hid + hi);\n"
" float32x4_t w = vld1q_f32(w_ho + oi * hid_cnt + hi);\n"
" float32x4_t m = vmulq_f32(in, w);\n"
" sum = vaddq_f32(sum, m);\n"
" }\n"
" sum = vpaddq_f32(sum,sum);\n"
" sum = vpaddq_f32(sum,sum);\n"
" n_out[oi] = vgetq_lane_f32(sum,0);\n"
" n_out[oi] = 1.0f / (1.0f + expf(-n_out[oi]));\n"
" if (n_out[oi] < 0.1f) {\n"
" out[oi] = 0.0f;\n"
" } else if (x > 0.9f) {\n"
" out[oi] = 1.0f;\n"
" } else {\n"
" out[oi] = n_out[oi];\n"
" }\n"
" }\n";
fprintf(fp, "static void nn_eval(float *out, const float *in){\n");
fprintf(fp, " static const int in_cnt = %d;\n", ctx->in_cnt);
fprintf(fp, " static const int hid_cnt = %d;\n", ctx->hid_cnt);
fprintf(fp, " static const int out_cnt = %d;\n", ctx->out_cnt);
fprintf(fp, " static const float w_ih[] = {\n");
fprintf(fp, " ");
for (hi = 0; hi < ctx->hid_cnt; ++hi) {
for (ii = 0; ii <= ctx->in_cnt; ++ii) {
int wi = nn__in_hide_weight_idx(ctx, ii, hi);
fprintf(fp, "%f,", ctx->w_in_hid[wi]);
if ((++n % 7 == 0)) {
fprintf(fp, "\n ");
}
}
}
n = 0;
fprintf(fp, "\n };\n");
fprintf(fp, " static const float w_ho[] = {\n");
fprintf(fp, " ");
for (oi = 0; oi < ctx->out_cnt; ++oi) {
for (hi = 0; hi <= ctx->hid_cnt; ++hi) {
int wi = nn__hide_out_weight_idx(ctx, hi, oi);
fprintf(fp, "%f,", ctx->w_hid_out[wi]);
if ((++n % 7 == 0)) {
fprintf(fp, "\n ");
}
}
}
fprintf(fp, "\n };\n");
fprintf(fp, " float n_in[%d] = {0};\n",align_up(ctx->n_in_cnt, 4));
fprintf(fp, " float n_hid[%d] = {0};\n", align_up(ctx->n_hid_cnt, 4));
fprintf(fp, " float n_out[%d] = {0};\n",ctx->out_cnt);
fprintf(fp, " int hi, ii, oi;\n\n");
fprintf(fp, " memcpy(n_in, in, (size_t)in_cnt * sizeof(float));\n");
fprintf(fp, " n_in[%d] = n_hid[%d] = -1;\n", ctx->in_cnt, ctx->hid_cnt);
fprintf(fp, "%s", code);
fprintf(fp, "}\n");
}
extern int
main(int argc, char **argv) {
const char *app = 0;
const char *out_file = 0;
int epochs = 150, epoch = 0;
struct nn_train_data dat = {0};
struct nn_ctx ctx = {0};
struct nn_trn trn = {0};
trn.use_batch = 0;
trn.tar_acc = 90.0f;
trn.momentum = 0.9f;
trn.learn_rate = 0.001f;
/* Command Arguments */
CMD_ARG_BEGIN(app, argc, argv){
case 'h': default: usage(app); break;
case 'b': trn.use_batch = 1; break;
case 'n': epochs = cmd_arg_opt_int(argv,usage(app)); break;
case 'l': trn.learn_rate = cmd_arg_opt_flt(argv,usage(app)); break;
case 'm': trn.momentum = cmd_arg_opt_flt(argv,usage(app)); break;
case 'o': out_file = cmd_arg_opt_str(argv,usage(app)); break;
case 'a': {
float v = cmd_arg_opt_flt(argv,usage(app));
trn.tar_acc = clamp(0.0f, v, 100.0f);
} break;
} CMD_ARG_END;
if (argc < 4) {
usage(app);
}
{
/* Training Data */
int file_siz = 0;
char *file = file_load(&file_siz, argv[0]);
dat.in_cnt = cmd_arg_int(argv[1]);
dat.hid_cnt = cmd_arg_int(argv[2]);
dat.out_cnt = cmd_arg_int(argv[3]);
dat.elm_cnt = line_cnt(file, file_siz);
dat.mem = xalloc(nn_req_train_data_siz(&dat));
nn_train_data_load(&dat, file, file_siz);
free(file);
}
/* Network */
ctx.in_cnt = dat.in_cnt;
ctx.out_cnt = dat.out_cnt;
ctx.hid_cnt = dat.hid_cnt;
ctx.mem = xalloc(nn_req_net_siz(&ctx));
nn_init(&ctx);
/* Trainer */
trn.mem = xalloc(nn_req_trn_siz(&ctx));
nn_trn_init(&trn, &ctx);
printf("\nNeural Network Training Starting:\n");
printf("=================================================================\n");
printf("\tLearning Rate: %f\n", trn.learn_rate);
printf("\tMomentum: %f\n", trn.momentum);
printf("\tEpochs: %d\n", 150);
printf("\tInput Neurons: %d\n", ctx.in_cnt);
printf("\tHidden Neurons: %d\n", ctx.hid_cnt);
printf("\tOutput Neurons: %d\n", ctx.out_cnt);
printf("=================================================================\n");
for (epoch = 0; epoch < epochs; ++epoch) {
nn_trn(&trn, &ctx, &dat);
if (trn.done) break;
printf("\tEpoch: %d Train Accuracy: %f Train MSE: %f General Accuracy: %f General MSE: %f\n",
epoch, trn.trn_acc, trn.trn_mse, trn.gen_acc, trn.gen_mse);
}
nn_trn_tst(&trn, &ctx, &dat);
printf("\nNeural Network Training Complete!!!\n");
printf("\tEllapsed Epochs: %d\n", epoch);
printf("\tAccuracy: %f\n", trn.val_acc);
printf("\tMSE: %f\n\n", trn.val_mse);
if (out_file) {
FILE *fp = fopen(out_file, "w");
if (fp) {
write_result(fp, &ctx);
fclose(fp);
} else {
printf("[error] failed to open output file!\n");
}
}
/* Cleanup */
free(trn.mem);
free(ctx.mem);
free(dat.mem);
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment