Skip to content

Instantly share code, notes, and snippets.

import torch_xla
import torch_xla.core.xla_model as xm
from torch_xla.stablehlo import exported_program_to_stablehlo
import torch
import torch._export
import tempfile
import unittest
[WONJOO, tensor_methods.cpp]: inputIR {
%0 = f64[] xla::device_data(), location=__init__@_tensor_str.py:90, device=CPU:0, ROOT=0
%1 = f64[] prim::Constant(), value=1, ROOT=1
%2 = f64[5,6,10]{2,1,0} aten::expand(%1), size=(5, 6, 10), ROOT=2
%3 = f64[80,10]{1,0} xla::device_data(), location=convert@module.py:905, device=CPU:0, ROOT=3
%4 = f64[10,80]{0,1} aten::permute(%3), dims=(1, 0), ROOT=4
%5 = f64[80,10]{1,0} aten::permute(%4), dims=(1, 0), ROOT=5
%6 = f64[] prim::Constant(), location=forward@rnn.py:709, value=1, ROOT=6
%7 = f64[6,80]{1,0} aten::expand(%6), location=forward@rnn.py:709, size=(6, 80), ROOT=7
%8 = f64[80]{0} xla::device_data(), location=convert@module.py:905, device=CPU:0, ROOT=8
======================================================================
ERROR: test_rnn_retain_variables_xla_float64 (__main__.TestNNDeviceTypeXLA)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/usr/local/google/home/wonjoo/anaconda3/lib/python3.8/site-packages/torch/testing/_comparison.py", line 1062, in assert_equal
pair.compare()
File "/usr/local/google/home/wonjoo/anaconda3/lib/python3.8/site-packages/torch/testing/_comparison.py", line 605, in compare
self._compare_values(actual, expected)
File "/usr/local/google/home/wonjoo/anaconda3/lib/python3.8/site-packages/torch/testing/_comparison.py", line 699, in _compare_values
compare_fn(actual, expected, rtol=self.rtol, atol=self.atol, equal_nan=self.equal_nan)