Last active
December 2, 2022 23:51
-
-
Save subtleGradient/d1908ca1b1a4a2a4cb6631af0a743a59 to your computer and use it in GitHub Desktop.
hack pyannote to work on M1/M2 macOS with arm64 with MPS
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_devices(needs: int = None): | |
"""Get devices that can be used by the pipeline | |
Parameters | |
---------- | |
needs : int, optional | |
Number of devices needed by the pipeline | |
Returns | |
------- | |
devices : list of torch.device | |
List of available devices. | |
When `needs` is provided, returns that many devices. | |
""" | |
num_gpus = torch.cuda.device_count() | |
if num_gpus == 0: | |
devices = [torch.device("cpu")] | |
try: | |
if torch.has_mps: | |
devices = [torch.device('mps')] | |
print("Using Apple Metal Performance Shaders") | |
except: | |
pass | |
if needs is None: | |
return devices | |
if needs <= len(devices): | |
return devices[:needs] | |
if needs == len(devices): | |
return devices | |
# make the devices list the same length as needs by duplicating the first device | |
return list(itertools.islice(itertools.cycle(devices), needs)) | |
devices = [torch.device(f"cuda:{index:d}") for index in range(num_gpus)] | |
if needs is None: | |
return devices | |
return [device for _, device in zip(range(needs), itertools.cycle(devices))] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
print('Hack this file: "/Users/tom/opt/anaconda3/envs/nightly/lib/python3.9/site-packages/pyannote/audio/pipelines/utils/getter.py"') | |
import pyannote.audio.pipelines.utils.getter; import importlib; importlib.reload(pyannote.audio.pipelines.utils.getter) | |
pyannote.audio.pipelines.utils.getter.get_devices(needs=2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment