Skip to content

Instantly share code, notes, and snippets.

@bawahakim
Created May 26, 2023 13:58
Show Gist options
  • Save bawahakim/11dcec17e9ce1a9b392df2da3d33284c to your computer and use it in GitHub Desktop.
Save bawahakim/11dcec17e9ce1a9b392df2da3d33284c to your computer and use it in GitHub Desktop.
Atan2 for tflite
def atan2(y:float, x: float):
# Terms of 18 gives precision of 2 decimals equivalent of tf.atan2
def taylor_atan(x:float, terms: int = 18):
pi_half = math.pi / 2
original_x = x
flips = 0
# Handle x > 1 and x < -1 iteratively.
while abs(x) > 1:
x = 1 / x
flips += 1
result = tf.zeros_like(x)
# Perform Taylor series calculation
for n in range(terms):
result += ((-1) ** n) * x ** (2 * n + 1) / (2 * n + 1)
# If number of flips is odd, flip the sign of result and subtract from pi_half
if flips % 2 != 0:
result = pi_half - result if original_x > 0 else -pi_half - result
return result
y_over_x = y / (x + 1e-20)
result = taylor_atan(tf.abs(y_over_x))
result = tf.where((x < 0) & (y >= 0), np.pi - result, result)
result = tf.where((x < 0) & (y < 0), -(np.pi - result), result)
result = tf.where((x > 0) & (y < 0), -result, result)
result = tf.where((x == 0) & (y > 0), np.pi / 2, result)
result = tf.where((x == 0) & (y < 0), -np.pi / 2, result)
result = tf.where((y == 0) & (x > 0), 0.0, result)
result = tf.where((y == 0) & (x < 0), np.pi, result)
result = tf.where((y == 0) & (x == 0), 0.0, result)
return result
import itertools
import numpy as np
import pytest
import tensorflow as tf
from common.tf_helper import TfHelper
values = [0.0, 0.004, 0.453, 1.0, 11.32, 103.65]
signs = [-1, 1]
test_values = list(itertools.product(values, signs, values, signs))
@pytest.mark.parametrize("y_val, y_sign, x_val, x_sign", test_values)
def test_atan2(y_val, y_sign, x_val, x_sign):
y = y_val * y_sign
x = x_val * x_sign
if y == -0.0 or x == -0.0:
return
result = TfHelper.atan2(y, x)
np.testing.assert_almost_equal(result.numpy(), tf.atan2(y, x), 2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment