Skip to content

Instantly share code, notes, and snippets.

@hangingman
Last active May 7, 2023 05:38
Show Gist options
  • Save hangingman/16c17f47bad8fca55424d8d823d186fc to your computer and use it in GitHub Desktop.
Save hangingman/16c17f47bad8fca55424d8d823d186fc to your computer and use it in GitHub Desktop.
import os
import time
import torch
def format_to_gb(item, precision=4):
"""quick function to format numbers to gigabyte and round to (default) 4 digit precision"""
metric_num = item / gigabyte_size
metric_num = round(metric_num, ndigits=precision)
return metric_num
gigabyte_size = 1073741824
megabyte_size = 1048576
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print(device)
# 手書き数字の画像データMNISTをダウンロード
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version=1, data_home=".", parser='auto', as_frame=False)
X = mnist.data / 255 # 0-255を0-1に正規化
y = mnist.target
# MNISTのデータセットの変更により、ラベルが数値データになっていないので、
# 以下により、NumPyの配列の数値型に変換します
import numpy as np
y = np.array(y)
y = y.astype(np.int32)
# 2. DataLoderの作成
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
# 2.1 データを訓練とテストに分割(6:1)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=1/7, random_state=0)
# 2.2 データをPyTorchのTensorに変換
X_train = torch.Tensor(X_train).to(device)
X_test = torch.Tensor(X_test).to(device)
y_train = torch.LongTensor(y_train).to(device)
y_test = torch.LongTensor(y_test).to(device)
# 2.3 データとラベルをセットにしたDatasetを作成
ds_train = TensorDataset(X_train, y_train)
ds_test = TensorDataset(X_test, y_test)
# 2.4 データセットのミニバッチサイズを指定した、Dataloaderを作成
# Chainerのiterators.SerialIteratorと似ている
loader_train = DataLoader(ds_train, batch_size=64, shuffle=True)
loader_test = DataLoader(ds_test, batch_size=64, shuffle=False)
# 3.1 ネットワークの構築
from torch import nn
model = nn.Sequential()
model.add_module('fc1', nn.Linear(28*28*1, 100))
model.add_module('relu1', nn.ReLU())
model.add_module('fc2', nn.Linear(100, 100))
model.add_module('relu2', nn.ReLU())
model.add_module('fc3', nn.Linear(100, 10))
model.to(device)
print(model)
# 4. 誤差関数と最適化手法の設定
# 誤差関数の設定
# 重みを学習する際の最適化手法の選択
from torch import optim
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
# 5. 学習と推論の設定
# 5-1. 学習1回でやることを定義します
def train(epoch):
model.train() # ネットワークを学習モードに切り替える
# データローダーから1ミニバッチずつ取り出して計算する
for data, targets in loader_train:
data = data.to(device)
targets = targets.to(device)
optimizer.zero_grad() # 一度計算された勾配結果を0にリセット
outputs = model(data) # 入力dataをinputし、出力を求める
loss = loss_fn(outputs, targets) # 出力と訓練データの正解との誤差を求める
loss.backward() # 誤差のバックプロパゲーションを求める
optimizer.step() # バックプロパゲーションの値で重みを更新する
print("epoch{}:終了\n".format(epoch))
# 5-2. 推論1回でやることを定義します
# run_pippyを使っているので、pp_ranksとargsがデフォルトで与えられる
def test(pp_ranks, args):
# PiPPyのパラメーター設定(pipelineの設定)
# ドキュメントによるとモデルをcudaに送った後pipeを作るのが正しいらしい
from pippy import split_into_equal_size
from pippy.IR import Pipe, MultiUseParameterConfig
from pippy.utils import exclude_master
MULTI_USE_PARAM_CONFIG = MultiUseParameterConfig.REPLICATE
print(f'REPLICATE config: {MULTI_USE_PARAM_CONFIG}')
number_of_workers = len(pp_ranks) # 使いたいGPUの数
print(f"number_of_workers = {number_of_workers}")
args.model.eval() # ネットワークを推論モードに切り替える
model_init_start = time.time()
split_policy = split_into_equal_size(number_of_workers)
# pipe_driverの設定
import pippy
# All ranks call into it
driver, stage_mod = pippy.all_compile(
args.model,
num_ranks=args.world_size,
num_chunks=args.world_size,
schedule="FillDrain",
split_policy=split_policy,
)
if args.rank!=0:
return
model_init_end = time.time()
print("Model initialization time")
print("=========================")
print("{} seconds".format(model_init_end - model_init_start))
memory_reserved = format_to_gb(torch.cuda.memory_reserved())
memory_allocated = format_to_gb(torch.cuda.memory_allocated())
print("memory_reserved after model intializaed with pipelines on each rank")
print("===================================================================")
print(" {} GB".format(memory_reserved))
print("memory_allocated after model intializaed with pipelines on each rank")
print("===================================================================")
print(" {} GB".format(memory_allocated))
this_file_name = os.path.splitext(os.path.basename(__file__))[0]
print('Running model pipeline.')
# データローダーから1ミニバッチずつ取り出して計算する
correct = 0
with torch.no_grad():
for data, targets in loader_test:
data = data.to(device)
targets = targets.to(device)
outputs = driver(data)
_, predicted = torch.max(outputs.data, 1) # 確率が最大のラベルを求める
correct += predicted.eq(targets.data.view_as(predicted)).sum() # 正解と一緒だったらカウントアップ
print('Inference is finished')
# 正解率を出力
data_num = len(loader_test.dataset) # データの総数
print('\nテストデータの正解率: {}/{} ({:.0f}%)\n'.format(correct,
data_num, 100. * correct / data_num))
# 5-3. 推論をPiPPyでやるために定義
from pippy import run_pippy
# def run_pippy(run_func, args, *extra_args):
# この関数がPiPPy側で用意されているので推論時に使うようにする
# コマンドラインで使えるようargsを読み取ってみる
# ref: https://qiita.com/uenonuenon/items/09fa620426b4c5d4acf9
import argparse
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 2)))
parser.add_argument('--pp_group_size', type=int, default=2)
parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1)))
parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost'))
parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500'))
parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available()))
args = parser.parse_args(args=[])
assert args.world_size % args.pp_group_size == 0
args.dp_group_size = args.world_size // args.pp_group_size
args.gspmd = 1
args.model = model
# 学習なしにテストデータで推論してみよう
run_pippy(test, args)
# 6. 学習と推論の実行
import time
t1 = time.time()
for epoch in range(10):
train(epoch)
run_pippy(test, args)
@hangingman
Copy link
Author

hangingman commented May 7, 2023

2つのGPUで実行テストしてうまくいった記念にログを書いておく
MNISTなのでデータ量が小さいが、もっと大きなモデルを使う際に威力がでそう。

  • 学習なしにテストデータで推論 部分
python test.py
cuda
Sequential(
  (fc1): Linear(in_features=784, out_features=100, bias=True)
  (relu1): ReLU()
  (fc2): Linear(in_features=100, out_features=100, bias=True)
  (relu2): ReLU()
  (fc3): Linear(in_features=100, out_features=10, bias=True)
)
[PiPPy] World size: 2, DP group size: 1, PP group size: 2
cuda
cuda
Sequential(
  (fc1): Linear(in_features=784, out_features=100, bias=True)
  (relu1): ReLU()
  (fc2): Linear(in_features=100, out_features=100, bias=True)
  (relu2): ReLU()
  (fc3): Linear(in_features=100, out_features=10, bias=True)
)
rank = 1 host/pid/device = cowgirl/132259/cuda:1
Sequential(
  (fc1): Linear(in_features=784, out_features=100, bias=True)
  (relu1): ReLU()
  (fc2): Linear(in_features=100, out_features=100, bias=True)
  (relu2): ReLU()
  (fc3): Linear(in_features=100, out_features=10, bias=True)
)
rank = 0 host/pid/device = cowgirl/132258/cuda:0
REPLICATE config: MultiUseParameterConfig.REPLICATE
number_of_workers = 2
REPLICATE config: MultiUseParameterConfig.REPLICATE
number_of_workers = 2
Model initialization time
=========================
0.02849125862121582 seconds
memory_reserved after model intializaed with pipelines on each rank
===================================================================
 0.207 GB
memory_allocated after model intializaed with pipelines on each rank
===================================================================
 0.2059 GB
Running model pipeline.
Inference is finished

テストデータの正解率: 889/10000 (9%)
  • 学習の実行
epoch0:終了
epoch1:終了
epoch2:終了
epoch3:終了
epoch4:終了
epoch5:終了
epoch6:終了
epoch7:終了
epoch8:終了
epoch9:終了
  • 再び推論の実行
[PiPPy] World size: 2, DP group size: 1, PP group size: 2
cuda
cuda
Sequential(
  (fc1): Linear(in_features=784, out_features=100, bias=True)
  (relu1): ReLU()
  (fc2): Linear(in_features=100, out_features=100, bias=True)
  (relu2): ReLU()
  (fc3): Linear(in_features=100, out_features=10, bias=True)
)
rank = 0 host/pid/device = cowgirl/133527/cuda:0
Sequential(
  (fc1): Linear(in_features=784, out_features=100, bias=True)
  (relu1): ReLU()
  (fc2): Linear(in_features=100, out_features=100, bias=True)
  (relu2): ReLU()
  (fc3): Linear(in_features=100, out_features=10, bias=True)
)
rank = 1 host/pid/device = cowgirl/133528/cuda:1
REPLICATE config: MultiUseParameterConfig.REPLICATE
number_of_workers = 2
REPLICATE config: MultiUseParameterConfig.REPLICATE
number_of_workers = 2
Model initialization time
=========================
0.03170156478881836 seconds
memory_reserved after model intializaed with pipelines on each rank
===================================================================
 0.207 GB
memory_allocated after model intializaed with pipelines on each rank
===================================================================
 0.2059 GB
Running model pipeline.
Inference is finished

テストデータの正解率: 9662/10000 (97%)

[W CudaIPCTypes.cpp:15] Producer process has been terminated before all shared CUDA tensors released. See Note [Sharing CUDA tensors]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment