Skip to content

Instantly share code, notes, and snippets.

@alberduris
Last active August 6, 2020 15:41
Show Gist options
  • Save alberduris/efa7b42d1d96691d589f32e77eef76f9 to your computer and use it in GitHub Desktop.
Save alberduris/efa7b42d1d96691d589f32e77eef76f9 to your computer and use it in GitHub Desktop.
Python class which implements a base Node for creating Trees with Pytorch Tensors #Others #JupyterNotebook #CodeSnippet
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Imports"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Node class for Trees\n",
"\n",
"Python class which implements a base Node for creating Trees with Pytorch Tensors"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"class Node:\n",
" \"\"\"\n",
" Class that implements a Node of a Tree\n",
" \"\"\"\n",
" def __init__(self, name=None, data=None, children=[]):\n",
" \n",
" self.name = name # Name of the node as ID (optional)\n",
" self.data = data # Each node carries a differentiable zero-dimensional (scalar) tensor initialized as random\n",
" if self.data is None:\n",
" self.data = torch.rand((1,), requires_grad=True)\n",
" \n",
" self.children = children \n",
" \n",
" def get_paths(self, node=None, path=None):\n",
" \"\"\"\n",
" Get all the paths of the Tree\n",
" \"\"\"\n",
" if node is None:\n",
" node = self\n",
" \n",
" paths = []\n",
" if path is None:\n",
" path = []\n",
" path.append(node)\n",
" \n",
" if node.children:\n",
" for child in node.children:\n",
" paths.extend(self.get_paths(child, path[:]))\n",
" else:\n",
" paths.append(path)\n",
" \n",
" return paths\n",
" \n",
" def traverse(self, f, node=None):\n",
" \"\"\"\n",
" Traverse the Tree recusively applying the function f to each Node data w/ updating inplace\n",
" \"\"\"\n",
" if node is None:\n",
" node = self\n",
" \n",
" if f is None: # Sanity check\n",
" raise NotImplementedError(\"Please, make sure the function {} passed to traverse function is implemented and not None.\".format(f))\n",
" node.data = f(node)\n",
" \n",
" if node.children:\n",
" for child in node.children:\n",
" self.traverse(f, child)\n",
" else:\n",
" return\n",
"\n",
" \n",
" def __str__(self):\n",
" return '{}: {}'.format(self.name, self.data)\n",
" \n",
" __repr__ = __str__\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create your own Tree\n",
"\n",
"He creado el árbol genérico que me has pasado por WhatsApp.\n",
"\n",
"Por supuesto se puede hacer programáticamente..."
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"tree = Node(name='x_1^1', children=[\n",
" Node(name='x_1^2', children=[\n",
" Node(name='x_1^3', children=[\n",
" Node(name='x_1^4', children=[]), Node(name='x_2^4', children=[])\n",
" ]), \n",
" Node(name='x_2^3', children=[\n",
" Node(name='x_3^4', children=[]), Node(name='x_4^4', children=[])\n",
" ])\n",
" ]),\n",
" Node(name='x_2^2', children=[\n",
" Node(name='x_3^3', children=[\n",
" Node(name='x_5^4', children=[]), Node(name='x_6^4', children=[])\n",
" ]), \n",
" Node(name='x_4^3', children=[\n",
" Node(name='x_7^4', children=[]), Node(name='x_8^4', children=[])\n",
" ])\n",
" ])\n",
"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Functions\n",
"\n",
"Algunas funciones por probar... \n",
"\n",
"`mult_sigmoid` corresponde a la que me has comentado antes a.k.a `\"Peso por numerito y función de activación es oyro numerito\"`"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"def sigmoid(node):\n",
" \"\"\"\n",
" Apply the Sigmoid function to a given Node data\n",
" \"\"\"\n",
" sigm = nn.Sigmoid()\n",
" return sigm(node.data)\n",
"\n",
"def mult_sigmoid(node):\n",
" \"\"\"\n",
" Apply a random product plus Sigmoid to a given Node data\n",
" \"\"\"\n",
" sigm = nn.Sigmoid()\n",
" return sigm(torch.rand((1,)) * node.data)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Explore & Traverse the Tree\n",
"\n",
"1. Explore the initial random Tree\n",
"2. Traverse the Tree applying the `\"Peso por numerito y función de activación es oyro numerito\"` function\n",
"3. Explore the resulting Tree"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[[x_1^1: tensor([0.2864], requires_grad=True),\n",
" x_1^2: tensor([0.1287], requires_grad=True),\n",
" x_1^3: tensor([0.0580], requires_grad=True),\n",
" x_1^4: tensor([0.8677], requires_grad=True)],\n",
" [x_1^1: tensor([0.2864], requires_grad=True),\n",
" x_1^2: tensor([0.1287], requires_grad=True),\n",
" x_1^3: tensor([0.0580], requires_grad=True),\n",
" x_2^4: tensor([0.5528], requires_grad=True)],\n",
" [x_1^1: tensor([0.2864], requires_grad=True),\n",
" x_1^2: tensor([0.1287], requires_grad=True),\n",
" x_2^3: tensor([0.6109], requires_grad=True),\n",
" x_3^4: tensor([0.2566], requires_grad=True)],\n",
" [x_1^1: tensor([0.2864], requires_grad=True),\n",
" x_1^2: tensor([0.1287], requires_grad=True),\n",
" x_2^3: tensor([0.6109], requires_grad=True),\n",
" x_4^4: tensor([0.1304], requires_grad=True)],\n",
" [x_1^1: tensor([0.2864], requires_grad=True),\n",
" x_2^2: tensor([0.5555], requires_grad=True),\n",
" x_3^3: tensor([0.9111], requires_grad=True),\n",
" x_5^4: tensor([0.9655], requires_grad=True)],\n",
" [x_1^1: tensor([0.2864], requires_grad=True),\n",
" x_2^2: tensor([0.5555], requires_grad=True),\n",
" x_3^3: tensor([0.9111], requires_grad=True),\n",
" x_6^4: tensor([0.9909], requires_grad=True)],\n",
" [x_1^1: tensor([0.2864], requires_grad=True),\n",
" x_2^2: tensor([0.5555], requires_grad=True),\n",
" x_4^3: tensor([0.8007], requires_grad=True),\n",
" x_7^4: tensor([0.7274], requires_grad=True)],\n",
" [x_1^1: tensor([0.2864], requires_grad=True),\n",
" x_2^2: tensor([0.5555], requires_grad=True),\n",
" x_4^3: tensor([0.8007], requires_grad=True),\n",
" x_8^4: tensor([0.3543], requires_grad=True)]]"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 1. Explore the initial random Tree\n",
"tree.get_paths()"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"# 2. Traverse the Tree\n",
"tree.traverse(f=mult_sigmoid)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[[x_1^1: tensor([0.5006], grad_fn=<SigmoidBackward>),\n",
" x_1^2: tensor([0.5165], grad_fn=<SigmoidBackward>),\n",
" x_1^3: tensor([0.5013], grad_fn=<SigmoidBackward>),\n",
" x_1^4: tensor([0.5470], grad_fn=<SigmoidBackward>)],\n",
" [x_1^1: tensor([0.5006], grad_fn=<SigmoidBackward>),\n",
" x_1^2: tensor([0.5165], grad_fn=<SigmoidBackward>),\n",
" x_1^3: tensor([0.5013], grad_fn=<SigmoidBackward>),\n",
" x_2^4: tensor([0.6257], grad_fn=<SigmoidBackward>)],\n",
" [x_1^1: tensor([0.5006], grad_fn=<SigmoidBackward>),\n",
" x_1^2: tensor([0.5165], grad_fn=<SigmoidBackward>),\n",
" x_2^3: tensor([0.5691], grad_fn=<SigmoidBackward>),\n",
" x_3^4: tensor([0.5629], grad_fn=<SigmoidBackward>)],\n",
" [x_1^1: tensor([0.5006], grad_fn=<SigmoidBackward>),\n",
" x_1^2: tensor([0.5165], grad_fn=<SigmoidBackward>),\n",
" x_2^3: tensor([0.5691], grad_fn=<SigmoidBackward>),\n",
" x_4^4: tensor([0.5016], grad_fn=<SigmoidBackward>)],\n",
" [x_1^1: tensor([0.5006], grad_fn=<SigmoidBackward>),\n",
" x_2^2: tensor([0.5077], grad_fn=<SigmoidBackward>),\n",
" x_3^3: tensor([0.6667], grad_fn=<SigmoidBackward>),\n",
" x_5^4: tensor([0.5391], grad_fn=<SigmoidBackward>)],\n",
" [x_1^1: tensor([0.5006], grad_fn=<SigmoidBackward>),\n",
" x_2^2: tensor([0.5077], grad_fn=<SigmoidBackward>),\n",
" x_3^3: tensor([0.6667], grad_fn=<SigmoidBackward>),\n",
" x_6^4: tensor([0.6168], grad_fn=<SigmoidBackward>)],\n",
" [x_1^1: tensor([0.5006], grad_fn=<SigmoidBackward>),\n",
" x_2^2: tensor([0.5077], grad_fn=<SigmoidBackward>),\n",
" x_4^3: tensor([0.5693], grad_fn=<SigmoidBackward>),\n",
" x_7^4: tensor([0.5114], grad_fn=<SigmoidBackward>)],\n",
" [x_1^1: tensor([0.5006], grad_fn=<SigmoidBackward>),\n",
" x_2^2: tensor([0.5077], grad_fn=<SigmoidBackward>),\n",
" x_4^3: tensor([0.5693], grad_fn=<SigmoidBackward>),\n",
" x_8^4: tensor([0.5736], grad_fn=<SigmoidBackward>)]]"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 3. Explore the resulting Tree\n",
"tree.get_paths()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.1"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment