Skip to content

Instantly share code, notes, and snippets.

@o-jill
Created September 27, 2024 03:46
Show Gist options
  • Save o-jill/cae97f16465206e053a65927e384cd4d to your computer and use it in GitHub Desktop.
Save o-jill/cae97f16465206e053a65927e384cd4d to your computer and use it in GitHub Desktop.
overwrite weights in varstore.
let vs = VarStore::new(Device::Cpu);
let nnet = net(&vs.root());
{
let mut var = vs.variables_.as_ref().lock().unwrap();
for k in var.named_variables.keys() {
println!("key:{k}");
var.named_variables.get(k).unwrap().print();
}
let newten = Tensor::from_slice(&[1.234f32]);
*var.named_variables.get_mut("layer2.bias").unwrap() = newten;
var.named_variables.get("layer2.bias").unwrap().print();
for k in var.named_variables.keys() {
println!("key:{k}");
var.named_variables.get(k).unwrap().print();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment