Created
March 16, 2020 00:06
-
-
Save jido/974525f5e83d4f1768b16e6f3411f42b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Idea one: just respect the lifetime of the provided allocator | |
test "gradient sigmoid" { | |
var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator); | |
defer arena.deinit(); | |
var graph = try Graph.init(&arena.allocator); | |
defer graph.deinit(); | |
const a = try constant(f64, &graph, .{ | |
.{ 1, 2 }, | |
.{ 3, 4 }, | |
}); | |
const b = try sigmoid(&graph, a); | |
std.testing.expect(std.mem.eql(usize, b.shape, &[_]usize{ 2, 2 })); | |
const c = try mean(&graph, b); | |
const gradients = try gradient(&graph, c, &[_]Tensor{a}); | |
var session = try Session.init(&arena.allocator, &graph); | |
defer session.deinit(); | |
const actual = try session.run(gradients, .{}); | |
const expected = try eager.constant(f64, &arena.allocator, .{ | |
.{ 0.0492, 0.0262 }, | |
.{ 0.0113, 0.0044 }, | |
}); | |
expectEqual(f64, actual[0].f64, expected); | |
} | |
Idea two: ask the graph to handle allocation duties | |
test "gradient sigmoid" { | |
var allocator = std.heap.page_allocator; | |
var graph = try Graph.init(allocator); | |
defer graph.deinit(); | |
const a = try constant(f64, &graph, .{ | |
.{ 1, 2 }, | |
.{ 3, 4 }, | |
}); | |
const b = try sigmoid(&graph, a); | |
std.testing.expect(std.mem.eql(usize, b.shape, &[_]usize{ 2, 2 })); | |
const c = try mean(&graph, b); | |
const gradients = try gradient(&graph, c, &[_]Tensor{a}); | |
var session = try Session.init(allocator, &graph); | |
defer session.deinit(); | |
const actual = try session.run(gradients, .{}); | |
const expected = try eager.constant(f64, &graph, .{ | |
.{ 0.0492, 0.0262 }, | |
.{ 0.0113, 0.0044 }, | |
}); | |
expectEqual(f64, actual[0].f64, expected); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment