Skip to content

Instantly share code, notes, and snippets.

@fehiepsi
Last active April 17, 2018 05:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fehiepsi/f0ccc0f40091ee15b35e756c4091ac8e to your computer and use it in GitHub Desktop.
Save fehiepsi/f0ccc0f40091ee15b35e756c4091ac8e to your computer and use it in GitHub Desktop.
test cuda trtrs
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import unittest\n",
"import torch\n",
"torch.set_default_tensor_type(torch.cuda.float64)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"class TEST(unittest.TestCase):\n",
" def test_trtrs(self):\n",
" a = torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23),\n",
" (-6.05, -3.30, 5.36, -4.44, 1.08),\n",
" (-0.45, 2.58, -2.70, 0.27, 9.04),\n",
" (8.32, 2.71, 4.35, -7.17, 2.14),\n",
" (-9.67, -5.14, -7.26, 6.08, -6.87))).t()\n",
" b = torch.Tensor(((4.02, 6.19, -8.22, -7.57, -3.03),\n",
" (-1.56, 4.00, -8.67, 1.75, 2.86),\n",
" (9.81, -4.09, -4.57, -8.61, 8.99))).t()\n",
" \n",
" print(a.dtype)\n",
"\n",
" U = torch.triu(a)\n",
" L = torch.tril(a)\n",
"\n",
" # solve Ux = b\n",
" x = torch.trtrs(b, U)[0]\n",
" self.assertLessEqual(b.dist(torch.mm(U, x)), 1e-12)\n",
" x = torch.trtrs(b, U, True, False, False)[0]\n",
" self.assertLessEqual(b.dist(torch.mm(U, x)), 1e-12)\n",
"\n",
" # solve Lx = b\n",
" x = torch.trtrs(b, L, False)[0]\n",
" self.assertLessEqual(b.dist(torch.mm(L, x)), 1e-12)\n",
" x = torch.trtrs(b, L, False, False, False)[0]\n",
" self.assertLessEqual(b.dist(torch.mm(L, x)), 1e-12)\n",
"\n",
" # solve U'x = b\n",
" x = torch.trtrs(b, U, True, True)[0]\n",
" self.assertLessEqual(b.dist(torch.mm(U.t(), x)), 1e-12)\n",
" x = torch.trtrs(b, U, True, True, False)[0]\n",
" self.assertLessEqual(b.dist(torch.mm(U.t(), x)), 1e-12)\n",
"\n",
" # solve U'x = b by manual transposition\n",
" y = torch.trtrs(b, U.t(), False, False)[0]\n",
" self.assertLessEqual(x.dist(y), 1e-12)\n",
"\n",
" # solve L'x = b\n",
" x = torch.trtrs(b, L, False, True)[0]\n",
" self.assertLessEqual(b.dist(torch.mm(L.t(), x)), 1e-12)\n",
" x = torch.trtrs(b, L, False, True, False)[0]\n",
" self.assertLessEqual(b.dist(torch.mm(L.t(), x)), 1e-12)\n",
"\n",
" # solve L'x = b by manual transposition\n",
" y = torch.trtrs(b, L.t(), True, False)[0]\n",
" self.assertLessEqual(x.dist(y), 1e-12)\n",
"\n",
" # test reuse\n",
" res1 = torch.trtrs(b, a)[0]\n",
" ta = torch.Tensor()\n",
" tb = torch.Tensor()\n",
" torch.trtrs(b, a, out=(tb, ta))\n",
" self.assertEqual(res1.dist(tb), 0)\n",
" tb.zero_()\n",
" torch.trtrs(b, a, out=(tb, ta))\n",
" self.assertEqual(res1.dist(tb), 0)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.cuda.float64\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
".\n",
"----------------------------------------------------------------------\n",
"Ran 1 test in 1.827s\n",
"\n",
"OK\n"
]
},
{
"data": {
"text/plain": [
"<unittest.main.TestProgram at 0x7f230e5e8f98>"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"unittest.main(argv=['first-arg-is-ignored'], exit=False)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python (pyro)",
"language": "python",
"name": "pyro"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment