Skip to content

Instantly share code, notes, and snippets.

@thsunkid
Last active September 5, 2021 15:11
Show Gist options
  • Save thsunkid/28731eddd4192cb10f8441e338d84d35 to your computer and use it in GitHub Desktop.
Save thsunkid/28731eddd4192cb10f8441e338d84d35 to your computer and use it in GitHub Desktop.
XLA compilation error for TPU sample code
import os
import math
import time
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD
from torch.nn.parameter import Parameter
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch_xla
import torch_xla.distributed.parallel_loader as pl
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
from tqdm.auto import tqdm
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import StratifiedKFold
import timm
import albumentations
from albumentations.pytorch import ToTensorV2
from nnAudio.Spectrogram import CQT1992v2
OUTPUT_DIR = './'
# ====================================================
# CFG
# ====================================================
class CFG:
num_workers=4
model_name='tf_efficientnet_b7_ns'
epochs=1
T_max=3
lr=1e-4
min_lr=1e-6
batch_size=48
weight_decay=1e-6
gradient_accumulation_steps=1
max_grad_norm=1000
qtransform_params={"sr": 2048, "fmin": 20, "fmax": 1024, "hop_length": 32, "bins_per_octave": 8, "n_bins":None}
target_size=1
train=True
# ====================================================
# Utils
# ====================================================
def get_score(y_true, y_pred):
score = roc_auc_score(y_true, y_pred)
return score
def seed_torch(seed=42):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
# ====================================================
# Dataset
# ====================================================
class TrainDataset(Dataset):
def __init__(self, transform=None):
self.wave_transform = CQT1992v2(**CFG.qtransform_params)
self.transform = transform
def __len__(self):
return 2000
def apply_qtransform(self, waves, transform):
waves = np.hstack(waves)
waves = waves / np.max(waves)
waves = torch.from_numpy(waves).float()
image = transform(waves)
return image
def __getitem__(self, idx):
waves = np.random.randn(3,4096)
image = self.apply_qtransform(waves, self.wave_transform)
image = image.squeeze().numpy()
if self.transform:
image = self.transform(image=image)['image']
label = torch.tensor(1).float()
return image, label
# ====================================================
# Transforms
# ====================================================
def get_transforms(*, data):
if data == 'train':
return albumentations.Compose([
ToTensorV2(),
])
elif data == 'valid':
return albumentations.Compose([
ToTensorV2(),
])
# ====================================================
# MODEL
# ====================================================
class CustomModel(nn.Module):
def __init__(self, cfg, pretrained=False):
super().__init__()
self.cfg = cfg
self.model = timm.create_model(self.cfg.model_name, pretrained=pretrained, in_chans=1)
self.n_features = self.model.classifier.in_features
self.model.classifier = nn.Linear(self.n_features, self.cfg.target_size)
def forward(self, x):
output = self.model(x)
return output
# ====================================================
# Helper functions
# ====================================================
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def loss_fn(outputs, targets):
return nn.BCEWithLogitsLoss()(outputs, targets.view(-1, 1))
def train_fn(fold, train_loader, model, criterion, optimizer, epoch, scheduler, device):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
scores = AverageMeter()
# switch to train mode
model.train()
start = end = time.time()
global_step = 0
xm.master_print("Training time ... ")
train_loader = tqdm(train_loader)
for step, (images, labels) in enumerate(train_loader):
# measure data loading time
data_time.update(time.time() - end)
images = images.to(device)
labels = labels.to(device)
batch_size = labels.size(0)
y_preds = model(images)
loss = criterion(y_preds.view(-1), labels)
# record loss
losses.update(loss.item(), batch_size)
if CFG.gradient_accumulation_steps > 1:
loss = loss / CFG.gradient_accumulation_steps
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.max_grad_norm)
if (step + 1) % CFG.gradient_accumulation_steps == 0:
xm.optimizer_step(optimizer)
optimizer.zero_grad()
global_step += 1
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
return losses.avg
def valid_fn(valid_loader, model, criterion, device):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
scores = AverageMeter()
# switch to evaluation mode
model.eval()
preds = []
start = end = time.time()
valid_labels = []
for step, (images, labels) in enumerate(valid_loader):
# measure data loading time
data_time.update(time.time() - end)
images = images.to(device)
labels = labels.to(device)
batch_size = labels.size(0)
# compute loss
with torch.no_grad():
y_preds = model(images)
xm.mark_step()
loss = loss_fn(y_preds, labels)
losses.update(loss.item(), batch_size)
# record accuracy
preds.append(y_preds.sigmoid().to('cpu').numpy())
valid_labels.append(labels.to('cpu').numpy())
if CFG.gradient_accumulation_steps > 1:
loss = loss / CFG.gradient_accumulation_steps
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
preds = np.concatenate(preds)
valid_labels = np.concatenate(valid_labels)
score = get_score(valid_labels, preds)
return losses.avg, preds, score
# ====================================================
# Train loop
# ====================================================
def train_loop(tid):
device = xm.xla_device()
train_dataset = TrainDataset(transform=get_transforms(data='train'))
valid_dataset = TrainDataset(transform=get_transforms(data='train'))
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset,
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal(),
shuffle=True)
valid_sampler = torch.utils.data.distributed.DistributedSampler(
valid_dataset,
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal(),
shuffle=False,
)
train_loader = DataLoader(train_dataset,
batch_size=CFG.batch_size,
sampler=train_sampler,
num_workers=CFG.num_workers, pin_memory=True, drop_last=True)
valid_loader = DataLoader(valid_dataset,
batch_size=CFG.batch_size,
sampler=valid_sampler,
num_workers=CFG.num_workers, pin_memory=True, drop_last=False)
xm.master_print("Dataloader")
# ====================================================
# model & optimizer
# ====================================================
model = CustomModel(CFG, pretrained=True)
model.to(device)
optimizer = Adam(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay, amsgrad=False)
scheduler = CosineAnnealingLR(optimizer, T_max=CFG.T_max, eta_min=CFG.min_lr, last_epoch=-1)
criterion = nn.BCEWithLogitsLoss()
best_score = 0.
best_loss = np.inf
fold = 0
for epoch in range(CFG.epochs):
start_time = time.time()
# train
para_loader = pl.ParallelLoader(train_loader, [device])
avg_loss = train_fn(fold, para_loader.per_device_loader(device), model, criterion, optimizer, epoch, scheduler, device)
# eval
para_loader = pl.ParallelLoader(valid_loader, [device])
avg_val_loss, preds, score = valid_fn(para_loader.per_device_loader(device), model, criterion, device)
scheduler.step()
elapsed = time.time() - start_time
if score > best_score:
best_score = score
xm.save(model.state_dict(), OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_best_score.pth')
if avg_val_loss < best_loss:
best_loss = avg_val_loss
xm.save(model.state_dict(), OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_best_score.pth')
return best_score, best_loss
if __name__ == "__main__":
seed_torch()
torch.set_default_tensor_type('torch.FloatTensor')
xmp.spawn(train_loop, args=(), nprocs=None)
@thsunkid
Copy link
Author

thsunkid commented Sep 5, 2021

Below is full log of the error.
It's bug-free if you use Efficientnet b6 instead. I suspect it is a memory-related issue.

2021-09-05 14:47:40.900394: F tensorflow/core/tpu/kernels/tpu_program_group.cc:86] Check failed: xla_tpu_programs.size() > 0 (0 vs. 0)
https://symbolize.stripped_domain/r/?trace=7fb1ea14f18b,7fb1ea14f20f,7faf7606064f,7faf70e9fc97,7faf70e94b01,7faf70eb529e,7faf70eb4e0b,7faf6dac893d,7faf723a32a8,7faf75b56580,7faf75b58943,7faf76031f71,7faf760317a0,7faf7601b32b,7fb1ea0ef608&map=c5ea6dcea9ec73900e238cf37efee14d75fd7749:7faf69206000-7faf78b74e28
*** SIGABRT received by PID 96579 (TID 98659) on cpu 50 from PID 96579; stack trace: ***
PC: @ 0x7fb1ea14f18b (unknown) raise
@ 0x7faf687071e0 976 (unknown)
@ 0x7fb1ea14f210 3968 (unknown)
@ 0x7faf76060650 16 tensorflow::internal::LogMessageFatal::~LogMessageFatal()
@ 0x7faf70e9fc98 592 tensorflow::tpu::TpuProgramGroup::Initialize()
@ 0x7faf70e94b02 1360 tensorflow::tpu::TpuCompilationCacheExternal::InitializeEntry()
@ 0x7faf70eb529f 800 tensorflow::tpu::TpuCompilationCacheInterface::CompileIfKeyAbsentHelper()
@ 0x7faf70eb4e0c 128 tensorflow::tpu::TpuCompilationCacheInterface::CompileIfKeyAbsent()
@ 0x7faf6dac893e 944 tensorflow::XRTCompileOp::Compute()
@ 0x7faf723a32a9 432 tensorflow::XlaDevice::Compute()
@ 0x7faf75b56581 2080 tensorflow::(anonymous namespace)::ExecutorState<>::Process()
@ 0x7faf75b58944 48 std::_Function_handler<>::_M_invoke()
@ 0x7faf76031f72 128 Eigen::ThreadPoolTempl<>::WorkerLoop()
@ 0x7faf760317a1 48 tensorflow::thread::EigenEnvironment::CreateThread()::{lambda()#1}::operator()()
@ 0x7faf7601b32c 80 tensorflow::(anonymous namespace)::PThread::ThreadFn()
@ 0x7fb1ea0ef609 (unknown) start_thread
https://symbolize.stripped_domain/r/?trace=7fb1ea14f18b,7faf687071df,7fb1ea14f20f,7faf7606064f,7faf70e9fc97,7faf70e94b01,7faf70eb529e,7faf70eb4e0b,7faf6dac893d,7faf723a32a8,7faf75b56580,7faf75b58943,7faf76031f71,7faf760317a0,7faf7601b32b,7fb1ea0ef608&map=c5ea6dcea9ec73900e238cf37efee14d75fd7749:7faf69206000-7faf78b74e28,ca1b7ab241ee28147b3d590cadb5dc1b:7faf5ba08000-7faf68a3ab20
E0905 14:47:41.115423 98659 coredump_hook.cc:292] RAW: Remote crash data gathering hook invoked.
E0905 14:47:41.115521 98659 coredump_hook.cc:384] RAW: Skipping coredump since rlimit was 0 at process start.
E0905 14:47:41.115538 98659 client.cc:222] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec.
E0905 14:47:41.115546 98659 coredump_hook.cc:447] RAW: Sending fingerprint to remote end.
E0905 14:47:41.115559 98659 coredump_socket.cc:124] RAW: Stat failed errno=2 on socket /var/google/services/logmanagerd/remote_coredump.socket
E0905 14:47:41.115585 98659 coredump_hook.cc:451] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] Missing crash reporting socket. Is the listener running?
E0905 14:47:41.115593 98659 coredump_hook.cc:525] RAW: Discarding core.
E0905 14:47:41.446840 98659 process_state.cc:771] RAW: Raising signal 6 with default behavior
2021-09-05 14:47:42.253931: W tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc:159] RPC failed with status = "Unavailable: Socket closed" and grpc_error_string = "{"created":"@1630853262.253726561","description":"Error received from peer ipv4:127.0.0.1:51011","file":"external/com_github_grpc_grpc/src/core/lib/surface/call.cc","file_line":1056,"grpc_message":"Socket closed","grpc_status":14}", maybe retrying the RPC
2021-09-05 14:47:42.253981: W tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc:159] RPC failed with status = "Unavailable: Socket closed" and grpc_error_string = "{"created":"@1630853262.253798251","description":"Error received from peer ipv4:127.0.0.1:51011","file":"external/com_github_grpc_grpc/src/core/lib/surface/call.cc","file_line":1056,"grpc_message":"Socket closed","grpc_status":14}", maybe retrying the RPC
2021-09-05 14:47:42.254173: W tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc:159] RPC failed with status = "Unavailable: Socket closed" and grpc_error_string = "{"created":"@1630853262.253982083","description":"Error received from peer ipv4:127.0.0.1:51011","file":"external/com_github_grpc_grpc/src/core/lib/surface/call.cc","file_line":1056,"grpc_message":"Socket closed","grpc_status":14}", maybe retrying the RPC
2021-09-05 14:47:42.254316: W tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc:159] RPC failed with status = "Unavailable: Socket closed" and grpc_error_string = "{"created":"@1630853262.254106547","description":"Error received from peer ipv4:127.0.0.1:51011","file":"external/com_github_grpc_grpc/src/core/lib/surface/call.cc","file_line":1056,"grpc_message":"Socket closed","grpc_status":14}", maybe retrying the RPC
2021-09-05 14:47:42.254401: W tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc:159] RPC failed with status = "Unavailable: Socket closed" and grpc_error_string = "{"created":"@1630853262.254233239","description":"Error received from peer ipv4:127.0.0.1:51011","file":"external/com_github_grpc_grpc/src/core/lib/surface/call.cc","file_line":1056,"grpc_message":"Socket closed","grpc_status":14}", maybe retrying the RPC
2021-09-05 14:47:42.254586: W tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc:159] RPC failed with status = "Unavailable: Socket closed" and grpc_error_string = "{"created":"@1630853262.254379039","description":"Error received from peer ipv4:127.0.0.1:51011","file":"external/com_github_grpc_grpc/src/core/lib/surface/call.cc","file_line":1056,"grpc_message":"Socket closed","grpc_status":14}", maybe retrying the RPC
2021-09-05 14:47:42.254681: W tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc:159] RPC failed with status = "Unavailable: Socket closed" and grpc_error_string = "{"created":"@1630853262.254371953","description":"Error received from peer ipv4:127.0.0.1:51011","file":"external/com_github_grpc_grpc/src/core/lib/surface/call.cc","file_line":1056,"grpc_message":"Socket closed","grpc_status":14}", maybe retrying the RPC
https://symbolize.stripped_domain/r/?trace=7f48afeae376,7f48aff0720f,0&map=
*** SIGTERM received by PID 96580 (TID 96580) on cpu 76 from PID 96418; stack trace: ***
PC: @ 0x7f48afeae376 (unknown) pthread_cond_wait@@GLIBC_2.3.2
@ 0x7f46406031e0 976 (unknown)
@ 0x7f48aff07210 (unknown) (unknown)
@ 0x1 (unknown) (unknown)
https://symbolize.stripped_domain/r/?trace=7f48afeae376,7f46406031df,7f48aff0720f,0&map=ca1b7ab241ee28147b3d590cadb5dc1b:7f4633904000-7f4640936b20
E0905 14:47:44.622399 96580 coredump_hook.cc:250] RAW: Remote crash gathering disabled for SIGTERM.
E0905 14:47:44.632334 96580 process_state.cc:771] RAW: Raising signal 15 with default behavior
https://symbolize.stripped_domain/r/?trace=7fa259b48376,7fa259ba120f,0&map=
*** SIGTERM received by PID 96581 (TID 96581) on cpu 16 from PID 96418; stack trace: ***
PC: @ 0x7fa259b48376 (unknown) pthread_cond_wait@@GLIBC_2.3.2
@ 0x7f9fea29d1e0 976 (unknown)
@ 0x7fa259ba1210 (unknown) (unknown)
@ 0x1 (unknown) (unknown)
https://symbolize.stripped_domain/r/?trace=7fa259b48376,7f9fea29d1df,7fa259ba120f,0&map=ca1b7ab241ee28147b3d590cadb5dc1b:7f9fdd59e000-7f9fea5d0b20
E0905 14:47:44.692978 96581 coredump_hook.cc:250] RAW: Remote crash gathering disabled for SIGTERM.
E0905 14:47:44.703157 96581 process_state.cc:771] RAW: Raising signal 15 with default behavior
https://symbolize.stripped_domain/r/?trace=7f253582d376,7f253588620f,0&map=
*** SIGTERM received by PID 96582 (TID 96582) on cpu 50 from PID 96418; stack trace: ***
PC: @ 0x7f253582d376 (unknown) pthread_cond_wait@@GLIBC_2.3.2
@ 0x7f22b3e3e1e0 976 (unknown)
@ 0x7f2535886210 (unknown) (unknown)
@ 0x1 (unknown) (unknown)
https://symbolize.stripped_domain/r/?trace=7f253582d376,7f22b3e3e1df,7f253588620f,0&map=ca1b7ab241ee28147b3d590cadb5dc1b:7f22a713f000-7f22b4171b20
E0905 14:47:44.770392 96582 coredump_hook.cc:250] RAW: Remote crash gathering disabled for SIGTERM.
E0905 14:47:44.781140 96582 process_state.cc:771] RAW: Raising signal 15 with default behavior
https://symbolize.stripped_domain/r/?trace=7fa22cd6e376,7fa22cdc720f,0&map=
*** SIGTERM received by PID 96583 (TID 96583) on cpu 94 from PID 96418; stack trace: ***
PC: @ 0x7fa22cd6e376 (unknown) pthread_cond_wait@@GLIBC_2.3.2
@ 0x7f9fbd4c31e0 976 (unknown)
@ 0x7fa22cdc7210 (unknown) (unknown)
@ 0x1 (unknown) (unknown)
https://symbolize.stripped_domain/r/?trace=7fa22cd6e376,7f9fbd4c31df,7fa22cdc720f,0&map=ca1b7ab241ee28147b3d590cadb5dc1b:7f9fb07c4000-7f9fbd7f6b20
E0905 14:47:44.850877 96583 coredump_hook.cc:250] RAW: Remote crash gathering disabled for SIGTERM.
E0905 14:47:44.860808 96583 process_state.cc:771] RAW: Raising signal 15 with default behavior
https://symbolize.stripped_domain/r/?trace=7f2422297376,7f24222f020f,0&map=
*** SIGTERM received by PID 96584 (TID 96584) on cpu 95 from PID 96418; stack trace: ***
PC: @ 0x7f2422297376 (unknown) pthread_cond_wait@@GLIBC_2.3.2
@ 0x7f21a08a81e0 976 (unknown)
@ 0x7f24222f0210 (unknown) (unknown)
@ 0x1 (unknown) (unknown)
https://symbolize.stripped_domain/r/?trace=7f2422297376,7f21a08a81df,7f24222f020f,0&map=ca1b7ab241ee28147b3d590cadb5dc1b:7f2193ba9000-7f21a0bdbb20
E0905 14:47:44.925262 96584 coredump_hook.cc:250] RAW: Remote crash gathering disabled for SIGTERM.
E0905 14:47:44.935406 96584 process_state.cc:771] RAW: Raising signal 15 with default behavior
https://symbolize.stripped_domain/r/?trace=7fe72afed376,7fe72b04620f,0&map=
*** SIGTERM received by PID 96585 (TID 96585) on cpu 72 from PID 96418; stack trace: ***
PC: @ 0x7fe72afed376 (unknown) pthread_cond_wait@@GLIBC_2.3.2
@ 0x7fe4bb7421e0 976 (unknown)
@ 0x7fe72b046210 (unknown) (unknown)
@ 0x1 (unknown) (unknown)
https://symbolize.stripped_domain/r/?trace=7fe72afed376,7fe4bb7421df,7fe72b04620f,0&map=ca1b7ab241ee28147b3d590cadb5dc1b:7fe4aea43000-7fe4bba75b20
E0905 14:47:44.998456 96585 coredump_hook.cc:250] RAW: Remote crash gathering disabled for SIGTERM.
E0905 14:47:45.007977 96585 process_state.cc:771] RAW: Raising signal 15 with default behavior
https://symbolize.stripped_domain/r/?trace=7fce73026376,7fce7307f20f,0&map=
*** SIGTERM received by PID 96586 (TID 96586) on cpu 0 from PID 96418; stack trace: ***
PC: @ 0x7fce73026376 (unknown) pthread_cond_wait@@GLIBC_2.3.2
@ 0x7fcbf16371e0 976 (unknown)
@ 0x7fce7307f210 (unknown) (unknown)
@ 0x1 (unknown) (unknown)
https://symbolize.stripped_domain/r/?trace=7fce73026376,7fcbf16371df,7fce7307f20f,0&map=ca1b7ab241ee28147b3d590cadb5dc1b:7fcbe4938000-7fcbf196ab20
E0905 14:47:45.068494 96586 coredump_hook.cc:250] RAW: Remote crash gathering disabled for SIGTERM.
E0905 14:47:45.078588 96586 process_state.cc:771] RAW: Raising signal 15 with default behavior
Traceback (most recent call last):
File "b7test.py", line 344, in
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=None)
File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 388, in spawn
return torch.multiprocessing.start_processes(
File "/usr/local/lib/python3.8/dist-packages/torch/multiprocessing/spawn.py", line 188, in start_processes
while not context.join():
File "/usr/local/lib/python3.8/dist-packages/torch/multiprocessing/spawn.py", line 130, in join
raise ProcessExitedException(
torch.multiprocessing.spawn.ProcessExitedException: process 0 terminated with signal SIGABRT
/usr/lib/python3.8/multiprocessing/resource_tracker.py:216: UserWarning: resource_tracker: There appear to be 8 leaked semaphore objects to clean up at shutdown
warnings.warn('resource_tracker: There appear to be %d '

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment