Skip to content

Instantly share code, notes, and snippets.

@mattn
Created April 16, 2019 05:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mattn/920779ca764b777174958db0430964ae to your computer and use it in GitHub Desktop.
Save mattn/920779ca764b777174958db0430964ae to your computer and use it in GitHub Desktop.
" Random number generator using George Marsaglia's MWC algorithm.
let s:hi = 0
let s:lo = 0
function! s:srand(seed)
if a:seed < 0
let s:hi = (a:seed - 0x80000000) / 0x10000 + 0x8000
let s:lo = (a:seed - 0x80000000) % 0x10000
else
let s:hi = a:seed / 0x10000 + 0x8000
let s:lo = a:seed % 0x10000
endif
endfunction
function! s:rand()
if s:hi == 0
let s:hi = s:random_seed()
endif
if s:lo == 0
let s:lo = s:random_seed()
endif
if s:hi < 0
let hi = s:hi - 0x80000000
let hi = 36969 * (hi % 0x10000) + (hi / 0x10000 + 0x8000)
else
let hi = s:hi
let hi = 36969 * (hi % 0x10000) + (hi / 0x10000)
endif
if s:lo < 0
let lo = s:lo - 0x80000000
let lo = 18273 * (lo % 0x10000) + (lo / 0x10000 + 0x8000)
else
let lo = s:lo
let lo = 18273 * (lo % 0x10000) + (lo / 0x10000)
endif
let s:hi = hi
let s:lo = lo
return (hi * 0x10000) + ((lo < 0 ? lo - 0x80000000 : lo) % 0x10000)
endfunction
function! s:random()
let n = s:rand()
if n < 0
return (n - 0x80000000) / 4294967295.0 + (0x40000000 / (4294967295.0 / 2.0))
else
return n / 4294967295.0
endif
endfunction
" V8 uses C runtime random function for seed and initialize it with time.
let s:seed = float2nr(fmod(str2float(reltimestr(reltime())) * 256, 2147483648.0))
function! s:random_seed()
let s:seed = s:seed * 214013 + 2531011
return (s:seed < 0 ? s:seed - 0x80000000 : s:seed) / 0x10000 % 0x8000
endfunction
function! s:add(x, y) abort
return join(map(a:x, 'v:val + a:y[v:key]'), '+')
endfunction
function! s:scale(x, f) abort
return map(deepcopy(a:x), 'v:val * a:f')
endfunction
function! s:dot(x, y) abort
return eval(join(map(deepcopy(a:x), 'v:val * a:y[v:key]'), '+'))
endfunction
function! s:softmax(w, x) abort
let l:v = s:dot(a:w, a:x)
return 1.0 / (1.0 + exp(-l:v))
endfunction
function! s:predict(w, x) abort
return s:softmax(a:w, a:x)
endfunction
function! s:logistic_regression(X, y, rate, ntrains) abort
let l:w = map(repeat([[]], len(a:X[0])), 's:random()')
let l:w = [0.1, 0.2, 0.3, 0.4]
let l:rate = a:rate
for l:n in range(a:ntrains)
for l:i in range(len(a:X))
let l:x = a:X[l:i]
let l:t = deepcopy(l:x)
let l:pred = s:softmax(l:t, l:w)
let l:perr = a:y[l:i] - l:pred
let l:scale = l:rate * l:perr * l:pred * (1.0 - l:pred)
call s:add(l:w, s:scale(l:x, l:scale))
endfor
endfor
return l:w
endfunction
function! s:token(line) abort
return map(split(a:line, ','), 'v:val =~# "^[-+]\\?[0-9][.]\\?[0-9]*$" ? str2float(v:val) : v:val')
endfunction
function! s:make_vocab(names) abort
let l:ns = {}
for l:name in a:names
if !has_key(l:ns, l:name)
let l:ns[l:name] = 0.0 + len(l:ns)
endif
endfor
return l:ns
endfunction
function! s:bag_of_words(names, vocab) abort
return map(a:names, '(0.0 + a:vocab[v:val]) / len(a:vocab)')
endfunction
function! s:shuffle(arr)
let l:arr = a:arr
let l:i = len(l:arr)
while l:i
let l:i -= 1
let l:j = s:rand() * l:i % len(l:arr)
if l:i ==# l:j
continue
endif
let [l:arr[l:i], l:arr[l:j]] = [l:arr[l:j], l:arr[l:i]]
endwhile
return l:arr
endfunction
function! s:main() abort
call s:srand(1555386915) " localtime()
let l:data = map(readfile('iris.csv'), 's:token(v:val)')[1:]
call s:shuffle(l:data)
let [l:train, l:test] = [l:data[:119], l:data[119:]]
let [l:X, l:y] = [[], []]
for l:row in l:train
call add(l:X, l:row[:3])
call add(l:y, l:row[4])
endfor
let l:vocab = s:make_vocab(l:y)
call s:bag_of_words(l:y, l:vocab)
let l:ni = map(sort(map(keys(l:vocab), '[v:val, float2nr(l:vocab[v:val])]'), {a, b -> a[1] - b[1]}), 'v:val[0]')
let l:w = s:logistic_regression(l:X, l:y, 0.1, 3000)
let l:count = 0
for l:row in l:test
let l:r = s:predict(l:row[:3], l:w)
if l:ni[min([float2nr(l:r * len(l:vocab) + 0.1), len(l:vocab)-1])] ==# l:row[4]
let l:count += 1
endif
endfor
echo (0.0 + l:count) / len(l:test)
endfunction
call s:main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment