Created
May 24, 2024 14:30
-
-
Save mypy-play/f46acc15a837142a137759b4ca0a14d3 to your computer and use it in GitHub Desktop.
Shared via mypy Playground
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import typing | |
from typing import Any, Never, overload | |
class NDArray: | |
shape: tuple | |
def __setitem__(self, idx: NDArray, v: object) -> None: ... | |
def random(x: object) -> NDArray: return NDArray() | |
Image = typing.NewType('Image', NDArray) | |
Mask = typing.NewType('Mask', NDArray) | |
array = random((50, 50)) | |
image = Image(array) | |
mask_array = random((50, 50)) | |
mask = Mask(mask_array) | |
@overload | |
def zero(image: Mask, mask: NDArray) -> Never: ... | |
@overload | |
def zero(image: NDArray, mask: Image) -> Never: ... | |
@overload | |
def zero(image: NDArray, mask: NDArray) -> Image: ... | |
def zero(image : Image, mask : Mask) -> Image: # type: ignore[misc] | |
if not image.shape == mask.shape: | |
raise ValueError("Image and mask shapes must match") | |
image[mask] = 0 | |
return image | |
x = random((50, 50)) | |
# should pass (explicit) | |
zero(Image(x), Mask(x)).shape | |
# should fail (image as mask) | |
zero(Image(x), Image(x)).shape | |
# should pass (untyped ndarrays) | |
zero(x, x).shape |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment