Atan2 pytorch onnx
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
|def my_atan2(y, x):|
|pi = torch.from_numpy(np.array([np.pi])).to(y.device, y.dtype)|
|ans = torch.atan(y / (x + 1e-6))|
|ans += ((y > 0) & (x < 0)) * pi|
|ans -= ((y < 0) & (x < 0)) * pi|
|ans *= (1 - ((y > 0) & (x == 0)) * 1.0)|
|ans += ((y > 0) & (x == 0)) * (pi / 2)|
|ans *= (1 - ((y < 0) & (x == 0)) * 1.0)|
|ans += ((y < 0) & (x == 0)) * (-pi / 2)|
Oct 6, 2022
Hi @nikola-j , thank you for sharing. Have you tried this implementation with complex Tensors? If possible, could you share how you derive the aboved algorithm?
Thank you in advance!
Edited: I found it here: https://en.wikipedia.org/wiki/Atan2. Thank you!!!
Apr 27, 2023
This optimized version includes the following improvements:
Added comments in English to explain each step in the code.
Used torch.tensor to create the pi tensor directly instead of using torch.from_numpy.
Defined eps as a separate variable, making it easier to adjust if needed.
These improvements make the code more readable while maintaining performance optimizations.
def onnx_atan2(y, x): # Create a pi tensor with the same device and data type as y pi = torch.tensor(np.pi, device=y.device, dtype=y.dtype) half_pi = pi / 2 eps = 1e-6 # Compute the arctangent of y/x ans = torch.atan(y / (x + eps)) # Create boolean tensors representing positive, negative, and zero values of y and x y_positive = y > 0 y_negative = y < 0 x_negative = x < 0 x_zero = x == 0 # Adjust ans based on the positive, negative, and zero values of y and x ans += torch.where(y_positive & x_negative, pi, torch.zeros_like(ans)) # Quadrants I and II ans -= torch.where(y_negative & x_negative, pi, torch.zeros_like(ans)) # Quadrants III and IV ans = torch.where(y_positive & x_zero, half_pi, ans) # Positive y-axis ans = torch.where(y_negative & x_zero, -half_pi, ans) # Negative y-axis return ans
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
But atan2 of 0,0 is undefined, not 0. I think it's better to return nan, that way you would know that something is wrong.
Edit: okay I see most libraries handle 0,0 as 0, I'll add the epsilon, thanks for the suggestion