- Allen Levoie (@allenlavoie)
- Paige Bailey (@dynamicwebpaige)
Current native support for forward-mode differentiation calls backward on a function twice. But executing this as a tf.function prevents retracing and the extra backward pass gets pruned. Now, It’s a matter of optimizing the current API and making the user experience better. Integrating tf.vectorized_map to facilitate batching of tangents falls under both the agendas. After this, hessian matrices can be computed efficiently using both forward and backward mode differentiation and, workarounds like calculating gradients/jacobians inside GradientTape context can be avoided.
- Support batches of tangents for regular ops in _jvp_helper
- ForwardAccumulator that batches in tangents
- Test the new accumulator over various ops
- Since any function inside the context of a ForwardAccumulator runs as a tf.function, tangensts in multiple nested accumaltors are passed through forward function wrapping in function.py. Currently, it's passing individual tangents and needs be be changed to accommodate batched.
- These are some tests that fail because of this.
- Solving the paceholder issue mentioned above by extending support to tf.function.
- Test _batch_accumulator extensively and include in the Public API.
- Extend batching support to custom gradients.