Skip to content

Instantly share code, notes, and snippets.

@Roger-luo
Created August 8, 2019 21:05
Show Gist options
  • Save Roger-luo/22534259242116474ce0b57342550d06 to your computer and use it in GitHub Desktop.
Save Roger-luo/22534259242116474ce0b57342550d06 to your computer and use it in GitHub Desktop.
using Flux, Tracker, DelimitedFiles
using LinearAlgebra, Random
using Flux: onehotbatch
using Flux.Optimise
using Flux.Optimise: update!
using Tracker: TrackedReal, data
using Base.Iterators: partition
using BitBasis
function generate_sample(m, L, batch_size=64)
Flux.reset!(m)
p = zeros(1, batch_size)
ps = []
for k in 1:L
p = m(p)
push!(ps, p)
end
return vcat(ps...)
end
function amplitude(m, L)
Flux.reset!(m)
amp = zeros(length(basis(L)))
for s in basis(L)
p = 1; x = 0
for k in 1:L
x = readbit(s, k)
# print(x)
p *= Tracker.data(m(x))[]
end
# println()
amp[s+1] = p
end
return sqrt.(amp/sum(amp))
end
opt = ADAM()
m = Chain(LSTM(1, 2), Dense(2, 1, σ))
batch_size = 64
raw_data = readdlm("data/Samples.txt")'
amp = readdlm("data/Amplitudes.txt")[:, 1]
mb_itr = partition(1:size(raw_data, 2), batch_size)
train_set = [(raw_data[:, k], ) for k in mb_itr]
loss(y) = Flux.mse(generate_sample(m, size(y)...), y)
for epoch in 1:100
train!(loss, params(m), train_set, opt)
println("fedility = ", amp' * amplitude(m, 10))
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment