Skip to content

Instantly share code, notes, and snippets.

View ssnl's full-sized avatar
🍘

Tongzhou Wang ssnl

🍘
View GitHub Profile
name: llm_agent
channels:
- pytorch
- nvidia
- nvidia/label/cuda-11.8.0
- conda-forge
- https://conda.anaconda.org/gurobi
- defaults
dependencies:
- _libgcc_mutex=0.1=main
absl-py==2.1.0
aiofiles==22.1.0
aiosqlite==0.18.0
antlr4-python3-runtime==4.9.3
anyio==4.2.0
appdirs==1.4.4
appnope==0.1.2
archspec==0.2.1
argon2-cffi==21.3.0
argon2-cffi-bindings==21.2.0
diff --git a/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py b/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py
index b353c39..1c95df3 100644
--- a/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py
+++ b/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py
@@ -54,6 +54,22 @@ from . import CriticLossBase, CriticBatchInfo
# return f"weight={self.weight:g}, softplus_beta={self.softplus_beta:g}, softplus_offset={self.softplus_offset:g}"
+@attrs.define(kw_only=True)
+class TaskCriticInfo:
diff --git a/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py b/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py
index b353c39..1ffe375 100644
--- a/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py
+++ b/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py
@@ -54,6 +54,17 @@ from . import CriticLossBase, CriticBatchInfo
# return f"weight={self.weight:g}, softplus_beta={self.softplus_beta:g}, softplus_offset={self.softplus_offset:g}"
+@attrs.define(kw_only=True)
+class Task:
diff --git a/quasimetric_rl/base_conf.py b/quasimetric_rl/base_conf.py
index 32de62f..df6afea 100644
--- a/quasimetric_rl/base_conf.py
+++ b/quasimetric_rl/base_conf.py
@@ -140,8 +140,8 @@ class BaseConf(abc.ABC):
]
if self.agent.quasimetric_critic.losses.dynamics_lagrange_mult_optim.lr > 0:
specs[-1] += '-opt'
- if self.agent.num_critics > 1:
- specs.append(f'{self.agent.num_critics}critic')
diff --git a/quasimetric_rl/base_conf.py b/quasimetric_rl/base_conf.py
index 32de62f..df6afea 100644
--- a/quasimetric_rl/base_conf.py
+++ b/quasimetric_rl/base_conf.py
@@ -140,8 +140,8 @@ class BaseConf(abc.ABC):
]
if self.agent.quasimetric_critic.losses.dynamics_lagrange_mult_optim.lr > 0:
specs[-1] += '-opt'
- if self.agent.num_critics > 1:
- specs.append(f'{self.agent.num_critics}critic')
diff --git a/quasimetric_rl/base_conf.py b/quasimetric_rl/base_conf.py
index 32de62f..df6afea 100644
--- a/quasimetric_rl/base_conf.py
+++ b/quasimetric_rl/base_conf.py
@@ -140,8 +140,8 @@ class BaseConf(abc.ABC):
]
if self.agent.quasimetric_critic.losses.dynamics_lagrange_mult_optim.lr > 0:
specs[-1] += '-opt'
- if self.agent.num_critics > 1:
- specs.append(f'{self.agent.num_critics}critic')
diff --git a/quasimetric_rl/data/base.py b/quasimetric_rl/data/base.py
index 09711c4..8412f7d 100644
--- a/quasimetric_rl/data/base.py
+++ b/quasimetric_rl/data/base.py
@@ -208,6 +208,7 @@ class Dataset(torch.utils.data.Dataset):
kind: str = MISSING # d4rl, gcrl, etc.
name: str = MISSING # maze2d-umaze-v1, etc.
+ horizon: int = attrs.field(default=1, validator=attrs.validators.gt(0)) # type: ignore
diff --git a/offline/main.py b/offline/main.py
index 5502749..1581d30 100644
--- a/offline/main.py
+++ b/offline/main.py
@@ -43,6 +43,7 @@ class Conf(BaseConf):
total_optim_steps: int = attrs.field(default=int(2e5), validator=attrs.validators.gt(0))
log_steps: int = attrs.field(default=250, validator=attrs.validators.gt(0)) # type: ignore
+ eval_before_training: bool = False
eval_steps: int = attrs.field(default=20000, validator=attrs.validators.gt(0)) # type: ignore
diff --git a/offline/main.py b/offline/main.py
index 5502749..1581d30 100644
--- a/offline/main.py
+++ b/offline/main.py
@@ -43,6 +43,7 @@ class Conf(BaseConf):
total_optim_steps: int = attrs.field(default=int(2e5), validator=attrs.validators.gt(0))
log_steps: int = attrs.field(default=250, validator=attrs.validators.gt(0)) # type: ignore
+ eval_before_training: bool = False
eval_steps: int = attrs.field(default=20000, validator=attrs.validators.gt(0)) # type: ignore