-
-
Save vmoens/95d6427fcb5fa5714291b3dbfa7daa15 to your computer and use it in GitHub Desktop.
action_mask.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/vmoens/95d6427fcb5fa5714291b3dbfa7daa15/action_mask.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install git+https://github.com/pytorch-labs/tensordict.git" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "dVg87jT115ch", | |
"outputId": "d4ccbbae-cb61-4999-9cfe-3a4eb67ed9e1" | |
}, | |
"id": "dVg87jT115ch", | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Collecting git+https://github.com/pytorch-labs/tensordict.git\n", | |
" Cloning https://github.com/pytorch-labs/tensordict.git to /tmp/pip-req-build-03nm82up\n", | |
" Running command git clone --filter=blob:none --quiet https://github.com/pytorch-labs/tensordict.git /tmp/pip-req-build-03nm82up\n", | |
" Resolved https://github.com/pytorch-labs/tensordict.git to commit e411324a2c524faad777895371be9d3c3f9c3a41\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install git+https://github.com/pytorch/rl@masked_actions" | |
], | |
"metadata": { | |
"id": "KP0Z8H3L4osA" | |
}, | |
"id": "KP0Z8H3L4osA", | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "e556e3b4-8a2e-40ff-b468-0f97532ea8ea", | |
"metadata": { | |
"tags": [], | |
"id": "e556e3b4-8a2e-40ff-b468-0f97532ea8ea" | |
}, | |
"outputs": [], | |
"source": [ | |
"from torchrl.modules.distributions import MaskedCategorical\n", | |
"from torchrl.modules import ProbabilisticActor\n", | |
"from tensordict import TensorDict\n", | |
"from tensordict.nn import TensorDictModule as Mod, ProbabilisticTensorDictModule as Prob, ProbabilisticTensorDictSequential as Seq\n", | |
"import torch\n", | |
"from torch import nn" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "9e0cdc25-ebb0-4f6c-8741-4cd2a5cef9f6", | |
"metadata": { | |
"tags": [], | |
"id": "9e0cdc25-ebb0-4f6c-8741-4cd2a5cef9f6" | |
}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Masked categorical usage with a stochastic policy" | |
], | |
"metadata": { | |
"id": "6ZH4dPvV3viU" | |
}, | |
"id": "6ZH4dPvV3viU" | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "d11cedda-461a-4b3f-a45d-0915bcf84924", | |
"metadata": { | |
"tags": [], | |
"id": "d11cedda-461a-4b3f-a45d-0915bcf84924" | |
}, | |
"outputs": [], | |
"source": [ | |
"torch.manual_seed(0)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "940d4146-eea8-4efc-b0b5-866734e52c88", | |
"metadata": { | |
"tags": [], | |
"id": "940d4146-eea8-4efc-b0b5-866734e52c88" | |
}, | |
"outputs": [], | |
"source": [ | |
"module = Mod(\n", | |
" nn.Linear(3, 4),\n", | |
" in_keys=[\"obs\"],\n", | |
" out_keys=[\"logits\"],\n", | |
")\n", | |
"prob = Prob(\n", | |
" in_keys=[\"logits\", \"mask\"],\n", | |
" out_keys=[\"action\"],\n", | |
" distribution_class=MaskedCategorical,\n", | |
")\n", | |
"actor = Seq(module, prob)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "2925c427-d149-435c-a5e4-5b5c64dae88e", | |
"metadata": { | |
"tags": [], | |
"id": "2925c427-d149-435c-a5e4-5b5c64dae88e" | |
}, | |
"outputs": [], | |
"source": [ | |
"mask = torch.ones(4, 4, dtype=torch.bool).tril().triu()\n", | |
"mask = torch.cat([mask, mask, mask[:2]], 0)\n", | |
"td = TensorDict({\n", | |
" \"obs\": torch.randn(10, 3),\n", | |
" \"mask\": mask\n", | |
"}, [10])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "369ee2bb-7cd7-4a0e-8854-2b0fe5679608", | |
"metadata": { | |
"tags": [], | |
"id": "369ee2bb-7cd7-4a0e-8854-2b0fe5679608" | |
}, | |
"outputs": [], | |
"source": [ | |
"out = actor(td)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "983f1faf-fa79-440d-8a70-f5ec8bf87a05", | |
"metadata": { | |
"tags": [], | |
"id": "983f1faf-fa79-440d-8a70-f5ec8bf87a05" | |
}, | |
"outputs": [], | |
"source": [ | |
"out[\"action\"]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "d193a197-1928-4cef-a0e0-08bb898c61bc", | |
"metadata": { | |
"tags": [], | |
"id": "d193a197-1928-4cef-a0e0-08bb898c61bc" | |
}, | |
"outputs": [], | |
"source": [ | |
"out[\"mask\"]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Masked actions in env: ActionMask transform" | |
], | |
"metadata": { | |
"id": "97clcOih302T" | |
}, | |
"id": "97clcOih302T" | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "b2168e5d-f4e0-490b-b973-7edf11a07cad", | |
"metadata": { | |
"tags": [], | |
"id": "b2168e5d-f4e0-490b-b973-7edf11a07cad" | |
}, | |
"outputs": [], | |
"source": [ | |
"from torchrl.envs import EnvBase\n", | |
"import torch\n", | |
"from torchrl.data.tensor_specs import DiscreteTensorSpec, BinaryDiscreteTensorSpec, UnboundedContinuousTensorSpec, CompositeSpec" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "32d20431-3527-4911-96ea-7a5385e0dc06", | |
"metadata": { | |
"tags": [], | |
"id": "32d20431-3527-4911-96ea-7a5385e0dc06" | |
}, | |
"outputs": [], | |
"source": [ | |
"class MaskedEnv(EnvBase):\n", | |
" def __init__(self, *args, **kwargs):\n", | |
" super().__init__(*args, **kwargs)\n", | |
" self.action_spec = DiscreteTensorSpec(4)\n", | |
" self.state_spec = CompositeSpec(mask=BinaryDiscreteTensorSpec(4, dtype=torch.bool))\n", | |
" self.observation_spec = CompositeSpec(obs=UnboundedContinuousTensorSpec(3))\n", | |
" self.reward_spec = UnboundedContinuousTensorSpec(1)\n", | |
"\n", | |
" def _reset(self, data):\n", | |
" td = self.observation_spec.rand()\n", | |
" td.update(torch.ones_like(self.state_spec.rand()))\n", | |
" return td\n", | |
"\n", | |
" def _step(self, data):\n", | |
" td = self.observation_spec.rand()\n", | |
" mask = data.get(\"mask\")\n", | |
" action = data.get(\"action\")\n", | |
" mask = mask.scatter(-1, action.unsqueeze(-1), 0)\n", | |
"\n", | |
" td.set(\"mask\", mask)\n", | |
" td.set(\"reward\", self.reward_spec.rand())\n", | |
" td.set(\"done\", ~mask.any().view(1))\n", | |
" return td.empty().set(\"next\", td)\n", | |
"\n", | |
" def _set_seed(self, seed):\n", | |
" return seed" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "59f68c75-2ad0-4628-9c56-332a8260b714", | |
"metadata": { | |
"tags": [], | |
"id": "59f68c75-2ad0-4628-9c56-332a8260b714" | |
}, | |
"outputs": [], | |
"source": [ | |
"base_env = MaskedEnv()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "558c86d7-b9ec-41c3-8c3b-8c3200bbb187", | |
"metadata": { | |
"tags": [], | |
"id": "558c86d7-b9ec-41c3-8c3b-8c3200bbb187" | |
}, | |
"outputs": [], | |
"source": [ | |
"td = base_env.reset()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "330146b0-7114-429c-ae11-2993316a65f6", | |
"metadata": { | |
"tags": [], | |
"id": "330146b0-7114-429c-ae11-2993316a65f6" | |
}, | |
"outputs": [], | |
"source": [ | |
"base_env.rand_step(td)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "84bd7c7e-94c3-46a6-9c14-c72aa6488c31", | |
"metadata": { | |
"tags": [], | |
"id": "84bd7c7e-94c3-46a6-9c14-c72aa6488c31" | |
}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "5efbc23a-7e0a-4d34-873e-8db1c07a2050", | |
"metadata": { | |
"tags": [], | |
"id": "5efbc23a-7e0a-4d34-873e-8db1c07a2050" | |
}, | |
"outputs": [], | |
"source": [ | |
"from torchrl.envs.transforms import ActionMask, TransformedEnv" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "82c75678-13e2-4190-b9ea-fa391bbaecd8", | |
"metadata": { | |
"tags": [], | |
"id": "82c75678-13e2-4190-b9ea-fa391bbaecd8" | |
}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "66eb3dfb-00e4-4b22-beef-9145d6eab6e9", | |
"metadata": { | |
"id": "66eb3dfb-00e4-4b22-beef-9145d6eab6e9" | |
}, | |
"outputs": [], | |
"source": [ | |
"env = TransformedEnv(base_env, ActionMask())\n", | |
"r = env.rollout(10)\n", | |
"print(r)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "c8568f3f-5c58-4fba-bb8a-7ca423b6b637", | |
"metadata": { | |
"tags": [], | |
"id": "c8568f3f-5c58-4fba-bb8a-7ca423b6b637" | |
}, | |
"outputs": [], | |
"source": [ | |
"r[\"action\"]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "fa9a97e5-472e-4aa0-9344-45cc23419b40", | |
"metadata": { | |
"tags": [], | |
"id": "fa9a97e5-472e-4aa0-9344-45cc23419b40" | |
}, | |
"outputs": [], | |
"source": [ | |
"r[\"mask\"]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "101cabf0-64db-4598-94aa-0ca32f9d0037", | |
"metadata": { | |
"tags": [], | |
"id": "101cabf0-64db-4598-94aa-0ca32f9d0037" | |
}, | |
"outputs": [], | |
"source": [ | |
"base_env.action_spec.mask" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "75807c5b-8ff0-443c-869c-df9c2f522953", | |
"metadata": { | |
"tags": [], | |
"id": "75807c5b-8ff0-443c-869c-df9c2f522953" | |
}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "5e65bfb0-74d5-4990-985b-83479a6ebb2f", | |
"metadata": { | |
"tags": [], | |
"id": "5e65bfb0-74d5-4990-985b-83479a6ebb2f" | |
}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "19930288-a2ad-4b9f-a5ab-5e53c2bc39c8", | |
"metadata": { | |
"tags": [], | |
"id": "19930288-a2ad-4b9f-a5ab-5e53c2bc39c8" | |
}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "2b8ef5d8-6fa1-4704-bd6f-ae2aef48f45e", | |
"metadata": { | |
"id": "2b8ef5d8-6fa1-4704-bd6f-ae2aef48f45e" | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.17" | |
}, | |
"colab": { | |
"provenance": [], | |
"include_colab_link": true | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment