Skip to content

Instantly share code, notes, and snippets.

@kongzii
Last active January 30, 2020 07:40
Show Gist options
  • Save kongzii/62b9d978a6536bb97095ed3fb74e30fd to your computer and use it in GitHub Desktop.
Save kongzii/62b9d978a6536bb97095ed3fb74e30fd to your computer and use it in GitHub Desktop.
Example of saving trained weights of model in the Swift For Tensorflow
import Foundation
import Python
import TensorFlow
public struct MyModel : Layer {
public var conv1d: Conv1D<Float>
public var dense1: Dense<Float>
public var dropout: Dropout<Float>
public var denseOut: Dense<Float>
public init() {
self.conv1d = Conv1D<Float>(filterShape: (2, 300, 100))
self.dense1 = Dense<Float>(inputSize: 100,
outputSize: 50,
activation: relu)
self.dropout = Dropout<Float>(probability: 0.02)
self.denseOut = Dense<Float>(inputSize: 50,
outputSize: 2,
activation: sigmoid)
}
@differentiable
public func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
let l1 = self.conv1d(input)
let l2 = self.dense1(l1)
let l3 = self.dropout(l2)
let out = self.denseOut(l3)
return out.squeezingShape()
}
}
extension Layer {
mutating public func loadWeights(numpyFile: String) {
let np = Python.import("numpy")
let weights = np.load(numpyFile, allow_pickle: true)
for (index, kp) in self.recursivelyAllWritableKeyPaths(to: Tensor<Float>.self).enumerated() {
self[keyPath: kp] = Tensor<Float>(numpy: weights[index])!
}
}
public func saveWeights(numpyFile: String) {
var weights: Array<PythonObject> = []
for kp in self.recursivelyAllWritableKeyPaths(to: Tensor<Float>.self) {
weights.append(self[keyPath: kp].makeNumpyArray())
}
let np = Python.import("numpy")
np.save(numpyFile, np.array(weights))
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment