Skip to content

Instantly share code, notes, and snippets.

@ManfeiBai
Created July 12, 2024 21:06
Show Gist options
  • Save ManfeiBai/3b13d4ea1068e2b497e60ccb80b66a1b to your computer and use it in GitHub Desktop.
Save ManfeiBai/3b13d4ea1068e2b497e60ccb80b66a1b to your computer and use it in GitHub Desktop.
3
(torch310) root@b7b12c30e894:/pytorch/xla# cat test_7665.py
import torch
import torch_xla
import torch_xla.experimental.fori_loop
from torch._higher_order_ops.while_loop import while_loop
import torch_xla.core.xla_model as xm
class TPUComputation:
def __init__(self):
self.device = xm.xla_device()
self.init_x = torch.tensor([1], device=self.device)
self.init_y = torch.tensor([1], device=self.device)
self.init_z = torch.tensor([1], device=self.device)
self.iteri = torch.tensor(10, device=self.device)
self.quantity = torch.tensor(3, device=self.device)
def cond_fn(self, iteri, x, y, z, q=None):
return iteri > 0
def body_fn(self, iteri, x, y, z, q=None):
return iteri - 1, x.clone(), y.add(1), z + self.quantity # Problemmatic Line
def compute(self):
result = while_loop(
self.cond_fn,
self.body_fn,
(self.iteri, self.init_x, self.init_y, self.init_z),
)
return result
if __name__ == "__main__":
computation = TPUComputation()
result = computation.compute()
print(result)
(torch310) root@b7b12c30e894:/pytorch/xla# PJRT_DEVICE=TPU python test_7665.py
body computation: !!!!!!!!!
HloModule PyLoweringContext.20.28, entry_computation_layout={((s64[], s64[], s64[1]{0}, s64[1]{0}, s64[1]{0}))->(s64[], s64[], s64[1]{0}, s64[1]{0}, s64[1]{0})}
%PyLoweringContext.7 (p0.10: s64[], p1.13: s64[], p2.14: s64[1], p3.17: s64[1], p4.22: s64[1]) -> (s64[], s64[], s64[1], s64[1], s64[1]) {
%p0.10 = s64[] parameter(0)
%constant.9 = s64[] constant(1)
%constant.8 = s64[] constant(1)
%multiply.11 = s64[] multiply(s64[] %constant.9, s64[] %constant.8)
%subtract.12 = s64[] subtract(s64[] %p0.10, s64[] %multiply.11)
%p1.13 = s64[] parameter(1)
%p2.14 = s64[1]{0} parameter(2)
%p3.17 = s64[1]{0} parameter(3)
%constant.16 = s64[] constant(1)
%constant.15 = s64[] constant(1)
%multiply.18 = s64[] multiply(s64[] %constant.16, s64[] %constant.15)
%broadcast.19 = s64[1]{0} broadcast(s64[] %multiply.18), dimensions={}
%add.20 = s64[1]{0} add(s64[1]{0} %p3.17, s64[1]{0} %broadcast.19)
%p4.22 = s64[1]{0} parameter(4)
%constant.21 = s64[] constant(1)
%multiply.23 = s64[] multiply(s64[] %p1.13, s64[] %constant.21)
%broadcast.24 = s64[1]{0} broadcast(s64[] %multiply.23), dimensions={}
%add.25 = s64[1]{0} add(s64[1]{0} %p4.22, s64[1]{0} %broadcast.24)
ROOT %tuple.26 = (s64[], s64[], s64[1]{0}, s64[1]{0}, s64[1]{0}) tuple(s64[] %subtract.12, s64[] %p1.13, s64[1]{0} %p2.14, s64[1]{0} %add.20, s64[1]{0} %add.25)
}
ENTRY %PyLoweringContext.20.28 (in.1: (s64[], s64[], s64[1], s64[1], s64[1])) -> (s64[], s64[], s64[1], s64[1], s64[1]) {
%in.1 = (s64[], s64[], s64[1]{0}, s64[1]{0}, s64[1]{0}) parameter(0)
%get-tuple-element.2 = s64[] get-tuple-element((s64[], s64[], s64[1]{0}, s64[1]{0}, s64[1]{0}) %in.1), index=0
%get-tuple-element.3 = s64[] get-tuple-element((s64[], s64[], s64[1]{0}, s64[1]{0}, s64[1]{0}) %in.1), index=1
%get-tuple-element.4 = s64[1]{0} get-tuple-element((s64[], s64[], s64[1]{0}, s64[1]{0}, s64[1]{0}) %in.1), index=2
%get-tuple-element.5 = s64[1]{0} get-tuple-element((s64[], s64[], s64[1]{0}, s64[1]{0}, s64[1]{0}) %in.1), index=3
%get-tuple-element.6 = s64[1]{0} get-tuple-element((s64[], s64[], s64[1]{0}, s64[1]{0}, s64[1]{0}) %in.1), index=4
ROOT %call.27 = (s64[], s64[], s64[1]{0}, s64[1]{0}, s64[1]{0}) call(s64[] %get-tuple-element.2, s64[] %get-tuple-element.3, s64[1]{0} %get-tuple-element.4, s64[1]{0} %get-tuple-element.5, s64[1]{0} %get-tuple-element.6), to_apply=%PyLoweringContext.7
}
cond computation: !!!!!!!!!
HloModule PyLoweringContext.8.16, entry_computation_layout={((s64[], s64[], s64[1]{0}, s64[1]{0}, s64[1]{0}))->pred[]}
%PyLoweringContext.7 (p0.9: s64[], UnusedArgumentsPlaceholder.11: s64[], UnusedArgumentsPlaceholder.12: s64[1], UnusedArgumentsPlaceholder.13: s64[1], UnusedArgumentsPlaceholder.14: s64[1]) -> pred[] {
%p0.9 = s64[] parameter(0)
%constant.8 = s64[] constant(0)
ROOT %compare.10 = pred[] compare(s64[] %p0.9, s64[] %constant.8), direction=GT
%UnusedArgumentsPlaceholder.11 = s64[] parameter(1)
%UnusedArgumentsPlaceholder.12 = s64[1]{0} parameter(2)
%UnusedArgumentsPlaceholder.13 = s64[1]{0} parameter(3)
%UnusedArgumentsPlaceholder.14 = s64[1]{0} parameter(4)
}
ENTRY %PyLoweringContext.8.16 (in.1: (s64[], s64[], s64[1], s64[1], s64[1])) -> pred[] {
%in.1 = (s64[], s64[], s64[1]{0}, s64[1]{0}, s64[1]{0}) parameter(0)
%get-tuple-element.2 = s64[] get-tuple-element((s64[], s64[], s64[1]{0}, s64[1]{0}, s64[1]{0}) %in.1), index=0
%get-tuple-element.3 = s64[] get-tuple-element((s64[], s64[], s64[1]{0}, s64[1]{0}, s64[1]{0}) %in.1), index=1
%get-tuple-element.4 = s64[1]{0} get-tuple-element((s64[], s64[], s64[1]{0}, s64[1]{0}, s64[1]{0}) %in.1), index=2
%get-tuple-element.5 = s64[1]{0} get-tuple-element((s64[], s64[], s64[1]{0}, s64[1]{0}, s64[1]{0}) %in.1), index=3
%get-tuple-element.6 = s64[1]{0} get-tuple-element((s64[], s64[], s64[1]{0}, s64[1]{0}, s64[1]{0}) %in.1), index=4
ROOT %call.15 = pred[] call(s64[] %get-tuple-element.2, s64[] %get-tuple-element.3, s64[1]{0} %get-tuple-element.4, s64[1]{0} %get-tuple-element.5, s64[1]{0} %get-tuple-element.6), to_apply=%PyLoweringContext.7
}
(FunctionalTensor(lvl=0, value=\
tensor(0, device='xla:0')), FunctionalTensor(lvl=0, value=\
tensor([1], device='xla:0')), FunctionalTensor(lvl=0, value=\
tensor([11], device='xla:0')), FunctionalTensor(lvl=0, value=\
tensor([31], device='xla:0')))
(torch310) root@b7b12c30e894:/pytorch/xla#
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment