Skip to content

Instantly share code, notes, and snippets.

@nikola-j
Last active September 22, 2023 02:51
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nikola-j/b5bb6b141b8d9920318677e1bba70466 to your computer and use it in GitHub Desktop.
Save nikola-j/b5bb6b141b8d9920318677e1bba70466 to your computer and use it in GitHub Desktop.
Atan2 pytorch onnx
def my_atan2(y, x):
pi = torch.from_numpy(np.array([np.pi])).to(y.device, y.dtype)
ans = torch.atan(y / (x + 1e-6))
ans += ((y > 0) & (x < 0)) * pi
ans -= ((y < 0) & (x < 0)) * pi
ans *= (1 - ((y > 0) & (x == 0)) * 1.0)
ans += ((y > 0) & (x == 0)) * (pi / 2)
ans *= (1 - ((y < 0) & (x == 0)) * 1.0)
ans += ((y < 0) & (x == 0)) * (-pi / 2)
return ans
@nikola-j
Copy link
Author

nikola-j commented Aug 7, 2020

Atan2 PyTorch implementation that can be exported to onnx

@mrhe13
Copy link

mrhe13 commented Mar 29, 2022

Thanks a lot!!!!!!atan2 is not supported in onnx and i've been trapped for a long time....

@evi-Genius
Copy link

Thanks! It works.

@gzchenjiajun
Copy link

How does this write into the project?

@kometa-triatlon
Copy link

Better add an epsilon when dividing by x:

ans = torch.atan(y / (x + 1e-6))

@nikola-j
Copy link
Author

@kometa-triatlon no need, torch atan handles inf, eg:

a = torch.from_numpy(np.array([1.0]))
b = torch.from_numpy(np.array([0.0]))
torch.atan(a/b)

Result

tensor([1.5708], dtype=torch.float64)

@kometa-triatlon
Copy link

kometa-triatlon commented Sep 23, 2022

There can be a combination of parameters that would surprise you ;)

zero = torch.tensor([0.])
one = torch.tensor([1.])

print(my_atan2(zero, one), torch.atan2(zero, one))
print(my_atan2(one, zero), torch.atan2(one, zero))
print(my_atan2(one, one), torch.atan2(one, one))
print(my_atan2(zero, zero), torch.atan2(zero, zero))

Result:

tensor([0.]) tensor([0.])
tensor([1.5708]) tensor([1.5708])
tensor([0.7854]) tensor([0.7854])
tensor([nan]) tensor([0.])

With adding epsilon (ans = torch.atan(y/ (x + 1e-10))):

tensor([0.]) tensor([0.])
tensor([1.5708]) tensor([1.5708])
tensor([0.7854]) tensor([0.7854])
tensor([0.]) tensor([0.])

@nikola-j
Copy link
Author

nikola-j commented Sep 24, 2022

But atan2 of 0,0 is undefined, not 0. I think it's better to return nan, that way you would know that something is wrong.
Edit: okay I see most libraries handle 0,0 as 0, I'll add the epsilon, thanks for the suggestion

@nnbtam99
Copy link

nnbtam99 commented Oct 6, 2022

Hi @nikola-j , thank you for sharing. Have you tried this implementation with complex Tensors? If possible, could you share how you derive the aboved algorithm?
Thank you in advance!

Edited: I found it here: https://en.wikipedia.org/wiki/Atan2. Thank you!!!

@candlewill
Copy link

This optimized version includes the following improvements:

Added comments in English to explain each step in the code.
Used torch.tensor to create the pi tensor directly instead of using torch.from_numpy.
Defined eps as a separate variable, making it easier to adjust if needed.
These improvements make the code more readable while maintaining performance optimizations.

def onnx_atan2(y, x):
    # Create a pi tensor with the same device and data type as y
    pi = torch.tensor(np.pi, device=y.device, dtype=y.dtype)
    half_pi = pi / 2
    eps = 1e-6

    # Compute the arctangent of y/x
    ans = torch.atan(y / (x + eps))

    # Create boolean tensors representing positive, negative, and zero values of y and x
    y_positive = y > 0
    y_negative = y < 0
    x_negative = x < 0
    x_zero = x == 0

    # Adjust ans based on the positive, negative, and zero values of y and x
    ans += torch.where(y_positive & x_negative, pi, torch.zeros_like(ans))  # Quadrants I and II
    ans -= torch.where(y_negative & x_negative, pi, torch.zeros_like(ans))  # Quadrants III and IV
    ans = torch.where(y_positive & x_zero, half_pi, ans)  # Positive y-axis
    ans = torch.where(y_negative & x_zero, -half_pi, ans)  # Negative y-axis

    return ans

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment