Skip to content

Instantly share code, notes, and snippets.

@vmoens
Created July 27, 2023 15:09
Show Gist options
  • Save vmoens/95d6427fcb5fa5714291b3dbfa7daa15 to your computer and use it in GitHub Desktop.
Save vmoens/95d6427fcb5fa5714291b3dbfa7daa15 to your computer and use it in GitHub Desktop.
action_mask.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/vmoens/95d6427fcb5fa5714291b3dbfa7daa15/action_mask.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"source": [
"!pip install git+https://github.com/pytorch-labs/tensordict.git"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "dVg87jT115ch",
"outputId": "d4ccbbae-cb61-4999-9cfe-3a4eb67ed9e1"
},
"id": "dVg87jT115ch",
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Collecting git+https://github.com/pytorch-labs/tensordict.git\n",
" Cloning https://github.com/pytorch-labs/tensordict.git to /tmp/pip-req-build-03nm82up\n",
" Running command git clone --filter=blob:none --quiet https://github.com/pytorch-labs/tensordict.git /tmp/pip-req-build-03nm82up\n",
" Resolved https://github.com/pytorch-labs/tensordict.git to commit e411324a2c524faad777895371be9d3c3f9c3a41\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"!pip install git+https://github.com/pytorch/rl@masked_actions"
],
"metadata": {
"id": "KP0Z8H3L4osA"
},
"id": "KP0Z8H3L4osA",
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "e556e3b4-8a2e-40ff-b468-0f97532ea8ea",
"metadata": {
"tags": [],
"id": "e556e3b4-8a2e-40ff-b468-0f97532ea8ea"
},
"outputs": [],
"source": [
"from torchrl.modules.distributions import MaskedCategorical\n",
"from torchrl.modules import ProbabilisticActor\n",
"from tensordict import TensorDict\n",
"from tensordict.nn import TensorDictModule as Mod, ProbabilisticTensorDictModule as Prob, ProbabilisticTensorDictSequential as Seq\n",
"import torch\n",
"from torch import nn"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9e0cdc25-ebb0-4f6c-8741-4cd2a5cef9f6",
"metadata": {
"tags": [],
"id": "9e0cdc25-ebb0-4f6c-8741-4cd2a5cef9f6"
},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"source": [
"# Masked categorical usage with a stochastic policy"
],
"metadata": {
"id": "6ZH4dPvV3viU"
},
"id": "6ZH4dPvV3viU"
},
{
"cell_type": "code",
"execution_count": null,
"id": "d11cedda-461a-4b3f-a45d-0915bcf84924",
"metadata": {
"tags": [],
"id": "d11cedda-461a-4b3f-a45d-0915bcf84924"
},
"outputs": [],
"source": [
"torch.manual_seed(0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "940d4146-eea8-4efc-b0b5-866734e52c88",
"metadata": {
"tags": [],
"id": "940d4146-eea8-4efc-b0b5-866734e52c88"
},
"outputs": [],
"source": [
"module = Mod(\n",
" nn.Linear(3, 4),\n",
" in_keys=[\"obs\"],\n",
" out_keys=[\"logits\"],\n",
")\n",
"prob = Prob(\n",
" in_keys=[\"logits\", \"mask\"],\n",
" out_keys=[\"action\"],\n",
" distribution_class=MaskedCategorical,\n",
")\n",
"actor = Seq(module, prob)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2925c427-d149-435c-a5e4-5b5c64dae88e",
"metadata": {
"tags": [],
"id": "2925c427-d149-435c-a5e4-5b5c64dae88e"
},
"outputs": [],
"source": [
"mask = torch.ones(4, 4, dtype=torch.bool).tril().triu()\n",
"mask = torch.cat([mask, mask, mask[:2]], 0)\n",
"td = TensorDict({\n",
" \"obs\": torch.randn(10, 3),\n",
" \"mask\": mask\n",
"}, [10])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "369ee2bb-7cd7-4a0e-8854-2b0fe5679608",
"metadata": {
"tags": [],
"id": "369ee2bb-7cd7-4a0e-8854-2b0fe5679608"
},
"outputs": [],
"source": [
"out = actor(td)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "983f1faf-fa79-440d-8a70-f5ec8bf87a05",
"metadata": {
"tags": [],
"id": "983f1faf-fa79-440d-8a70-f5ec8bf87a05"
},
"outputs": [],
"source": [
"out[\"action\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d193a197-1928-4cef-a0e0-08bb898c61bc",
"metadata": {
"tags": [],
"id": "d193a197-1928-4cef-a0e0-08bb898c61bc"
},
"outputs": [],
"source": [
"out[\"mask\"]"
]
},
{
"cell_type": "markdown",
"source": [
"# Masked actions in env: ActionMask transform"
],
"metadata": {
"id": "97clcOih302T"
},
"id": "97clcOih302T"
},
{
"cell_type": "code",
"execution_count": null,
"id": "b2168e5d-f4e0-490b-b973-7edf11a07cad",
"metadata": {
"tags": [],
"id": "b2168e5d-f4e0-490b-b973-7edf11a07cad"
},
"outputs": [],
"source": [
"from torchrl.envs import EnvBase\n",
"import torch\n",
"from torchrl.data.tensor_specs import DiscreteTensorSpec, BinaryDiscreteTensorSpec, UnboundedContinuousTensorSpec, CompositeSpec"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "32d20431-3527-4911-96ea-7a5385e0dc06",
"metadata": {
"tags": [],
"id": "32d20431-3527-4911-96ea-7a5385e0dc06"
},
"outputs": [],
"source": [
"class MaskedEnv(EnvBase):\n",
" def __init__(self, *args, **kwargs):\n",
" super().__init__(*args, **kwargs)\n",
" self.action_spec = DiscreteTensorSpec(4)\n",
" self.state_spec = CompositeSpec(mask=BinaryDiscreteTensorSpec(4, dtype=torch.bool))\n",
" self.observation_spec = CompositeSpec(obs=UnboundedContinuousTensorSpec(3))\n",
" self.reward_spec = UnboundedContinuousTensorSpec(1)\n",
"\n",
" def _reset(self, data):\n",
" td = self.observation_spec.rand()\n",
" td.update(torch.ones_like(self.state_spec.rand()))\n",
" return td\n",
"\n",
" def _step(self, data):\n",
" td = self.observation_spec.rand()\n",
" mask = data.get(\"mask\")\n",
" action = data.get(\"action\")\n",
" mask = mask.scatter(-1, action.unsqueeze(-1), 0)\n",
"\n",
" td.set(\"mask\", mask)\n",
" td.set(\"reward\", self.reward_spec.rand())\n",
" td.set(\"done\", ~mask.any().view(1))\n",
" return td.empty().set(\"next\", td)\n",
"\n",
" def _set_seed(self, seed):\n",
" return seed"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "59f68c75-2ad0-4628-9c56-332a8260b714",
"metadata": {
"tags": [],
"id": "59f68c75-2ad0-4628-9c56-332a8260b714"
},
"outputs": [],
"source": [
"base_env = MaskedEnv()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "558c86d7-b9ec-41c3-8c3b-8c3200bbb187",
"metadata": {
"tags": [],
"id": "558c86d7-b9ec-41c3-8c3b-8c3200bbb187"
},
"outputs": [],
"source": [
"td = base_env.reset()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "330146b0-7114-429c-ae11-2993316a65f6",
"metadata": {
"tags": [],
"id": "330146b0-7114-429c-ae11-2993316a65f6"
},
"outputs": [],
"source": [
"base_env.rand_step(td)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "84bd7c7e-94c3-46a6-9c14-c72aa6488c31",
"metadata": {
"tags": [],
"id": "84bd7c7e-94c3-46a6-9c14-c72aa6488c31"
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "5efbc23a-7e0a-4d34-873e-8db1c07a2050",
"metadata": {
"tags": [],
"id": "5efbc23a-7e0a-4d34-873e-8db1c07a2050"
},
"outputs": [],
"source": [
"from torchrl.envs.transforms import ActionMask, TransformedEnv"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "82c75678-13e2-4190-b9ea-fa391bbaecd8",
"metadata": {
"tags": [],
"id": "82c75678-13e2-4190-b9ea-fa391bbaecd8"
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "66eb3dfb-00e4-4b22-beef-9145d6eab6e9",
"metadata": {
"id": "66eb3dfb-00e4-4b22-beef-9145d6eab6e9"
},
"outputs": [],
"source": [
"env = TransformedEnv(base_env, ActionMask())\n",
"r = env.rollout(10)\n",
"print(r)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c8568f3f-5c58-4fba-bb8a-7ca423b6b637",
"metadata": {
"tags": [],
"id": "c8568f3f-5c58-4fba-bb8a-7ca423b6b637"
},
"outputs": [],
"source": [
"r[\"action\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fa9a97e5-472e-4aa0-9344-45cc23419b40",
"metadata": {
"tags": [],
"id": "fa9a97e5-472e-4aa0-9344-45cc23419b40"
},
"outputs": [],
"source": [
"r[\"mask\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "101cabf0-64db-4598-94aa-0ca32f9d0037",
"metadata": {
"tags": [],
"id": "101cabf0-64db-4598-94aa-0ca32f9d0037"
},
"outputs": [],
"source": [
"base_env.action_spec.mask"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "75807c5b-8ff0-443c-869c-df9c2f522953",
"metadata": {
"tags": [],
"id": "75807c5b-8ff0-443c-869c-df9c2f522953"
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "5e65bfb0-74d5-4990-985b-83479a6ebb2f",
"metadata": {
"tags": [],
"id": "5e65bfb0-74d5-4990-985b-83479a6ebb2f"
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "19930288-a2ad-4b9f-a5ab-5e53c2bc39c8",
"metadata": {
"tags": [],
"id": "19930288-a2ad-4b9f-a5ab-5e53c2bc39c8"
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "2b8ef5d8-6fa1-4704-bd6f-ae2aef48f45e",
"metadata": {
"id": "2b8ef5d8-6fa1-4704-bd6f-ae2aef48f45e"
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.17"
},
"colab": {
"provenance": [],
"include_colab_link": true
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment