Skip to content

Instantly share code, notes, and snippets.

@sjchoi86
Created June 13, 2022 11:23
Show Gist options
  • Save sjchoi86/acb56e0d246eedc4c1fc0c1ea22188b2 to your computer and use it in GitHub Desktop.
Save sjchoi86/acb56e0d246eedc4c1fc0c1ea22188b2 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "530490ae",
"metadata": {},
"source": [
"### Model ID"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "5e246896",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"PyTorch version:[1.12.0.dev20220519].\n",
"device:[mps].\n"
]
}
],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"%matplotlib inline\n",
"%config InlineBackend.figure_format='retina'\n",
"print (\"PyTorch version:[%s].\"%(torch.__version__))\n",
"device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
"device = 'mps'\n",
"print (\"device:[%s].\"%(device))"
]
},
{
"cell_type": "markdown",
"id": "218f72ed",
"metadata": {},
"source": [
"### Utils"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "e0bfe063",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n"
]
}
],
"source": [
"def t2pr(T=torch.zeros(128,4,4)):\n",
" ps = T[:,0:3,3]\n",
" Rs = T[:,0:3,0:3]\n",
" return (ps,Rs)\n",
"\n",
"def t2p(Ts=torch.zeros(128,4,4)):\n",
" device = Ts.device\n",
" ps = Ts[:,0:3,3]\n",
" return ps\n",
"\n",
"def pr2t(ps=torch.zeros(128,1,3),Rs=torch.zeros(128,3,3)):\n",
" device = ps.device\n",
" Ts = torch.zeros(ps.size()[0],4,4,dtype=torch.float).to(device)\n",
" Ts[:,0:3,0:3] = Rs[:]\n",
" Ts[:,0:3,3] = ps[:]\n",
" Ts[:,3,3] = 1\n",
" return Ts\n",
"\n",
"def rpy2r(rpys=torch.zeros(128,3)): # [radian]\n",
" device = rpys.device\n",
" Rs = torch.zeros(rpys.size()[0],3,3,dtype=torch.float).to(device)\n",
" rs = rpys[:,0]\n",
" ps = rpys[:,1]\n",
" ys = rpys[:,2]\n",
" Rs[:,0,:] = torch.vstack([\n",
" torch.cos(ys)*torch.cos(ps),\n",
" -torch.sin(ys)*torch.cos(rs) + torch.cos(ys)*torch.sin(ps)*torch.sin(rs),\n",
" torch.sin(ys)*torch.sin(rs) + torch.cos(ys)*torch.sin(ps)*torch.cos(rs)\n",
" ]).transpose(0,1)\n",
" Rs[:,1,:] = torch.vstack([\n",
" torch.sin(ys)*torch.cos(ps),\n",
" torch.cos(ys)*torch.cos(rs) + torch.sin(ys)*torch.sin(ps)*torch.sin(rs),\n",
" -torch.cos(ys)*torch.sin(rs) + torch.sin(ys)*torch.sin(ps)*torch.cos(rs)\n",
" ]).transpose(0,1)\n",
" Rs[:,2,:] = torch.vstack([\n",
" -torch.sin(ps),\n",
" torch.cos(ps)*torch.sin(rs),\n",
" torch.cos(ps)*torch.cos(rs)\n",
" ]).transpose(0,1) \n",
" return Rs\n",
"\n",
"def skew(ps=torch.zeros(128,3)):\n",
" device = ps.device\n",
" skew_ps = torch.zeros(ps.size()[0],3,3,dtype=torch.float).to(device)\n",
" zeros = torch.zeros(ps.size()[0],dtype=torch.float).to(device)\n",
" skew_ps[:,0,:] = torch.vstack([zeros, -ps[:,2], ps[:,1]]).transpose(0,1)\n",
" skew_ps[:,1,:] = torch.vstack([ps[:,2], zeros, -ps[:,0]]).transpose(0,1)\n",
" skew_ps[:,2,:] = torch.vstack([-ps[:,1], ps[:,0],zeros]).transpose(0,1)\n",
" return skew_ps\n",
"\n",
"def rodrigues(w=torch.tensor([1.0,0,0]),qs=torch.zeros(128),VERBOSE=False):\n",
" eps = 1e-10\n",
" device = qs.device\n",
" batch_size = qs.size()[0]\n",
"\n",
" if torch.norm(w) < eps:\n",
" Rs = torch.tile(torch.eye(3),(batch_size,1,1)).to(device)\n",
" return Rs\n",
" if abs(torch.norm(w)-1) > eps:\n",
" if VERBOSE:\n",
" print(\"Warning: [rodirgues] >> joint twist not normalized\")\n",
"\n",
" theta = torch.norm(w)\n",
" w = w/theta # [3]\n",
" qs = qs*theta # [N]\n",
" w_skew = skew(w.unsqueeze(0)).squeeze(0) # []\n",
" Rs = torch.tensordot(\n",
" torch.ones_like(qs).unsqueeze(0),torch.eye(3).unsqueeze(0).to(device),dims=([0],[0])\n",
" ) \\\n",
" + torch.tensordot(\n",
" torch.sin(qs).unsqueeze(0),w_skew.unsqueeze(0),dims=([0],[0])\n",
" )\\\n",
" + torch.tensordot(\n",
" (1-torch.cos(qs)).unsqueeze(0),(w_skew@w_skew).unsqueeze(0),dims=([0],[0])\n",
" )\n",
" return Rs\n",
"\n",
"print (\"Done.\")"
]
},
{
"cell_type": "markdown",
"id": "8f3da235",
"metadata": {},
"source": [
"### Kinematic Chain"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "aae9a173",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n"
]
}
],
"source": [
"# Prismatic joint\n",
"class PrismaticJointClass(nn.Module):\n",
" def __init__(self,\n",
" name ='PJ',\n",
" p_offset_val = 0.1*torch.randn(1,3),\n",
" rpy_offset_val = 0.1*torch.randn(1,3), # [radian]\n",
" axis_val = torch.randn(3)\n",
" ):\n",
" super(PrismaticJointClass,self).__init__()\n",
" self.name = name\n",
" self.p_offset = nn.Parameter(p_offset_val)\n",
" self.rpy_offset = nn.Parameter(rpy_offset_val)\n",
" self.axis = nn.Parameter(nn.functional.normalize(\n",
" axis_val,p=2.0,dim=0,eps=1e-4))\n",
" \n",
" def forward(self,qs=torch.zeros(128)):\n",
" batch_size = qs.size()[0] \n",
" device = qs.device\n",
" T_offset = pr2t(self.p_offset,rpy2r(self.rpy_offset))\n",
" Ts = pr2t(ps=torch.outer(qs,self.axis),\n",
" Rs=torch.tile(torch.eye(3),(batch_size,1,1)).to(device)\n",
" )\n",
" return T_offset,Ts\n",
"\n",
"# Revolute joint\n",
"class RevoluteJointClass(nn.Module):\n",
" def __init__(self,\n",
" name ='RJ',\n",
" p_offset_val = 0.1*torch.randn(1,3),\n",
" rpy_offset_val = 0.1*torch.randn(1,3), # [radian]\n",
" axis_val = torch.randn(3)\n",
" ):\n",
" super(RevoluteJointClass,self).__init__()\n",
" self.name = name\n",
" self.p_offset = nn.Parameter(p_offset_val)\n",
" self.rpy_offset = nn.Parameter(rpy_offset_val)\n",
" self.axis = nn.Parameter(nn.functional.normalize(\n",
" axis_val,p=2.0,dim=0,eps=1e-4))\n",
" \n",
" def forward(self,qs=torch.zeros(128)):\n",
" batch_size = qs.size()[0] \n",
" device = qs.device\n",
" T_offset = pr2t(self.p_offset,rpy2r(self.rpy_offset))\n",
" Ts = pr2t(ps=torch.zeros(batch_size,3).to(device),\n",
" Rs=rodrigues(self.axis,qs)\n",
" )\n",
" return T_offset,Ts\n",
"\n",
"# Fixed joint\n",
"class FixedJointClass(nn.Module):\n",
" def __init__(self,\n",
" name ='FJ',\n",
" p_offset_val = torch.zeros(1,3),\n",
" rpy_offset_val = torch.zeros(1,3) # [radian]\n",
" ):\n",
" super(FixedJointClass,self).__init__()\n",
" self.name = name\n",
" self.p_offset = p_offset_val\n",
" self.rpy_offset = rpy_offset_val\n",
" \n",
" def forward(self,qs=torch.zeros(128)):\n",
" batch_size = qs.size()[0] \n",
" device = qs.device\n",
" T_offset = pr2t(self.p_offset,rpy2r(self.rpy_offset))\n",
" Ts = torch.tile(torch.eye(4),dims=(batch_size,1,1)).to(device)\n",
" return T_offset,Ts\n",
" \n",
"# Joint Sequence\n",
"class SequenceOfJointsClass(nn.Module):\n",
" def __init__(self,\n",
" name = 'SoJ',\n",
" joints = []\n",
" ):\n",
" super(SequenceOfJointsClass,self).__init__()\n",
" self.name = name\n",
" self.joints = joints\n",
" \n",
"# Actuation Network\n",
"class ActuationNetworkClass(nn.Module):\n",
" def __init__(self,\n",
" name = 'AN',\n",
" actuation_type = 'Revolute', # ['Revolute','Prismatic','Fixed']\n",
" adim = 4,\n",
" hdims = [16],\n",
" actv = torch.nn.ReLU()\n",
" ):\n",
" super(ActuationNetworkClass,self).__init__()\n",
" self.name = name\n",
" self.actuation_type = actuation_type\n",
" self.adim = 4\n",
" self.hdims = hdims\n",
" self.actv = actv\n",
" # Initialize layers\n",
" self.init_layers()\n",
" \n",
" def init_layers(self):\n",
" self.layers = []\n",
" prev_hdim = self.adim\n",
" for hdim in self.hdims:\n",
" self.layers.append(nn.Linear(prev_hdim,hdim,bias=True))\n",
" self.layers.append(self.actv) # activation\n",
" prev_hdim = hdim\n",
" self.layers.append(nn.Linear(prev_hdim,1))\n",
" # Append to net\n",
" self.net = nn.Sequential()\n",
" for l_idx,layer in enumerate(self.layers):\n",
" layer_name = \"%s_%02d\"%(type(layer).__name__.lower(),l_idx)\n",
" self.net.add_module(layer_name,layer)\n",
" # Initialize parameters\n",
" self.init_params()\n",
" \n",
" def init_params(self):\n",
" for m in self.modules():\n",
" if isinstance(m,nn.Conv2d): # init conv\n",
" nn.init.kaiming_normal_(m.weight)\n",
" nn.init.zeros_(m.bias)\n",
" elif isinstance(m,nn.BatchNorm2d): # init BN\n",
" nn.init.constant_(m.weight,1)\n",
" nn.init.constant_(m.bias,0)\n",
" elif isinstance(m,nn.Linear): # lnit dense\n",
" nn.init.kaiming_normal_(m.weight)\n",
" nn.init.zeros_(m.bias)\n",
" \n",
" def forward(self,x):\n",
" return self.net(x)\n",
" \n",
"print (\"Done.\") "
]
},
{
"cell_type": "markdown",
"id": "4321fc7b",
"metadata": {},
"source": [
"### Instantiate"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "e0e6b68b",
"metadata": {},
"outputs": [],
"source": [
"PJ = PrismaticJointClass()\n",
"RJ = RevoluteJointClass()\n",
"FJ = FixedJointClass()\n",
"SoJ = SequenceOfJointsClass(joints=[RJ,PJ,FJ])\n",
"AN = ActuationNetworkClass()"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "f7abb366",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<bound method Module.parameters of ActuationNetworkClass(\n",
" (actv): ReLU()\n",
" (net): Sequential(\n",
" (linear_00): Linear(in_features=4, out_features=16, bias=True)\n",
" (relu_01): ReLU()\n",
" (linear_02): Linear(in_features=16, out_features=1, bias=True)\n",
" )\n",
")>"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"AN.parameters"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3a34d307",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "aa7c81ea",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "5fb54120",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "a9246bdc",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "115ba05f",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "ab7e3a68",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "0ab2b5bd",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "c6bb91c2",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment