Skip to content

Instantly share code, notes, and snippets.

@nikola-j
Last active September 22, 2023 02:51
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save nikola-j/b5bb6b141b8d9920318677e1bba70466 to your computer and use it in GitHub Desktop.
Atan2 pytorch onnx
def my_atan2(y, x):
pi = torch.from_numpy(np.array([np.pi])).to(y.device, y.dtype)
ans = torch.atan(y / (x + 1e-6))
ans += ((y > 0) & (x < 0)) * pi
ans -= ((y < 0) & (x < 0)) * pi
ans *= (1 - ((y > 0) & (x == 0)) * 1.0)
ans += ((y > 0) & (x == 0)) * (pi / 2)
ans *= (1 - ((y < 0) & (x == 0)) * 1.0)
ans += ((y < 0) & (x == 0)) * (-pi / 2)
return ans
@nnbtam99
Copy link

nnbtam99 commented Oct 6, 2022

Hi @nikola-j , thank you for sharing. Have you tried this implementation with complex Tensors? If possible, could you share how you derive the aboved algorithm?
Thank you in advance!

Edited: I found it here: https://en.wikipedia.org/wiki/Atan2. Thank you!!!

@candlewill
Copy link

This optimized version includes the following improvements:

Added comments in English to explain each step in the code.
Used torch.tensor to create the pi tensor directly instead of using torch.from_numpy.
Defined eps as a separate variable, making it easier to adjust if needed.
These improvements make the code more readable while maintaining performance optimizations.

def onnx_atan2(y, x):
    # Create a pi tensor with the same device and data type as y
    pi = torch.tensor(np.pi, device=y.device, dtype=y.dtype)
    half_pi = pi / 2
    eps = 1e-6

    # Compute the arctangent of y/x
    ans = torch.atan(y / (x + eps))

    # Create boolean tensors representing positive, negative, and zero values of y and x
    y_positive = y > 0
    y_negative = y < 0
    x_negative = x < 0
    x_zero = x == 0

    # Adjust ans based on the positive, negative, and zero values of y and x
    ans += torch.where(y_positive & x_negative, pi, torch.zeros_like(ans))  # Quadrants I and II
    ans -= torch.where(y_negative & x_negative, pi, torch.zeros_like(ans))  # Quadrants III and IV
    ans = torch.where(y_positive & x_zero, half_pi, ans)  # Positive y-axis
    ans = torch.where(y_negative & x_zero, -half_pi, ans)  # Negative y-axis

    return ans

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment