Created
February 15, 2023 07:29
-
-
Save russelljjarvis/012762025dcba23f30cec03a2ad61658 to your computer and use it in GitHub Desktop.
Potjans_init.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
sing LinearAlgebra, Random, Statistics, StatsBase, ArgParse, SymmetricFormats | |
using Plots | |
using StatsBase | |
using Random | |
using ProgressMeter | |
using StatsPlots | |
using Plots | |
#using UnicodePlots | |
unicodeplots() | |
using Pkg; Pkg.activate(dirname(@__DIR__), io=devnull) | |
using LinearAlgebra, Random, JLD2, Statistics, StatsBase, ArgParse, SymmetricFormats, UnPack | |
println(BLAS.get_config()) | |
function ArgParse.parse_item(::Type{Vector{Int}}, x::AbstractString) | |
return eval(Meta.parse(x)) | |
end | |
#verbose = false | |
verbose = true | |
aps = ArgParseSettings() | |
@add_arg_table! aps begin | |
"data_dir" | |
help = "full path to the directory containing the parameters file" | |
required = true | |
"--itasks", "-t" | |
help = "a vector specifying which tasks to learn" | |
arg_type = Vector{Int} | |
default = [1] | |
range_tester = x->all(x.>0) | |
end | |
add_arg_group!(aps, "mutually exclusive arguments. if neither is specified, sinusoids\nwill be generated for synpatic inputs", exclusive = true); | |
@add_arg_table! aps begin | |
"--utarg_file", "-u" | |
help = "full path to the JLD file containing the synaptic current targets" | |
"--spikerate_file", "-s" | |
help = "full path to the JLD file containing the spike rates" | |
end | |
parsed_args = parse_args(aps) | |
# --- set up variables --- # | |
module Param | |
using LinearAlgebra, SymmetricFormats, Random, JLD2 | |
function set_data_dir(path) | |
global data_dir = path | |
end | |
save_params() = save(joinpath(data_dir, "param.jld2"), "param", (; | |
PPrecision, PScale, FloatPrecision, IntPrecision, PType, | |
seed, rng_func, | |
example_neurons, wid, maxrate, | |
penlambda, penlamFF, penmu, | |
learn_every, | |
train_duration, stim_on, stim_off, train_time, | |
dt, Nsteps, | |
Ncells, Ne, Ni, | |
tau_meme, tau_memi, | |
K, L, LX, Lexc, Linh, | |
X_bal, | |
vre, threshe, threshi, refrac, | |
tau_bale, tau_bali, tau_plas, | |
noise_model, sig, | |
correlation_var, | |
:genXStim_file => abspath(@__DIR__, genXStim_file), | |
genXStim_args, | |
:genUTarget_file => abspath(@__DIR__, genUTarget_file), | |
genUTarget_args, | |
:genRateX_file => abspath(@__DIR__, genRateX_file), | |
genRateX_args, | |
:genStaticWeights_file => abspath(@__DIR__, genStaticWeights_file), | |
genStaticWeights_args, | |
:genPlasticWeights_file => abspath(@__DIR__, genPlasticWeights_file), | |
genPlasticWeights_args, | |
:cellModel_file => abspath(@__DIR__, cellModel_file), | |
cellModel_args, | |
choose_task_func, | |
)) | |
end | |
Param.set_data_dir(parsed_args["data_dir"]) | |
Param.include(joinpath(parsed_args["data_dir"], "param.jl")) | |
Param.save_params() | |
include(joinpath(@__DIR__,"cpu","variables.jl")) | |
if Param.Ncells == typemax(Param.IntPrecision) | |
@warn "IntPrecision is too small for GPU (but fine for CPU)" | |
elseif Param.Ncells > typemax(Param.IntPrecision) | |
@error "IntPrecision is too small" | |
end | |
# --- load code --- # | |
macro maybethread(loop) | |
if Threads.nthreads()>1 | |
quote Threads.@threads $(Expr(loop.head, | |
Expr(loop.args[1].head, esc.(loop.args[1].args)...), | |
esc(loop.args[2]))); end | |
else | |
@warn "running single threaded" | |
quote $(esc(loop)); end | |
end | |
end | |
kind=:init | |
include(Param.genStaticWeights_file) | |
include(Param.genPlasticWeights_file) | |
include(Param.genRateX_file) | |
include(Param.genUTarget_file) | |
include(Param.genXStim_file) | |
include(Param.cellModel_file) | |
include(joinpath("cpu","loop.jl")) | |
include("rate2utarg.jl") | |
# --- initialization --- # | |
w0Index, w0Weights, nc0 = genStaticWeights(Param.genStaticWeights_args) | |
UnicodePlots.spy(w0Weights) |> display | |
#UnicodePlots.spy(w0Index) |> display | |
rateX = genRateX(Param.genRateX_args) | |
itask = 1 | |
times = [] | |
for i in 1:Param.Ncells | |
push!(times,Float32[]) | |
end | |
uavg, ns0, ustd,times = loop_init(itask, nothing, nothing, p.stim_off, p.train_time, dt, | |
p.Nsteps, p.Ncells, p.Ne, nothing, refrac, vre, invtauedecay, | |
invtauidecay, nothing, mu, thresh, tau, nothing, times, ns, nothing, | |
ns_ffwd, forwardInputsE, forwardInputsI, nothing, forwardInputsEPrev, | |
forwardInputsIPrev, nothing, nothing, nothing, xedecay, xidecay, nothing, | |
synInputBalanced, synInput, nothing, nothing, bias, nothing, nothing, lastSpike, | |
nothing, nothing, nothing, nothing, nothing, v, rng, noise, rndFfwd, sig, | |
nothing, nothing, w0Index, w0Weights, nc0, nothing, nothing, nothing, | |
nothing, nothing, nothing, nothing, nothing, nothing, nothing, uavg, | |
utmp, ffwdRate) | |
function convert_ragged_arraytodense(times) | |
converttimes = [] | |
convertnodes = [] | |
for (ind,i) in enumerate(times) | |
for t in i | |
append!(converttimes,i) | |
append!(convertnodes,ind) | |
end | |
end | |
(converttimes,convertnodes) | |
end | |
(times,nodes) = convert_ragged_arraytodense(times) | |
wpWeightFfwd, wpWeightIn, wpIndexIn, ncpIn = | |
genPlasticWeights(p.genPlasticWeights_args, w0Index, nc0, ns0) | |
if verbose | |
UnicodePlots.spy(wpWeightIn) |> display | |
#UnicodePlots.spy(wpIndexIn) |> display | |
end | |
# get indices of postsynaptic cells for each presynaptic cell | |
wpIndexConvert = zeros(Int, p.Ncells,p.Ncells)# p.Lexc+p.Linh) | |
wpIndexOutD = Dict() | |
ncpOut = Array{Int}(undef, p.Ncells) | |
for i = 1:p.Ncells | |
wpIndexOutD[i] = [] | |
end | |
for postCell = 1:p.Ncells | |
if postCell!=0 | |
for i = 1:ncpIn[postCell] | |
## | |
# wpIndexIn | |
# incompatible with indexs provided. | |
## | |
if postCell <= sizeof(wpIndexIn[:,i]) # unfortunate hack | |
# to do figure out how to get rid of unfortunate hack. | |
preCell = wpIndexIn[postCell,i] # This line fails | |
if preCell!=0 | |
push!(wpIndexOutD[preCell], postCell) | |
wpIndexConvert[postCell,i] = length(wpIndexOutD[preCell]) | |
end | |
end | |
end | |
end | |
end | |
for preCell = 1:p.Ncells | |
if preCell!=0 | |
ncpOut[preCell] = length(wpIndexOutD[preCell]) | |
end | |
end | |
# get weight, index of outgoing connections | |
wpIndexOut = zeros(Int, maximum(ncpOut),p.Ncells) | |
for preCell = 1:p.Ncells-1 | |
wpIndexOut[1:ncpOut[preCell],preCell] = wpIndexOutD[preCell] | |
end | |
Ntime = floor(Int, (p.train_time-p.stim_off)/p.learn_every) | |
if parsed_args["xtarg_file"] !== nothing | |
xtarg_dict = load(parsed_args["xtarg_file"]) | |
xtarg = xtarg_dict[first(keys(xtarg_dict))] | |
if size(xtarg,1) != Ntime | |
error(parsed_args["xtarg_file"], | |
" should have (train_time-stim_off)/learn_every = ", | |
Ntime, " rows") | |
end | |
ndims(xtarg)==2 && (xtarg = xtarg[:,:,[CartesianIndex()]]) | |
if any(parsed_args["itasks"] .> size(xtarg,3)) | |
error("an element of --itasks exceeds the size of the third dimension of ", | |
parsed_args["xtarg_file"]) | |
end | |
xtarg = xtarg[:,:,parsed_args["itasks"]] | |
elseif parsed_args["spikerate_file"] !== nothing | |
xtarg = rate2synInput(p, p.K==0 ? sig : (ustd / sqrt(p.tauedecay * 1.3))) # factor 1.3 was calibrated manually | |
else | |
xtarg = Array{Float64}(undef, Ntime, p.Ncells, length(parsed_args["itasks"])) | |
for itask = 1:length(parsed_args["itasks"]) | |
xtarg[:,:,itask] = genTarget(p.genTarget_args, uavg) | |
end | |
end | |
timeSteps = round(Int, (p.stim_off - p.stim_on) / dt) | |
stim = Array{Float64}(undef, timeSteps, p.Ncells, length(parsed_args["itasks"])) | |
for itask = 1:length(parsed_args["itasks"]) | |
stim[:,:,itask] = genStim(p.genStim_args) | |
end | |
# --- set up correlation matrix --- # | |
pLrec = p.Lexc + p.Linh; | |
# L2-penalty | |
Pinv_L2 = Diagonal(repeat([p.penlambda], pLrec)); | |
# row sum penalty | |
vec10 = [ones(p.Lexc); zeros(p.Linh)]; | |
vec01 = [zeros(p.Lexc); ones(p.Linh)]; | |
Pinv_rowsum = p.penmu*(vec10*vec10' + vec01*vec01'); | |
# sum of penalties | |
Pinv_rec = Pinv_L2 + Pinv_rowsum; | |
Pinv_ffwd = Diagonal(repeat([p.penlamFF], p.Lffwd)) | |
Pinv = zeros(pLrec+p.Lffwd, pLrec+p.Lffwd) | |
Pinv[1:pLrec, 1:pLrec] = Pinv_rec | |
Pinv[pLrec+1 : pLrec+p.Lffwd, pLrec+1 : pLrec+p.Lffwd] = Pinv_ffwd | |
Pinv_norm = p.PType(Symmetric(UpperTriangular(Pinv) \ I)); | |
println("got here") | |
#----------- save initialization --------------# | |
save(joinpath(parsed_args["data_dir"],"param.jld2"), "p", p) | |
save(joinpath(parsed_args["data_dir"],"w0Index.jld2"), "w0Index", w0Index) | |
save(joinpath(parsed_args["data_dir"],"w0Weights.jld2"), "w0Weights", w0Weights) | |
save(joinpath(parsed_args["data_dir"],"nc0.jld2"), "nc0", nc0) | |
save(joinpath(parsed_args["data_dir"],"stim.jld2"), "stim", stim) | |
save(joinpath(parsed_args["data_dir"],"xtarg.jld2"), "xtarg", xtarg) | |
save(joinpath(parsed_args["data_dir"],"wpIndexIn.jld2"), "wpIndexIn", wpIndexIn) | |
save(joinpath(parsed_args["data_dir"],"wpIndexOut.jld2"), "wpIndexOut", wpIndexOut) | |
save(joinpath(parsed_args["data_dir"],"wpIndexConvert.jld2"), "wpIndexConvert", wpIndexConvert) | |
save(joinpath(parsed_args["data_dir"],"wpWeightFfwd.jld2"), "wpWeightFfwd", wpWeightFfwd) | |
save(joinpath(parsed_args["data_dir"],"wpWeightIn.jld2"), "wpWeightIn", wpWeightIn) | |
save(joinpath(parsed_args["data_dir"],"ncpIn.jld2"), "ncpIn", ncpIn) | |
save(joinpath(parsed_args["data_dir"],"ncpOut.jld2"), "ncpOut", ncpOut) | |
save(joinpath(parsed_args["data_dir"],"P.jld2"), "P", Pinv_norm) | |
save(joinpath(parsed_args["data_dir"],"ffwdRate.jld2"), "ffwdRate", ffwdRate) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment