Skip to content

Instantly share code, notes, and snippets.

@sherwoac
Last active January 19, 2022 10:42
Show Gist options
  • Save sherwoac/ddfd4f9e4a5e60e883c348ad81607b6e to your computer and use it in GitHub Desktop.
Save sherwoac/ddfd4f9e4a5e60e883c348ad81607b6e to your computer and use it in GitHub Desktop.
_get_unit_square_intercepts
def _get_unit_square_intercepts(self, slopes, intercept):
"""
returns unit square intercepts for given slope (a) and intercepts (b)
y = ax + b
solves:
right: y = a + b
x = 1
y = slopes + intercept
left: y = b
x = 0
y = intercept
top: 1 = ax + b
x = torch.divide(1 - intercept, slopes)
y = 1
bottom: 0 = ax + b
x = torch.divide(- intercept, slopes)
y = 0
:param slopes: b x 1
:param intercepts: b x 1
:return: points where line intersects unit square borders: b x pts(x, y): b x 2 x 2
"""
batches = slopes.size(0)
x = torch.column_stack([torch.ones(batches),
torch.zeros(batches),
torch.divide(1 - intercept, slopes),
torch.divide(-1 * intercept, slopes)])
y = torch.column_stack([slopes + intercept,
intercept,
torch.ones(batches),
torch.zeros(batches)])
acceptance = (y >= 0) * (y <= 1) * (x >= 0) * (x <= 1)
return torch.column_stack((x[acceptance].reshape(batches, 1, -1),
y[acceptance].reshape(batches, 1, -1))) # b x pts(x, y)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment