Created
June 13, 2022 11:23
-
-
Save sjchoi86/acb56e0d246eedc4c1fc0c1ea22188b2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "530490ae", | |
"metadata": {}, | |
"source": [ | |
"### Model ID" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "5e246896", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"PyTorch version:[1.12.0.dev20220519].\n", | |
"device:[mps].\n" | |
] | |
} | |
], | |
"source": [ | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"%matplotlib inline\n", | |
"%config InlineBackend.figure_format='retina'\n", | |
"print (\"PyTorch version:[%s].\"%(torch.__version__))\n", | |
"device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n", | |
"device = 'mps'\n", | |
"print (\"device:[%s].\"%(device))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "218f72ed", | |
"metadata": {}, | |
"source": [ | |
"### Utils" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "e0bfe063", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Done.\n" | |
] | |
} | |
], | |
"source": [ | |
"def t2pr(T=torch.zeros(128,4,4)):\n", | |
" ps = T[:,0:3,3]\n", | |
" Rs = T[:,0:3,0:3]\n", | |
" return (ps,Rs)\n", | |
"\n", | |
"def t2p(Ts=torch.zeros(128,4,4)):\n", | |
" device = Ts.device\n", | |
" ps = Ts[:,0:3,3]\n", | |
" return ps\n", | |
"\n", | |
"def pr2t(ps=torch.zeros(128,1,3),Rs=torch.zeros(128,3,3)):\n", | |
" device = ps.device\n", | |
" Ts = torch.zeros(ps.size()[0],4,4,dtype=torch.float).to(device)\n", | |
" Ts[:,0:3,0:3] = Rs[:]\n", | |
" Ts[:,0:3,3] = ps[:]\n", | |
" Ts[:,3,3] = 1\n", | |
" return Ts\n", | |
"\n", | |
"def rpy2r(rpys=torch.zeros(128,3)): # [radian]\n", | |
" device = rpys.device\n", | |
" Rs = torch.zeros(rpys.size()[0],3,3,dtype=torch.float).to(device)\n", | |
" rs = rpys[:,0]\n", | |
" ps = rpys[:,1]\n", | |
" ys = rpys[:,2]\n", | |
" Rs[:,0,:] = torch.vstack([\n", | |
" torch.cos(ys)*torch.cos(ps),\n", | |
" -torch.sin(ys)*torch.cos(rs) + torch.cos(ys)*torch.sin(ps)*torch.sin(rs),\n", | |
" torch.sin(ys)*torch.sin(rs) + torch.cos(ys)*torch.sin(ps)*torch.cos(rs)\n", | |
" ]).transpose(0,1)\n", | |
" Rs[:,1,:] = torch.vstack([\n", | |
" torch.sin(ys)*torch.cos(ps),\n", | |
" torch.cos(ys)*torch.cos(rs) + torch.sin(ys)*torch.sin(ps)*torch.sin(rs),\n", | |
" -torch.cos(ys)*torch.sin(rs) + torch.sin(ys)*torch.sin(ps)*torch.cos(rs)\n", | |
" ]).transpose(0,1)\n", | |
" Rs[:,2,:] = torch.vstack([\n", | |
" -torch.sin(ps),\n", | |
" torch.cos(ps)*torch.sin(rs),\n", | |
" torch.cos(ps)*torch.cos(rs)\n", | |
" ]).transpose(0,1) \n", | |
" return Rs\n", | |
"\n", | |
"def skew(ps=torch.zeros(128,3)):\n", | |
" device = ps.device\n", | |
" skew_ps = torch.zeros(ps.size()[0],3,3,dtype=torch.float).to(device)\n", | |
" zeros = torch.zeros(ps.size()[0],dtype=torch.float).to(device)\n", | |
" skew_ps[:,0,:] = torch.vstack([zeros, -ps[:,2], ps[:,1]]).transpose(0,1)\n", | |
" skew_ps[:,1,:] = torch.vstack([ps[:,2], zeros, -ps[:,0]]).transpose(0,1)\n", | |
" skew_ps[:,2,:] = torch.vstack([-ps[:,1], ps[:,0],zeros]).transpose(0,1)\n", | |
" return skew_ps\n", | |
"\n", | |
"def rodrigues(w=torch.tensor([1.0,0,0]),qs=torch.zeros(128),VERBOSE=False):\n", | |
" eps = 1e-10\n", | |
" device = qs.device\n", | |
" batch_size = qs.size()[0]\n", | |
"\n", | |
" if torch.norm(w) < eps:\n", | |
" Rs = torch.tile(torch.eye(3),(batch_size,1,1)).to(device)\n", | |
" return Rs\n", | |
" if abs(torch.norm(w)-1) > eps:\n", | |
" if VERBOSE:\n", | |
" print(\"Warning: [rodirgues] >> joint twist not normalized\")\n", | |
"\n", | |
" theta = torch.norm(w)\n", | |
" w = w/theta # [3]\n", | |
" qs = qs*theta # [N]\n", | |
" w_skew = skew(w.unsqueeze(0)).squeeze(0) # []\n", | |
" Rs = torch.tensordot(\n", | |
" torch.ones_like(qs).unsqueeze(0),torch.eye(3).unsqueeze(0).to(device),dims=([0],[0])\n", | |
" ) \\\n", | |
" + torch.tensordot(\n", | |
" torch.sin(qs).unsqueeze(0),w_skew.unsqueeze(0),dims=([0],[0])\n", | |
" )\\\n", | |
" + torch.tensordot(\n", | |
" (1-torch.cos(qs)).unsqueeze(0),(w_skew@w_skew).unsqueeze(0),dims=([0],[0])\n", | |
" )\n", | |
" return Rs\n", | |
"\n", | |
"print (\"Done.\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "8f3da235", | |
"metadata": {}, | |
"source": [ | |
"### Kinematic Chain" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 31, | |
"id": "aae9a173", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Done.\n" | |
] | |
} | |
], | |
"source": [ | |
"# Prismatic joint\n", | |
"class PrismaticJointClass(nn.Module):\n", | |
" def __init__(self,\n", | |
" name ='PJ',\n", | |
" p_offset_val = 0.1*torch.randn(1,3),\n", | |
" rpy_offset_val = 0.1*torch.randn(1,3), # [radian]\n", | |
" axis_val = torch.randn(3)\n", | |
" ):\n", | |
" super(PrismaticJointClass,self).__init__()\n", | |
" self.name = name\n", | |
" self.p_offset = nn.Parameter(p_offset_val)\n", | |
" self.rpy_offset = nn.Parameter(rpy_offset_val)\n", | |
" self.axis = nn.Parameter(nn.functional.normalize(\n", | |
" axis_val,p=2.0,dim=0,eps=1e-4))\n", | |
" \n", | |
" def forward(self,qs=torch.zeros(128)):\n", | |
" batch_size = qs.size()[0] \n", | |
" device = qs.device\n", | |
" T_offset = pr2t(self.p_offset,rpy2r(self.rpy_offset))\n", | |
" Ts = pr2t(ps=torch.outer(qs,self.axis),\n", | |
" Rs=torch.tile(torch.eye(3),(batch_size,1,1)).to(device)\n", | |
" )\n", | |
" return T_offset,Ts\n", | |
"\n", | |
"# Revolute joint\n", | |
"class RevoluteJointClass(nn.Module):\n", | |
" def __init__(self,\n", | |
" name ='RJ',\n", | |
" p_offset_val = 0.1*torch.randn(1,3),\n", | |
" rpy_offset_val = 0.1*torch.randn(1,3), # [radian]\n", | |
" axis_val = torch.randn(3)\n", | |
" ):\n", | |
" super(RevoluteJointClass,self).__init__()\n", | |
" self.name = name\n", | |
" self.p_offset = nn.Parameter(p_offset_val)\n", | |
" self.rpy_offset = nn.Parameter(rpy_offset_val)\n", | |
" self.axis = nn.Parameter(nn.functional.normalize(\n", | |
" axis_val,p=2.0,dim=0,eps=1e-4))\n", | |
" \n", | |
" def forward(self,qs=torch.zeros(128)):\n", | |
" batch_size = qs.size()[0] \n", | |
" device = qs.device\n", | |
" T_offset = pr2t(self.p_offset,rpy2r(self.rpy_offset))\n", | |
" Ts = pr2t(ps=torch.zeros(batch_size,3).to(device),\n", | |
" Rs=rodrigues(self.axis,qs)\n", | |
" )\n", | |
" return T_offset,Ts\n", | |
"\n", | |
"# Fixed joint\n", | |
"class FixedJointClass(nn.Module):\n", | |
" def __init__(self,\n", | |
" name ='FJ',\n", | |
" p_offset_val = torch.zeros(1,3),\n", | |
" rpy_offset_val = torch.zeros(1,3) # [radian]\n", | |
" ):\n", | |
" super(FixedJointClass,self).__init__()\n", | |
" self.name = name\n", | |
" self.p_offset = p_offset_val\n", | |
" self.rpy_offset = rpy_offset_val\n", | |
" \n", | |
" def forward(self,qs=torch.zeros(128)):\n", | |
" batch_size = qs.size()[0] \n", | |
" device = qs.device\n", | |
" T_offset = pr2t(self.p_offset,rpy2r(self.rpy_offset))\n", | |
" Ts = torch.tile(torch.eye(4),dims=(batch_size,1,1)).to(device)\n", | |
" return T_offset,Ts\n", | |
" \n", | |
"# Joint Sequence\n", | |
"class SequenceOfJointsClass(nn.Module):\n", | |
" def __init__(self,\n", | |
" name = 'SoJ',\n", | |
" joints = []\n", | |
" ):\n", | |
" super(SequenceOfJointsClass,self).__init__()\n", | |
" self.name = name\n", | |
" self.joints = joints\n", | |
" \n", | |
"# Actuation Network\n", | |
"class ActuationNetworkClass(nn.Module):\n", | |
" def __init__(self,\n", | |
" name = 'AN',\n", | |
" actuation_type = 'Revolute', # ['Revolute','Prismatic','Fixed']\n", | |
" adim = 4,\n", | |
" hdims = [16],\n", | |
" actv = torch.nn.ReLU()\n", | |
" ):\n", | |
" super(ActuationNetworkClass,self).__init__()\n", | |
" self.name = name\n", | |
" self.actuation_type = actuation_type\n", | |
" self.adim = 4\n", | |
" self.hdims = hdims\n", | |
" self.actv = actv\n", | |
" # Initialize layers\n", | |
" self.init_layers()\n", | |
" \n", | |
" def init_layers(self):\n", | |
" self.layers = []\n", | |
" prev_hdim = self.adim\n", | |
" for hdim in self.hdims:\n", | |
" self.layers.append(nn.Linear(prev_hdim,hdim,bias=True))\n", | |
" self.layers.append(self.actv) # activation\n", | |
" prev_hdim = hdim\n", | |
" self.layers.append(nn.Linear(prev_hdim,1))\n", | |
" # Append to net\n", | |
" self.net = nn.Sequential()\n", | |
" for l_idx,layer in enumerate(self.layers):\n", | |
" layer_name = \"%s_%02d\"%(type(layer).__name__.lower(),l_idx)\n", | |
" self.net.add_module(layer_name,layer)\n", | |
" # Initialize parameters\n", | |
" self.init_params()\n", | |
" \n", | |
" def init_params(self):\n", | |
" for m in self.modules():\n", | |
" if isinstance(m,nn.Conv2d): # init conv\n", | |
" nn.init.kaiming_normal_(m.weight)\n", | |
" nn.init.zeros_(m.bias)\n", | |
" elif isinstance(m,nn.BatchNorm2d): # init BN\n", | |
" nn.init.constant_(m.weight,1)\n", | |
" nn.init.constant_(m.bias,0)\n", | |
" elif isinstance(m,nn.Linear): # lnit dense\n", | |
" nn.init.kaiming_normal_(m.weight)\n", | |
" nn.init.zeros_(m.bias)\n", | |
" \n", | |
" def forward(self,x):\n", | |
" return self.net(x)\n", | |
" \n", | |
"print (\"Done.\") " | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "4321fc7b", | |
"metadata": {}, | |
"source": [ | |
"### Instantiate" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 32, | |
"id": "e0e6b68b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"PJ = PrismaticJointClass()\n", | |
"RJ = RevoluteJointClass()\n", | |
"FJ = FixedJointClass()\n", | |
"SoJ = SequenceOfJointsClass(joints=[RJ,PJ,FJ])\n", | |
"AN = ActuationNetworkClass()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 33, | |
"id": "f7abb366", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<bound method Module.parameters of ActuationNetworkClass(\n", | |
" (actv): ReLU()\n", | |
" (net): Sequential(\n", | |
" (linear_00): Linear(in_features=4, out_features=16, bias=True)\n", | |
" (relu_01): ReLU()\n", | |
" (linear_02): Linear(in_features=16, out_features=1, bias=True)\n", | |
" )\n", | |
")>" | |
] | |
}, | |
"execution_count": 33, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"AN.parameters" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "3a34d307", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "aa7c81ea", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "5fb54120", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "a9246bdc", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "115ba05f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "ab7e3a68", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "0ab2b5bd", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "c6bb91c2", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.13" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment