Skip to content

Instantly share code, notes, and snippets.

@vfdev-5
Created July 6, 2020 15:59
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vfdev-5/de62c2e0043f6b8a493360409ed8a16b to your computer and use it in GitHub Desktop.
Save vfdev-5/de62c2e0043f6b8a493360409ed8a16b to your computer and use it in GitHub Desktop.
PyTorch Tech Share (July 06 2020) - Simple PyTorch distributed computation functionality testing with `pytest-xdist`.

PyTorch Tech Share (July 06 2020) - Simple PyTorch distributed computation functionality testing with pytest-xdist.

It is about one of many other approaches on how we can test a custom distributed computation functionality by emulating multiple processes.

What is "distributed setting" in PyTorch ?

  • Communications between N application's processes
    • send/receive tensors

Model and Data Parallelism

im1

Some used terms

  • WORLD_SIZE : number of processes used in computations, e.g. number of all GPUs across all machines
    • For example, we have 2 machines with 4 GPUs each, WORLD_SIZE is 2 * 4
  • RANK : process unique identifier, varies between 0 to WORLD_SIZE - 1
  • local rank : machine-wise process identifier, e.g. GPU index in the node.

Why and when do I need that ?

Deep learning applications:

  • Metric's computation (e.g. accuracy) in distributed setting
  • Helper tools to extend current pytorch functionality
  • Yet Another Distributed training framework (like torch.distributed, horovod, ...)

Code and Test

Accuracy metric implementation:

import torch.distributed as dist

class Accuracy:

    def __init__(self):
        self.num_samples = 0
        self.num_correct = 0

    def update(self, y_pred, y):
        self.num_samples += y_pred.shape[0]
        self.num_correct += (y_pred == y).sum().item()

    def compute(self):
        # We need to collect `num_correct` and `num_samples` across participating processes
        # ...
        # dist.all_reduce(tensor_num_correct)
        # dist.all_reduce(tensor_num_samples)
        # ...
        accuracy = self.num_correct / self.num_samples
        return accuracy
# test accuracy
@pytest.fixture()
def local_rank(worker_id):
    """ use a different account in each xdist worker """
    import os

    if "gw" in worker_id:
        lrank = int(worker_id.replace("gw", ""))
    elif "master" == worker_id:
        lrank = 0
    else:
        raise RuntimeError("Can not get rank from worker_id={}".format(worker_id))

    yield lrank


@pytest.fixture()
def distributed_context(local_rank):
    import os

    rank = local_rank
    world_size = os.environ["WORLD_SIZE"]
    yield {
        "local_rank": local_rank,
        "rank": rank,
        "world_size": world_size,
    }

def test_accuracy(distributed_context):

    # setup y_pred dependent on rank
    # setup y dependent on rank
    acc = Accuracy()
    acc.update(y_pred, y)
    assert acc.compute() == true_acc

Execute the test

Single node, multiple processes

  • Run 4 processes
WORLD_SIZE=4 pytest --dist=each --tx 4*popen//python=python3.7 -vvv tests/test_accuracy.py

How others do ?

PyTorch

Not clear

Horovod

  • They use their runner application
horovodrun -n 2 -H localhost:2 --gloo  pytest -v test/test_torch.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment