Skip to content

Instantly share code, notes, and snippets.

@ssnl
Created March 15, 2024 17:44
Show Gist options
  • Save ssnl/c7ac69147331b1db6232f640603a9f75 to your computer and use it in GitHub Desktop.
Save ssnl/c7ac69147331b1db6232f640603a9f75 to your computer and use it in GitHub Desktop.
diff --git a/offline/main.py b/offline/main.py
index 5502749..1581d30 100644
--- a/offline/main.py
+++ b/offline/main.py
@@ -43,6 +43,7 @@ class Conf(BaseConf):
total_optim_steps: int = attrs.field(default=int(2e5), validator=attrs.validators.gt(0))
log_steps: int = attrs.field(default=250, validator=attrs.validators.gt(0)) # type: ignore
+ eval_before_training: bool = False
eval_steps: int = attrs.field(default=20000, validator=attrs.validators.gt(0)) # type: ignore
save_steps: int = attrs.field(default=50000, validator=attrs.validators.gt(0)) # type: ignore
num_eval_episodes: int = attrs.field(default=50, validator=attrs.validators.ge(0)) # type: ignore
@@ -57,7 +58,7 @@ cs.store(name='config', node=Conf())
@hydra.main(version_base=None, config_name="config")
def train(dict_cfg: DictConfig):
cfg: Conf = Conf.from_DictConfig(dict_cfg) # type: ignore
- cfg.setup_for_experiment() # checking & setup logging
+ wandb_run = cfg.setup_for_experiment() # checking & setup logging
dataset = cfg.env.make()
@@ -116,7 +117,7 @@ def train(dict_cfg: DictConfig):
logging.info(f"Checkpointed to {relpath}")
def eval(epoch, it, optim_steps):
- val_result_allenvs = trainer.evaluate()
+ val_result_allenvs = trainer.evaluate(desc=f"opt{optim_steps:08d}")
val_results.clear()
val_results.append(dict(
epoch=epoch,
@@ -128,25 +129,10 @@ def train(dict_cfg: DictConfig):
epoch=epoch,
it=it,
optim_steps=optim_steps,
- result={},
+ result={
+ k: val_result.summarize() for k, val_result in val_result_allenvs.items()
+ },
))
- for k, val_result in val_result_allenvs.items():
- succ_rate_ts = (
- None if val_result.timestep_is_success is None
- else torch.stack([_x.mean(dtype=torch.float32) for _x in val_result.timestep_is_success])
- )
- hitting_time = val_result.capped_hitting_time
- summary = dict(
- epi_return=val_result.episode_return,
- epi_score=val_result.episode_score,
- succ_rate_ts=succ_rate_ts,
- succ_rate=val_result.is_success,
- hitting_time=hitting_time,
- )
- for kk, v in val_result.extra_timestep_results.items():
- summary[kk] = torch.stack([_v.mean(dtype=torch.float32) for _v in v])
- summary[kk + '_last'] = torch.stack([_v[-1] for _v in v])
- val_summaries[-1]['result'][k] = summary
averaged_info = cfg.stage_log_info(dict(eval=val_summaries[-1]), optim_steps)
with open(os.path.join(cfg.output_dir, 'eval.log'), 'a') as f: # type: ignore
print(json.dumps(averaged_info), file=f)
@@ -184,15 +170,18 @@ def train(dict_cfg: DictConfig):
)
num_total_epochs = int(np.ceil(cfg.total_optim_steps / trainer.num_batches))
+ logging.info(f"save folder: {cfg.output_dir}")
+ logging.info(f"wandb: {wandb_run.get_url()}")
# Training loop
optim_steps = 0
- # eval(0, 0, optim_steps)
+ if cfg.eval_before_training:
+ eval(0, 0, optim_steps)
save(0, 0, optim_steps)
if start_epoch < num_total_epochs:
for epoch in range(num_total_epochs):
epoch_desc = f"Train epoch {epoch:05d}/{num_total_epochs:05d}"
- for it, (data, data_info) in enumerate(tqdm(trainer.iter_training_data(), total=trainer.num_batches, desc=epoch_desc)):
+ for it, (data, data_info) in enumerate(tqdm(trainer.iter_training_data(), total=trainer.num_batches, desc=epoch_desc, leave=False)):
step_counter.update_then_record_alerts()
optim_steps += 1
@@ -231,4 +220,9 @@ if __name__ == '__main__':
# set up some hydra flags before parsing
os.environ['HYDRA_FULL_ERROR'] = str(int(FLAGS.DEBUG))
- train()
+ try:
+ train()
+ except:
+ import wandb
+ wandb.finish(1) # sometimes crashes are not reported?? let's be safe
+ raise
diff --git a/offline/trainer.py b/offline/trainer.py
index e7f48cf..c75abed 100644
--- a/offline/trainer.py
+++ b/offline/trainer.py
@@ -11,52 +11,7 @@ import torch
import torch.utils.data
from quasimetric_rl.modules import QRLConf, QRLAgent, QRLLosses, InfoT
-from quasimetric_rl.data import BatchData, Dataset, EpisodeData, EnvSpec, OfflineEnv
-
-
-def first_nonzero(arr: torch.Tensor, dim: int = -1, invalid_val: int = -1):
- mask = (arr != 0)
- return torch.where(mask.any(dim=dim), mask.to(torch.uint8).argmax(dim=dim), invalid_val)
-
-
-@attrs.define(kw_only=True)
-class EvalEpisodeResult:
- timestep_reward: List[torch.Tensor]
- episode_return: torch.Tensor
- episode_score: torch.Tensor
- timestep_is_success: Optional[List[torch.Tensor]]
- is_success: Optional[torch.Tensor]
- hitting_time: Optional[torch.Tensor]
- extra_timestep_results: Mapping[str, List[torch.Tensor]]
-
- @property
- def capped_hitting_time(self) -> Optional[torch.Tensor]:
- # if not hit -> |ts| + 1
- if self.hitting_time is None:
- return None
- assert self.timestep_is_success is not None
- return torch.stack([torch.where(_x < 0, _succ.shape[0] + 1, _x) for _x, _succ in zip(self.hitting_time, self.timestep_is_success)])
-
- @classmethod
- def from_timestep_reward_is_success(cls, dataset: Dataset,
- timestep_reward: List[torch.Tensor],
- timestep_is_success: Optional[List[torch.Tensor]],
- extra_timestep_results) -> Self:
- return cls(
- timestep_reward=timestep_reward,
- episode_return=torch.stack([r.sum() for r in timestep_reward]),
- episode_score=dataset.normalize_score(timestep_reward),
- timestep_is_success=timestep_is_success,
- is_success=(
- None if timestep_is_success is None
- else torch.stack([_x.any(dim=-1) for _x in timestep_is_success])
- ),
- hitting_time=(
- None if timestep_is_success is None
- else torch.stack([first_nonzero(_x, dim=-1) for _x in timestep_is_success])
- ), # NB this is off by 1
- extra_timestep_results=dict(extra_timestep_results),
- )
+from quasimetric_rl.data import BatchData, Dataset, EpisodeData, EnvSpec, OfflineEnv, interaction
class Trainer(object):
@@ -133,32 +88,20 @@ class Trainer(object):
adistn = self.agent.actor(obs[None].to(self.device), goal[None].to(self.device))
return adistn.mode.cpu().numpy()[0]
- rollout = Dataset.collect_rollout_general(
+ rollout = interaction.collect_rollout(
actor, env=env, env_spec=EnvSpec.from_env(env),
max_episode_length=env.max_episode_steps)
return rollout
- def evaluate(self) -> Mapping[str, EvalEpisodeResult]:
+ def evaluate(self, desc=None) -> Mapping[str, interaction.EvalEpisodeResult]:
envs = self.dataset.create_eval_envs(self.eval_seed)
- results: Dict[str, EvalEpisodeResult] = {}
+ results: Dict[str, interaction.EvalEpisodeResult] = {}
for k, env in envs.items():
rollouts: List[EpisodeData] = []
- for _ in tqdm(range(self.num_eval_episodes), desc=f'eval/{k}'):
+ this_desc = f'eval/{k}'
+ if desc is not None:
+ this_desc = f'{desc}/{this_desc}'
+ for _ in tqdm(range(self.num_eval_episodes), desc=this_desc):
rollouts.append(self.collect_eval_rollout(env=env))
- results[k] = EvalEpisodeResult.from_timestep_reward_is_success(
- self.dataset,
- timestep_reward=[rollout.rewards for rollout in rollouts],
- timestep_is_success=(
- None
- if len(rollouts) == 0 or 'is_success' not in rollouts[0].transition_infos
- else [rollout.transition_infos['is_success'] for rollout in rollouts]
- ),
- extra_timestep_results=(
- {} if len(rollouts) == 0 else
- {
- k: [rollout.transition_infos[k] for rollout in rollouts]
- for k in rollouts[0].transition_infos.keys() if k != 'is_success'
- }
- ),
- )
+ results[k] = interaction.EvalEpisodeResult.from_episode_rollouts(self.dataset, rollouts)
return results
diff --git a/online/main.py b/online/main.py
index cd58f7a..c16e094 100644
--- a/online/main.py
+++ b/online/main.py
@@ -50,7 +50,7 @@ cs.store(name='config', node=Conf())
@hydra.main(version_base=None, config_name="config")
def train(dict_cfg: DictConfig):
cfg: Conf = Conf.from_DictConfig(dict_cfg) # type: ignore
- cfg.setup_for_experiment() # checking & setup logging
+ wandb_run = cfg.setup_for_experiment() # checking & setup logging
replay_buffer = cfg.env.make()
@@ -85,7 +85,7 @@ def train(dict_cfg: DictConfig):
logging.info(f"Checkpointed to {relpath}")
def eval(env_steps, optim_steps):
- val_result_allenvs = trainer.evaluate()
+ val_result_allenvs = trainer.evaluate(desc=f'env{env_steps:08d}_opt{optim_steps:08d}')
val_results.clear()
val_results.append(dict(
env_steps=env_steps,
@@ -95,18 +95,10 @@ def train(dict_cfg: DictConfig):
val_summaries.append(dict(
env_steps=env_steps,
optim_steps=optim_steps,
- result={},
+ result={
+ k: val_result.summarize() for k, val_result in val_result_allenvs.items()
+ },
))
- for k, val_result in val_result_allenvs.items():
- succ_rate_ts = val_result.timestep_is_success.mean(dtype=torch.float32, dim=-1)
- hitting_time = val_result.capped_hitting_time
- val_summaries[-1]['result'][k] = dict(
- epi_return=val_result.episode_return,
- epi_score=val_result.episode_score,
- succ_rate_ts=succ_rate_ts,
- succ_rate=val_result.is_success,
- hitting_time=hitting_time,
- )
averaged_info = cfg.stage_log_info(dict(eval=val_summaries[-1]), optim_steps)
with open(os.path.join(cfg.output_dir, 'eval.log'), 'a') as f: # type: ignore
print(json.dumps(averaged_info), file=f)
@@ -121,6 +113,8 @@ def train(dict_cfg: DictConfig):
),
)
+ logging.info(f"save folder: {cfg.output_dir}")
+ logging.info(f"wandb: {wandb_run.get_url()}")
# Training loop
eval(0, 0); save(0, 0)
for optim_steps, (env_steps, next_iter_new_env_step, data, data_info) in enumerate(trainer.iter_training_data(), start=1):
@@ -162,4 +156,9 @@ if __name__ == '__main__':
# set up some hydra flags before parsing
os.environ['HYDRA_FULL_ERROR'] = str(int(FLAGS.DEBUG))
- train()
+ try:
+ train()
+ except:
+ import wandb
+ wandb.finish(1) # sometimes crashes are not reported?? let's be safe
+ raise
diff --git a/online/trainer.py b/online/trainer.py
index 01faad8..8926871 100644
--- a/online/trainer.py
+++ b/online/trainer.py
@@ -12,43 +12,11 @@ import torch
import torch.utils.data
from quasimetric_rl.modules import QRLConf, QRLAgent, QRLLosses, InfoT
-from quasimetric_rl.data import Dataset, BatchData, EpisodeData, MultiEpisodeData
+from quasimetric_rl.data import BatchData, EpisodeData, interaction
from quasimetric_rl.data.online import ReplayBuffer, OnlineFixedLengthEnv
from quasimetric_rl.utils import tqdm
-def first_nonzero(arr: torch.Tensor, dim: int = -1, invalid_val: int = -1):
- mask = (arr != 0)
- return torch.where(mask.any(dim=dim), mask.to(torch.uint8).argmax(dim=dim), invalid_val)
-
-
-@attrs.define(kw_only=True)
-class EvalEpisodeResult:
- timestep_reward: torch.Tensor
- episode_return: torch.Tensor
- episode_score: torch.Tensor
- timestep_is_success: torch.Tensor
- is_success: torch.Tensor
- hitting_time: torch.Tensor
-
- @property
- def capped_hitting_time(self) -> torch.Tensor:
- # if not hit -> |ts| + 1
- return torch.stack([torch.where(_x < 0, self.timestep_is_success.shape[-1] + 1, _x) for _x in self.hitting_time])
-
- @classmethod
- def from_timestep_reward_is_success(cls, dataset: Dataset, timestep_reward: torch.Tensor,
- timestep_is_success: torch.Tensor) -> Self:
- return cls(
- timestep_reward=timestep_reward,
- episode_return=timestep_reward.sum(-1),
- episode_score=dataset.normalize_score(cast(Sequence[torch.Tensor], timestep_reward)),
- timestep_is_success=timestep_is_success,
- is_success=timestep_is_success.any(dim=-1),
- hitting_time=first_nonzero(timestep_is_success, dim=-1), # NB this is off by 1
- )
-
-
@attrs.define(kw_only=True)
class InteractionConf:
total_env_steps: int = attrs.field(default=int(1e6), validator=attrs.validators.gt(0))
@@ -165,23 +133,18 @@ class Trainer(object):
self.replay.add_rollout(rollout)
return rollout
- def evaluate(self) -> Mapping[str, EvalEpisodeResult]:
+ def evaluate(self, desc=None) -> Mapping[str, interaction.EvalEpisodeResult]:
envs = self.make_evaluate_envs()
- results: Dict[str, EvalEpisodeResult] = {}
+ results: Dict[str, interaction.EvalEpisodeResult] = {}
for k, env in envs.items():
rollouts = []
- for _ in tqdm(range(self.num_eval_episodes), desc=f'eval/{k}'):
+ this_desc = f'eval/{k}'
+ if desc is not None:
+ this_desc = f'{desc}/{this_desc}'
+ for _ in tqdm(range(self.num_eval_episodes), desc=this_desc):
rollouts.append(self.collect_rollout(eval=True, store=False, env=env))
- mrollouts = MultiEpisodeData.cat(rollouts)
- results[k] = EvalEpisodeResult.from_timestep_reward_is_success(
- self.replay,
- mrollouts.rewards.reshape(
- self.num_eval_episodes, env.episode_length,
- ),
- mrollouts.transition_infos['is_success'].reshape(
- self.num_eval_episodes, env.episode_length,
- ),
- )
+ results[k] = interaction.EvalEpisodeResult.from_episode_rollouts(
+ self.replay, rollouts)
return results
def iter_training_data(self) -> Iterator[Tuple[int, bool, BatchData, InfoT]]:
@@ -197,7 +160,7 @@ class Trainer(object):
"""
def yield_data():
num_transitions = self.replay.num_transitions_realized
- for icyc in tqdm(range(self.num_samples_per_cycle), desc=f"{num_transitions} env steps, train batches"):
+ for icyc in tqdm(range(self.num_samples_per_cycle), desc=f"{num_transitions} env steps, train batches", leave=False):
data_t0 = time.time()
data = self.sample()
info = dict(
diff --git a/quasimetric_rl/base_conf.py b/quasimetric_rl/base_conf.py
index b3a332e..e9f2202 100644
--- a/quasimetric_rl/base_conf.py
+++ b/quasimetric_rl/base_conf.py
@@ -172,7 +172,7 @@ class BaseConf(abc.ABC):
else:
raise RuntimeError(f'Output directory {self.output_dir} exists and is complete')
- wandb.init(
+ run = wandb.init(
project=self.wandb_project,
name=self.output_folder.replace('/', '__') + '__' + datetime.now().strftime(r"%Y%m%d_%H:%M:%S"),
config=yaml.safe_load(OmegaConf.to_yaml(self)),
@@ -219,3 +219,6 @@ class BaseConf(abc.ABC):
if self.device.type == 'cuda' and self.device.index is not None:
torch.cuda.set_device(self.device.index)
+
+ assert run is not None
+ return run
diff --git a/quasimetric_rl/data/__init__.py b/quasimetric_rl/data/__init__.py
index 41fc413..519796b 100644
--- a/quasimetric_rl/data/__init__.py
+++ b/quasimetric_rl/data/__init__.py
@@ -5,8 +5,9 @@ from .env_spec import EnvSpec
from . import online
from .online import register_online_env, OnlineFixedLengthEnv
from .offline import OfflineEnv
+from . import interaction
__all__ = [
'BatchData', 'EpisodeData', 'MultiEpisodeData', 'Dataset', 'register_offline_env',
- 'EnvSpec', 'online', 'register_online_env', 'OnlineFixedLengthEnv', 'OfflineEnv',
+ 'EnvSpec', 'online', 'register_online_env', 'OnlineFixedLengthEnv', 'OfflineEnv', 'interaction',
]
diff --git a/quasimetric_rl/data/base.py b/quasimetric_rl/data/base.py
index c86f525..09711c4 100644
--- a/quasimetric_rl/data/base.py
+++ b/quasimetric_rl/data/base.py
@@ -36,6 +36,7 @@ class BatchData(TensorCollectionAttrsMixin): # TensorCollectionAttrsMixin has s
timeouts: torch.Tensor
future_observations: torch.Tensor # sampled!
+ future_tdelta: torch.Tensor
@property
def device(self) -> torch.device:
@@ -143,6 +144,12 @@ class EpisodeData(MultiEpisodeData):
rewards=self.rewards[:t],
terminals=self.terminals[:t],
timeouts=self.timeouts[:t],
+ observation_infos={
+ k: v[:t + 1] for k, v in self.observation_infos.items()
+ },
+ transition_infos={
+ k: v[:t] for k, v in self.transition_infos.items()
+ },
)
@@ -235,6 +242,7 @@ class Dataset(torch.utils.data.Dataset):
indices_to_episode_timesteps: torch.Tensor
max_episode_length: int
# -----
+ device: torch.device
def create_env(self, *, dict_obseravtion: Optional[bool] = None, seed: Optional[int] = None, **kwargs) -> 'OfflineEnv':
from .offline import OfflineEnv
@@ -272,6 +280,7 @@ class Dataset(torch.utils.data.Dataset):
def __init__(self, kind: str, name: str, *,
future_observation_discount: float,
dummy: bool = False, # when you don't want to load data, e.g., in analysis
+ device: torch.device = torch.device('cpu'), # FIXME: get some heuristic
) -> None:
self.kind = kind
self.name = name
@@ -298,97 +307,22 @@ class Dataset(torch.utils.data.Dataset):
indices_to_episode_timesteps.append(torch.arange(l, dtype=torch.int64))
assert len(episodes) > 0, "must have at least one episode"
- self.raw_data = MultiEpisodeData.cat(episodes)
+ self.raw_data = MultiEpisodeData.cat(episodes).to(device)
- self.obs_indices_to_obs_index_in_episode = torch.cat(obs_indices_to_obs_index_in_episode, dim=0)
- self.indices_to_episode_indices = torch.cat(indices_to_episode_indices, dim=0)
- self.indices_to_episode_timesteps = torch.cat(indices_to_episode_timesteps, dim=0)
+ self.obs_indices_to_obs_index_in_episode = torch.cat(obs_indices_to_obs_index_in_episode, dim=0).to(device)
+ self.indices_to_episode_indices = torch.cat(indices_to_episode_indices, dim=0).to(device)
+ self.indices_to_episode_timesteps = torch.cat(indices_to_episode_timesteps, dim=0).to(device)
self.max_episode_length = int(self.raw_data.episode_lengths.max().item())
+ self.device = device
- def get_observations(self, obs_indices: torch.Tensor):
- return self.raw_data.all_observations[obs_indices]
+ # def max_bytes_used(self):
+ # return self
- @classmethod
- def collect_rollout_general(cls, actor: Callable[[torch.Tensor, torch.Tensor, gym.Space], np.ndarray], *,
- env: gym.Env, env_spec: EnvSpec, max_episode_length: int,
- assert_exact_episode_length: bool = False, extra_transition_info_keys: Collection[str] = []) -> EpisodeData:
- from .utils import get_empty_episode
-
- epi = get_empty_episode(env_spec, max_episode_length)
-
- # check observation space
- obs_dict_keys = {'observation', 'achieved_goal', 'desired_goal'}
- WRONG_OBS_ERR_MESSAGE = (
- f"{cls.__name__} collect_rollout only supports Dict "
- f"observation space with keys {obs_dict_keys}, but got {env.observation_space}"
- )
- assert isinstance(env.observation_space, gym.spaces.Dict), WRONG_OBS_ERR_MESSAGE
- assert set(env.observation_space.spaces.keys()) == {'observation', 'achieved_goal', 'desired_goal'}, WRONG_OBS_ERR_MESSAGE
-
- observation_dict = cast(Mapping[str, np.ndarray], env.reset())
- observation: torch.Tensor = torch.as_tensor(observation_dict['observation'], dtype=torch.float32)
-
- goal: torch.Tensor = torch.as_tensor(observation_dict['desired_goal'], dtype=torch.float32)
- agoal: torch.Tensor = torch.as_tensor(observation_dict['achieved_goal'], dtype=torch.float32)
- epi.all_observations[0] = observation
- epi.observation_infos['desired_goals'][0] = goal
- epi.observation_infos['achieved_goals'][0] = agoal
- if len(extra_transition_info_keys):
- epi.transition_infos = dict(epi.transition_infos)
- epi.transition_infos.update({
- k: torch.empty([max_episode_length], dtype=torch.float32) for k in extra_transition_info_keys
- })
-
- t = 0
- timeout = False
- terminal = False
- while not timeout and not terminal:
- assert t < max_episode_length
-
- action = actor(
- observation,
- goal,
- env_spec.action_space,
- )
- transition_out = env.step(np.asarray(action))
- observation_dict, reward, terminal, info = transition_out[:3] + transition_out[-1:] # some BC
-
- observation = torch.tensor(observation_dict['observation'], dtype=torch.float32) # copy just in case
-
- goal: torch.Tensor = torch.as_tensor(observation_dict['desired_goal'], dtype=torch.float32)
- agoal: torch.Tensor = torch.as_tensor(observation_dict['achieved_goal'], dtype=torch.float32)
-
- if 'is_success' in info:
- is_success: bool = info['is_success']
- epi.transition_infos['is_success'][t] = is_success
- else:
- if t == 0:
- # remove field
- transition_infos = dict(epi.transition_infos)
- del transition_infos['is_success']
- epi.transition_infos = transition_infos
-
- for k in extra_transition_info_keys:
- epi.transition_infos[k][t] = info[k]
-
- epi.all_observations[t + 1] = observation
- epi.actions[t] = torch.as_tensor(action, dtype=torch.float32)
- epi.rewards[t] = reward
- epi.observation_infos['desired_goals'][t + 1] = goal
- epi.observation_infos['achieved_goals'][t + 1] = agoal
-
- t += 1
- timeout = info.get('TimeLimit.truncated', False)
- if assert_exact_episode_length:
- assert (timeout or terminal) == (t == max_episode_length)
-
- if t < max_episode_length:
- epi = epi.first_t(t)
-
- return epi
+ def get_observations(self, obs_indices: torch.Tensor):
+ return self.raw_data.all_observations[obs_indices.to(self.device)]
def __getitem__(self, indices: torch.Tensor) -> BatchData:
- indices = torch.as_tensor(indices)
+ indices = torch.as_tensor(indices, device=self.device)
eindices = self.indices_to_episode_indices[indices]
obs_indices = indices + eindices # index for `observation`: skip the s_last from previous episodes
obs = self.get_observations(obs_indices)
@@ -398,7 +332,7 @@ class Dataset(torch.utils.data.Dataset):
tindices = self.indices_to_episode_timesteps[indices]
epilengths = self.raw_data.episode_lengths[eindices] # max idx is this
- deltas = torch.arange(self.max_episode_length)
+ deltas = torch.arange(self.max_episode_length, device=self.device)
pdeltas = torch.where(
# test tidx + 1 + delta <= max_idx = epi_length
(tindices[:, None] + deltas) < epilengths[:, None],
@@ -407,14 +341,16 @@ class Dataset(torch.utils.data.Dataset):
)
deltas = torch.distributions.Categorical(
probs=pdeltas,
- ).sample()
- future_observations = self.get_observations(obs_indices + 1 + deltas)
+ validate_args=False,
+ ).sample() + 1
+ future_observations = self.get_observations(obs_indices + deltas)
return BatchData(
observations=obs,
actions=self.raw_data.actions[indices],
next_observations=nobs,
future_observations=future_observations,
+ future_tdelta=deltas,
rewards=self.raw_data.rewards[indices],
terminals=terminals,
timeouts=self.raw_data.timeouts[indices],
diff --git a/quasimetric_rl/data/interaction.py b/quasimetric_rl/data/interaction.py
new file mode 100644
index 0000000..699fa61
--- /dev/null
+++ b/quasimetric_rl/data/interaction.py
@@ -0,0 +1,178 @@
+from __future__ import annotations
+from typing import *
+
+import attrs
+
+import gym
+import gym.spaces
+import numpy as np
+import torch
+import torch.utils.data
+
+from . import Dataset, EpisodeData, EnvSpec
+
+
+def first_nonzero(arr: torch.Tensor, dim: int = -1, invalid_val: int = -1):
+ mask = (arr != 0)
+ return torch.where(mask.any(dim=dim), mask.to(torch.uint8).argmax(dim=dim), invalid_val)
+
+
+@attrs.define(kw_only=True)
+class EvalEpisodeResult:
+ timestep_reward: List[torch.Tensor]
+ episode_return: torch.Tensor
+ episode_score: torch.Tensor
+ timestep_is_success: Optional[List[torch.Tensor]]
+ is_success: Optional[torch.Tensor]
+ hitting_time: Optional[torch.Tensor]
+ extra_timestep_results: Mapping[str, List[torch.Tensor]]
+
+ @property
+ def capped_hitting_time(self) -> Optional[torch.Tensor]:
+ # if not hit -> |ts| + 1
+ if self.hitting_time is None:
+ return None
+ assert self.timestep_is_success is not None
+ return torch.stack([torch.where(_x < 0, _succ.shape[0] + 1, _x) for _x, _succ in zip(self.hitting_time, self.timestep_is_success)])
+
+ @classmethod
+ def from_timestep_reward_is_success(cls, dataset: Dataset,
+ timestep_reward: List[torch.Tensor],
+ timestep_is_success: Optional[List[torch.Tensor]],
+ extra_timestep_results) -> Self:
+ return cls(
+ timestep_reward=timestep_reward,
+ episode_return=torch.stack([r.sum() for r in timestep_reward]),
+ episode_score=dataset.normalize_score(timestep_reward),
+ timestep_is_success=timestep_is_success,
+ is_success=(
+ None if timestep_is_success is None
+ else torch.stack([_x.any(dim=-1) for _x in timestep_is_success])
+ ),
+ hitting_time=(
+ None if timestep_is_success is None
+ else torch.stack([first_nonzero(_x, dim=-1) for _x in timestep_is_success])
+ ), # NB this is off by 1
+ extra_timestep_results=dict(extra_timestep_results),
+ )
+
+ @classmethod
+ def from_episode_rollouts(cls, dataset: Dataset,rollouts: Sequence[EpisodeData]) -> Self:
+ return cls.from_timestep_reward_is_success(
+ dataset,
+ timestep_reward=[rollout.rewards for rollout in rollouts],
+ timestep_is_success=(
+ None
+ if len(rollouts) == 0 or 'is_success' not in rollouts[0].transition_infos
+ else [rollout.transition_infos['is_success'] for rollout in rollouts]
+ ),
+ extra_timestep_results=(
+ {} if len(rollouts) == 0 else
+ {
+ k: [rollout.transition_infos[k] for rollout in rollouts]
+ for k in rollouts[0].transition_infos.keys() if k != 'is_success'
+ }
+ ),
+ )
+
+ def summarize(self) -> Mapping[str, Union[torch.Tensor, float, None]]:
+ succ_rate_ts = (
+ None if self.timestep_is_success is None
+ else torch.stack([_x.mean(dtype=torch.float32) for _x in self.timestep_is_success])
+ )
+ hitting_time = self.capped_hitting_time
+ summary = dict(
+ epi_return=self.episode_return,
+ epi_score=self.episode_score,
+ succ_rate_ts=succ_rate_ts,
+ succ_rate=self.is_success,
+ hitting_time=hitting_time,
+ )
+ for kk, v in self.extra_timestep_results.items():
+ summary[kk] = torch.stack([_v.mean(dtype=torch.float32) for _v in v])
+ summary[kk + '_last'] = torch.stack([_v[-1] for _v in v])
+
+ return summary
+
+
+def collect_rollout(actor: Callable[[torch.Tensor, torch.Tensor, gym.Space], np.ndarray], *,
+ env: gym.Env, env_spec: EnvSpec, max_episode_length: int,
+ assert_exact_episode_length: bool = False) -> EpisodeData:
+ # NOTE: extra tracked info can be specified by env.tracked_info_keys
+
+ from .utils import get_empty_episode
+
+ epi = get_empty_episode(env_spec, max_episode_length)
+
+ # check observation space
+ obs_dict_keys = {'observation', 'achieved_goal', 'desired_goal'}
+ WRONG_OBS_ERR_MESSAGE = (
+ f"collect_rollout only supports Dict "
+ f"observation space with keys {obs_dict_keys}, but got {env.observation_space}"
+ )
+ assert isinstance(env.observation_space, gym.spaces.Dict), WRONG_OBS_ERR_MESSAGE
+ assert set(env.observation_space.spaces.keys()) == {'observation', 'achieved_goal', 'desired_goal'}, WRONG_OBS_ERR_MESSAGE
+
+ observation_dict = cast(Mapping[str, np.ndarray], env.reset())
+ observation: torch.Tensor = torch.as_tensor(observation_dict['observation'], dtype=torch.float32)
+
+ goal: torch.Tensor = torch.as_tensor(observation_dict['desired_goal'], dtype=torch.float32)
+ agoal: torch.Tensor = torch.as_tensor(observation_dict['achieved_goal'], dtype=torch.float32)
+ epi.all_observations[0] = observation
+ epi.observation_infos['desired_goals'][0] = goal
+ epi.observation_infos['achieved_goals'][0] = agoal
+
+ extra_transition_info_keys = getattr(env, 'tracked_info_keys', [])
+ if len(extra_transition_info_keys):
+ epi.transition_infos = dict(epi.transition_infos)
+ epi.transition_infos.update({
+ k: torch.empty([max_episode_length], dtype=torch.float32) for k in extra_transition_info_keys
+ })
+
+ t = 0
+ timeout = False
+ terminal = False
+ while not timeout and not terminal:
+ assert t < max_episode_length
+
+ action = actor(
+ observation,
+ goal,
+ env_spec.action_space,
+ )
+ transition_out = env.step(np.asarray(action))
+ observation_dict, reward, terminal, info = transition_out[:3] + transition_out[-1:] # some BC
+
+ observation = torch.tensor(observation_dict['observation'], dtype=torch.float32) # copy just in case
+
+ goal: torch.Tensor = torch.as_tensor(observation_dict['desired_goal'], dtype=torch.float32)
+ agoal: torch.Tensor = torch.as_tensor(observation_dict['achieved_goal'], dtype=torch.float32)
+
+ if 'is_success' in info:
+ is_success: bool = info['is_success']
+ epi.transition_infos['is_success'][t] = is_success
+ else:
+ if t == 0:
+ # remove field
+ transition_infos = dict(epi.transition_infos)
+ del transition_infos['is_success']
+ epi.transition_infos = transition_infos
+
+ for k in extra_transition_info_keys:
+ epi.transition_infos[k][t] = info[k]
+
+ epi.all_observations[t + 1] = observation
+ epi.actions[t] = torch.as_tensor(action, dtype=torch.float32)
+ epi.rewards[t] = reward
+ epi.observation_infos['desired_goals'][t + 1] = goal
+ epi.observation_infos['achieved_goals'][t + 1] = agoal
+
+ t += 1
+ timeout = info.get('TimeLimit.truncated', False)
+ if assert_exact_episode_length:
+ assert (timeout or terminal) == (t == max_episode_length)
+
+ if t < max_episode_length:
+ epi = epi.first_t(t)
+
+ return epi
diff --git a/quasimetric_rl/data/offline/__init__.py b/quasimetric_rl/data/offline/__init__.py
index 22bb42c..c816a7b 100644
--- a/quasimetric_rl/data/offline/__init__.py
+++ b/quasimetric_rl/data/offline/__init__.py
@@ -60,6 +60,10 @@ class OfflineGoalCondEnv(gym.ObservationWrapper, OfflineEnv): # type: ignore
self.get_goal_fn = get_goal_fn
self.extra_info_fns = extra_info_fns
+ @property
+ def tracked_info_keys(self):
+ return tuple(self.extra_info_fns.keys())
+
def observation(self, observation: np.ndarray):
o, g = observation, self.get_goal_fn(self.env)
if self.is_image_based:
diff --git a/quasimetric_rl/data/offline/d4rl/antmaze.py b/quasimetric_rl/data/offline/d4rl/antmaze.py
index c48ab34..dbf5198 100644
--- a/quasimetric_rl/data/offline/d4rl/antmaze.py
+++ b/quasimetric_rl/data/offline/d4rl/antmaze.py
@@ -182,12 +182,13 @@ def create_env_antmaze(name, dict_obseravtion: Optional[bool] = None, *, random_
return env
-def load_episodes_antmaze(name):
+def load_episodes_antmaze(name, normalize_observation=True):
env = load_environment(name)
d4rl_dataset = cached_d4rl_dataset(name)
- # normalize
- d4rl_dataset['observations'] = obs_norm(name, d4rl_dataset['observations'])
- d4rl_dataset['next_observations'] = obs_norm(name, d4rl_dataset['next_observations'])
+ if normalize_observation:
+ # normalize
+ d4rl_dataset['observations'] = obs_norm(name, d4rl_dataset['observations'])
+ d4rl_dataset['next_observations'] = obs_norm(name, d4rl_dataset['next_observations'])
yield from convert_dict_to_EpisodeData_iter(
sequence_dataset(
env,
@@ -206,3 +207,11 @@ for name in ['antmaze-umaze-v2', 'antmaze-umaze-diverse-v2',
normalize_score_fn=functools.partial(get_normalized_score, name),
eval_specs=dict(single_task=dict(random_start_goal=False), multi_task=dict(random_start_goal=True)),
)
+ register_offline_env(
+ 'd4rl', name + '-nonorm',
+ create_env_fn=functools.partial(create_env_antmaze, name, normalize_observation=False),
+ load_episodes_fn=functools.partial(load_episodes_antmaze, name, normalize_observation=False),
+ normalize_score_fn=functools.partial(get_normalized_score, name),
+ eval_specs=dict(single_task=dict(random_start_goal=False, normalize_observation=False),
+ multi_task=dict(random_start_goal=True, normalize_observation=False)),
+ )
diff --git a/quasimetric_rl/data/online/memory.py b/quasimetric_rl/data/online/memory.py
index b7cb94b..a3445f4 100644
--- a/quasimetric_rl/data/online/memory.py
+++ b/quasimetric_rl/data/online/memory.py
@@ -12,6 +12,7 @@ import gym.spaces
from . import OnlineFixedLengthEnv
from ..base import EpisodeData, MultiEpisodeData, Dataset, BatchData
+from ..interaction import collect_rollout
from ..utils import get_empty_episode, get_empty_episodes
@@ -153,7 +154,7 @@ class ReplayBuffer(Dataset):
get_empty_episodes(
self.env_spec, self.episode_length,
int(np.ceil(self.increment_num_transitions / self.episode_length)),
- ),
+ ).to(self.device),
],
dim=0,
)
@@ -164,19 +165,19 @@ class ReplayBuffer(Dataset):
# indices_to_episode_timesteps: torch.Tensor
self.indices_to_episode_indices = torch.cat([
self.indices_to_episode_indices,
- torch.repeat_interleave(torch.arange(original_capacity, new_capacity), self.episode_length),
+ torch.repeat_interleave(torch.arange(original_capacity, new_capacity, device=self.device), self.episode_length),
], dim=0)
self.indices_to_episode_timesteps = torch.cat([
self.indices_to_episode_timesteps,
- torch.arange(self.episode_length).repeat(new_capacity - original_capacity),
+ torch.arange(self.episode_length, device=self.device).repeat(new_capacity - original_capacity),
], dim=0)
logging.info(f'ReplayBuffer: Expanded from capacity={original_capacity} to {new_capacity} episodes')
def collect_rollout(self, actor: Callable[[torch.Tensor, torch.Tensor, gym.Space], np.ndarray], *,
env: Optional[OnlineFixedLengthEnv] = None) -> EpisodeData:
- return self.collect_rollout_general(actor, env=(env or self.env), env_spec=self.env_spec,
- max_episode_length=self.episode_length, assert_exact_episode_length=True)
+ return collect_rollout(actor, env=(env or self.env), env_spec=self.env_spec,
+ max_episode_length=self.episode_length, assert_exact_episode_length=True)
def add_rollout(self, episode: EpisodeData):
if self.num_episodes_realized == self.episodes_capacity:
@@ -214,7 +215,8 @@ class ReplayBuffer(Dataset):
def sample(self, batch_size: int) -> BatchData:
indices = torch.as_tensor(
- np.random.choice(self.num_transitions_realized, size=[batch_size])
+ np.random.choice(self.num_transitions_realized, size=[batch_size]),
+ device=self.device,
)
return self[indices]
diff --git a/quasimetric_rl/flags.py b/quasimetric_rl/flags.py
index 2fb0fab..2f7578a 100644
--- a/quasimetric_rl/flags.py
+++ b/quasimetric_rl/flags.py
@@ -25,31 +25,33 @@ FLAGS = FlagsDefinition()
def pdb_if_DEBUG(fn: Callable):
@functools.wraps(fn)
def wrapped(*args, **kwargs):
- try:
- fn(*args, **kwargs)
- except:
- # follow ABSL:
- # https://github.com/abseil/abseil-py/blob/a0ae31683e6cf3667886c500327f292c893a1740/absl/app.py#L311-L327
-
- exc = sys.exc_info()[1]
- if isinstance(exc, KeyboardInterrupt):
- raise
-
- # Don't try to post-mortem debug successful SystemExits, since those
- # mean there wasn't actually an error. In particular, the test framework
- # raises SystemExit(False) even if all tests passed.
- if isinstance(exc, SystemExit) and not exc.code:
- raise
-
- # Check the tty so that we don't hang waiting for input in an
- # non-interactive scenario.
- if FLAGS.DEBUG:
+ if not FLAGS.DEBUG: # check here, in case it is set after decorator call
+ return fn(*args, **kwargs)
+ else:
+ try:
+ return fn(*args, **kwargs)
+ except:
+ # follow ABSL:
+ # https://github.com/abseil/abseil-py/blob/a0ae31683e6cf3667886c500327f292c893a1740/absl/app.py#L311-L327
+
+ exc = sys.exc_info()[1]
+ if isinstance(exc, KeyboardInterrupt):
+ raise
+
+ # Don't try to post-mortem debug successful SystemExits, since those
+ # mean there wasn't actually an error. In particular, the test framework
+ # raises SystemExit(False) even if all tests passed.
+ if isinstance(exc, SystemExit) and not exc.code:
+ raise
+
+ # Check the tty so that we don't hang waiting for input in an
+ # non-interactive scenario.
traceback.print_exc()
print()
print(' *** Entering post-mortem debugging ***')
print()
pdb.post_mortem()
- raise
+ raise
return wrapped
diff --git a/quasimetric_rl/modules/__init__.py b/quasimetric_rl/modules/__init__.py
index 79fd35f..1ed2315 100644
--- a/quasimetric_rl/modules/__init__.py
+++ b/quasimetric_rl/modules/__init__.py
@@ -32,13 +32,17 @@ class QRLLosses(Module):
critic_losses: Collection[quasimetric_critic.QuasimetricCriticLosses],
critics_total_grad_clip_norm: Optional[float],
recompute_critic_for_actor_loss: bool,
- critics_share_embedding: bool):
+ critics_share_embedding: bool,
+ critic_losses_use_target_encoder: bool,
+ actor_loss_uses_target_encoder: bool):
super().__init__()
self.add_module('actor_loss', actor_loss)
self.critic_losses = torch.nn.ModuleList(critic_losses) # type: ignore
self.critics_total_grad_clip_norm = critics_total_grad_clip_norm
self.recompute_critic_for_actor_loss = recompute_critic_for_actor_loss
self.critics_share_embedding = critics_share_embedding
+ self.critic_losses_use_target_encoder = critic_losses_use_target_encoder
+ self.actor_loss_uses_target_encoder = actor_loss_uses_target_encoder
def forward(self, agent: QRLAgent, data: BatchData, *, optimize: bool = True) -> LossResult:
# compute CriticBatchInfo
@@ -57,8 +61,8 @@ class QRLLosses(Module):
)
else:
zx = critic.encoder(data.observations)
- zy = critic.target_encoder(data.next_observations)
- if critic.has_separate_target_encoder:
+ zy = critic.get_encoder(target=self.critic_losses_use_target_encoder)(data.next_observations)
+ if critic.has_separate_target_encoder and self.critic_losses_use_target_encoder:
assert not zy.requires_grad
critic_batch_info = quasimetric_critic.CriticBatchInfo(
critic=critic,
@@ -85,8 +89,9 @@ class QRLLosses(Module):
assert agent.actor is not None
with torch.no_grad(), torch.inference_mode():
for idx, critic in enumerate(agent.critics):
- if self.recompute_critic_for_actor_loss or critic.has_separate_target_encoder:
- zx, zy = critic.target_encoder(torch.stack([data.observations, data.next_observations], dim=0)).unbind(0)
+ if self.recompute_critic_for_actor_loss or (critic.has_separate_target_encoder and self.actor_loss_uses_target_encoder):
+ zx, zy = critic.get_encoder(target=self.actor_loss_uses_target_encoder)(
+ torch.stack([data.observations, data.next_observations], dim=0)).unbind(0)
critic_batch_infos[idx] = quasimetric_critic.CriticBatchInfo(
critic=critic,
zx=zx,
@@ -142,7 +147,13 @@ class QRLLosses(Module):
critic_loss.dynamics_lagrange_mult_sched.load_state_dict(optim_scheds[f"critic_{idx:02d}"]['dynamics_lagrange_mult_sched'])
def extra_repr(self) -> str:
- return f'recompute_critic_for_actor_loss={self.recompute_critic_for_actor_loss}'
+ return '\n'.join([
+ f'recompute_critic_for_actor_loss={self.recompute_critic_for_actor_loss}',
+ f'critics_share_embedding={self.critics_share_embedding}',
+ f'critics_total_grad_clip_norm={self.critics_total_grad_clip_norm}',
+ f'critic_losses_use_target_encoder={self.critic_losses_use_target_encoder}',
+ f'actor_loss_uses_target_encoder={self.actor_loss_uses_target_encoder}',
+ ])
@attrs.define(kw_only=True)
@@ -155,6 +166,8 @@ class QRLConf:
default=None, validator=attrs.validators.optional(attrs.validators.gt(0)),
)
recompute_critic_for_actor_loss: bool = False
+ critic_losses_use_target_encoder: bool = True
+ actor_loss_uses_target_encoder: bool = True
def make(self, *, env_spec: EnvSpec, total_optim_steps: int) -> Tuple[QRLAgent, QRLLosses]:
if self.actor is None:
@@ -174,9 +187,13 @@ class QRLConf:
critics.append(critic)
critic_losses.append(critic_loss)
- return QRLAgent(actor=actor, critics=critics), QRLLosses(actor_loss=actor_losses, critic_losses=critic_losses,
- critics_share_embedding=self.critics_share_embedding,
- critics_total_grad_clip_norm=self.critics_total_grad_clip_norm,
- recompute_critic_for_actor_loss=self.recompute_critic_for_actor_loss)
+ return QRLAgent(actor=actor, critics=critics), QRLLosses(
+ actor_loss=actor_losses, critic_losses=critic_losses,
+ critics_share_embedding=self.critics_share_embedding,
+ critics_total_grad_clip_norm=self.critics_total_grad_clip_norm,
+ recompute_critic_for_actor_loss=self.recompute_critic_for_actor_loss,
+ critic_losses_use_target_encoder=self.critic_losses_use_target_encoder,
+ actor_loss_uses_target_encoder=self.actor_loss_uses_target_encoder,
+ )
__all__ = ['QRLAgent', 'QRLLosses', 'QRLConf', 'InfoT', 'InfoValT']
diff --git a/quasimetric_rl/modules/actor/losses/awr.py b/quasimetric_rl/modules/actor/losses/awr.py
index c69e79c..f6cf4de 100644
--- a/quasimetric_rl/modules/actor/losses/awr.py
+++ b/quasimetric_rl/modules/actor/losses/awr.py
@@ -3,11 +3,10 @@ from typing import *
import attrs
import torch
-import torch.nn as nn
-from ....data import BatchData, EnvSpec
+from ....data import BatchData
-from ...utils import LatentTensor, LossResult, grad_mul
+from ...utils import LatentTensor, LossResult, bcast_bshape
from ..model import Actor
from ...quasimetric_critic import QuasimetricCritic, CriticBatchInfo
@@ -140,14 +139,23 @@ class AWRLoss(ActorLossBase):
for idx, actor_obs_goal_critic_info in enumerate(actor_obs_goal_critic_infos):
critic = actor_obs_goal_critic_info.critic
- zo = actor_obs_goal_critic_info.zo.detach()
- zg = actor_obs_goal_critic_info.zg.detach()
+ zo = actor_obs_goal_critic_info.zo.detach() # [B,D]
+ zg = actor_obs_goal_critic_info.zg.detach() # [2?,B,D]
with torch.no_grad(), critic.mode(False):
- zp = critic.latent_dynamics(data.observations, zo, data.actions)
- z = torch.stack([zo, zp], dim=0)
- z = z[(slice(None),) + (None,) * (zg.ndim - zo.ndim) + (Ellipsis,)] # if zg is batched, add enough dim to be broadcast-able
- dist_noact, dist = critic.quasimetric_model(z, zg).unbind(0)
- dist_noact = dist_noact.detach()
+ zp = critic.latent_dynamics(data.observations, zo, data.actions) # [B,D]
+ if not critic.borrowing_embedding:
+ zo, zp, zg = bcast_bshape(
+ (zo, 1),
+ (zp, 1),
+ (zg, 1),
+ )
+ z = torch.stack([zo, zp], dim=0) # [2,2?,B,D]
+ # z = z[(slice(None),) + (None,) * (zg.ndim - zo.ndim) + (Ellipsis,)] # if zg is batched, add enough dim to be broadcast-able
+ dist_noact, dist = critic.quasimetric_model(z, zg).unbind(0) # [2,2?,B] -> 2x [2?,B]
+ dist_noact = dist_noact.detach()
+ else:
+ dist = critic.quasimetric_model(zp, zg)
+ dist_noact = dists_noact[0]
info[f'dist_delta_{idx:02d}'] = (dist_noact - dist).mean()
info[f'dist_{idx:02d}'] = dist.mean()
dists_noact.append(dist_noact)
diff --git a/quasimetric_rl/modules/actor/losses/min_dist.py b/quasimetric_rl/modules/actor/losses/min_dist.py
index 0f90cff..bc36ee8 100644
--- a/quasimetric_rl/modules/actor/losses/min_dist.py
+++ b/quasimetric_rl/modules/actor/losses/min_dist.py
@@ -147,13 +147,18 @@ class MinDistLoss(ActorLossBase):
for idx, actor_obs_goal_critic_info in enumerate(actor_obs_goal_critic_infos):
critic = actor_obs_goal_critic_info.critic
- zo = actor_obs_goal_critic_info.zo.detach()
- zg = actor_obs_goal_critic_info.zg.detach()
+ zo = actor_obs_goal_critic_info.zo.detach() # [B,D]
+ zg = actor_obs_goal_critic_info.zg.detach() # [2?,B,D]
with critic.requiring_grad(False), critic.mode(False):
- zp = critic.latent_dynamics(data.observations, zo, action)
- z = torch.stack(torch.broadcast_tensors(zo, zp), dim=0)
- dist_noact, dist = critic.quasimetric_model(z, zg).unbind(0)
- dist_noact = dist_noact.detach()
+ zp = critic.latent_dynamics(data.observations, zo, action) # [2?,B,D]
+ if not critic.borrowing_embedding:
+ # action: [2?,B,A]
+ z = torch.stack(torch.broadcast_tensors(zo, zp), dim=0) # [2,2?,B,D]
+ dist_noact, dist = critic.quasimetric_model(z, zg).unbind(0) # [2,2?,B] -> 2x [2?,B]
+ dist_noact = dist_noact.detach()
+ else:
+ dist = critic.quasimetric_model(zp, zg) # [2?,B]
+ dist_noact = dists_noact[0]
info[f'dist_delta_{idx:02d}'] = (dist_noact - dist).mean()
info[f'dist_{idx:02d}'] = dist.mean()
dists_noact.append(dist_noact)
diff --git a/quasimetric_rl/modules/quasimetric_critic/losses/__init__.py b/quasimetric_rl/modules/quasimetric_critic/losses/__init__.py
index fa8eba0..48ed700 100644
--- a/quasimetric_rl/modules/quasimetric_critic/losses/__init__.py
+++ b/quasimetric_rl/modules/quasimetric_critic/losses/__init__.py
@@ -31,7 +31,7 @@ class CriticLossBase(LossBase):
return super().__call__(data, critic_batch_info)
-from .global_push import GlobalPushLoss, GlobalPushLinearLoss, GlobalPushLogLoss, GlobalPushRBFLoss
+from .global_push import GlobalPushLoss, GlobalPushLinearLoss, GlobalPushLogLoss, GlobalPushRBFLoss, GlobalPushNextMSELoss
from .local_constraint import LocalConstraintLoss
from .latent_dynamics import LatentDynamicsLoss
@@ -41,6 +41,7 @@ class QuasimetricCriticLosses(CriticLossBase):
class Conf:
global_push: GlobalPushLoss.Conf = GlobalPushLoss.Conf()
global_push_linear: GlobalPushLinearLoss.Conf = GlobalPushLinearLoss.Conf()
+ global_push_next_mse: GlobalPushNextMSELoss.Conf = GlobalPushNextMSELoss.Conf()
global_push_log: GlobalPushLogLoss.Conf = GlobalPushLogLoss.Conf()
global_push_rbf: GlobalPushRBFLoss.Conf = GlobalPushRBFLoss.Conf()
local_constraint: LocalConstraintLoss.Conf = LocalConstraintLoss.Conf()
@@ -54,7 +55,10 @@ class QuasimetricCriticLosses(CriticLossBase):
local_lagrange_mult_optim: AdamWSpec.Conf = AdamWSpec.Conf(lr=1e-2)
dynamics_lagrange_mult_optim: AdamWSpec.Conf = AdamWSpec.Conf(lr=0)
- scale_with_best_local_fit: bool = False
+ quasimetric_scale: Optional[str] = attrs.field(
+ default=None, validator=attrs.validators.optional(attrs.validators.in_(
+ ['best_local_fit', 'best_local_fit_clip5', 'best_local_fit_clip10',
+ 'best_local_fit_detach']))) # type: ignore
def make(self, critic: QuasimetricCritic, total_optim_steps: int,
share_embedding_from: Optional[QuasimetricCritic] = None) -> 'QuasimetricCriticLosses':
@@ -65,6 +69,7 @@ class QuasimetricCriticLosses(CriticLossBase):
# global losses
global_push=self.global_push.make(),
global_push_linear=self.global_push_linear.make(),
+ global_push_next_mse=self.global_push_next_mse.make(),
global_push_log=self.global_push_log.make(),
global_push_rbf=self.global_push_rbf.make(),
# local loss
@@ -81,12 +86,13 @@ class QuasimetricCriticLosses(CriticLossBase):
local_lagrange_mult_optim_spec=self.local_lagrange_mult_optim.make(),
dynamics_lagrange_mult_optim_spec=self.dynamics_lagrange_mult_optim.make(),
#
- scale_with_best_local_fit=self.scale_with_best_local_fit,
+ quasimetric_scale=self.quasimetric_scale,
)
borrowing_embedding: bool
global_push: Optional[GlobalPushLoss]
global_push_linear: Optional[GlobalPushLinearLoss]
+ global_push_next_mse: Optional[GlobalPushNextMSELoss]
global_push_log: Optional[GlobalPushLogLoss]
global_push_rbf: Optional[GlobalPushRBFLoss]
local_constraint: Optional[LocalConstraintLoss]
@@ -98,12 +104,13 @@ class QuasimetricCriticLosses(CriticLossBase):
local_lagrange_mult_sched: LRScheduler
dynamics_lagrange_mult_optim: OptimWrapper
dynamics_lagrange_mult_sched: LRScheduler
- scale_with_best_local_fit: bool
+ quasimetric_scale: Optional[str]
def __init__(self, critic: QuasimetricCritic, *, total_optim_steps: int,
share_embedding_from: Optional[QuasimetricCritic] = None,
global_push: Optional[GlobalPushLoss], global_push_linear: Optional[GlobalPushLinearLoss],
- global_push_log: Optional[GlobalPushLogLoss], global_push_rbf: Optional[GlobalPushRBFLoss],
+ global_push_next_mse: Optional[GlobalPushNextMSELoss], global_push_log: Optional[GlobalPushLogLoss],
+ global_push_rbf: Optional[GlobalPushRBFLoss],
local_constraint: Optional[LocalConstraintLoss], latent_dynamics: LatentDynamicsLoss,
critic_optim_spec: AdamWSpec,
latent_dynamics_lr_mul: float,
@@ -112,7 +119,7 @@ class QuasimetricCriticLosses(CriticLossBase):
quasimetric_head_lr_mul: float,
local_lagrange_mult_optim_spec: AdamWSpec,
dynamics_lagrange_mult_optim_spec: AdamWSpec,
- scale_with_best_local_fit: bool):
+ quasimetric_scale: Optional[str]):
super().__init__()
self.borrowing_embedding = share_embedding_from is not None
if self.borrowing_embedding:
@@ -123,6 +130,7 @@ class QuasimetricCriticLosses(CriticLossBase):
local_constraint = None
self.add_module('global_push', global_push)
self.add_module('global_push_linear', global_push_linear)
+ self.add_module('global_push_next_mse', global_push_next_mse)
self.add_module('global_push_log', global_push_log)
self.add_module('global_push_rbf', global_push_rbf)
self.add_module('local_constraint', local_constraint)
@@ -147,7 +155,7 @@ class QuasimetricCriticLosses(CriticLossBase):
self.dynamics_lagrange_mult_optim, self.dynamics_lagrange_mult_sched = dynamics_lagrange_mult_optim_spec.create_optim_scheduler(
latent_dynamics.parameters(), total_optim_steps)
assert len(list(latent_dynamics.parameters())) == 1
- self.scale_with_best_local_fit = scale_with_best_local_fit
+ self.quasimetric_scale = quasimetric_scale
def optimizers(self) -> Iterable[OptimWrapper]:
return [self.critic_optim, self.local_lagrange_mult_optim, self.dynamics_lagrange_mult_optim]
@@ -155,23 +163,34 @@ class QuasimetricCriticLosses(CriticLossBase):
def schedulers(self) -> Iterable[LRScheduler]:
return [self.critic_sched, self.local_lagrange_mult_sched, self.dynamics_lagrange_mult_sched]
- @torch.no_grad()
- def compute_best_quasimetric_scale(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> torch.Tensor:
+ def compute_best_quasimetric_scale(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> Tuple[torch.Tensor, torch.Tensor]:
assert self.local_constraint is not None and not self.borrowing_embedding
+ critic_batch_info.critic.quasimetric_model.quasimetric_head.scale.detach_().fill_(1) # reset
dist = critic_batch_info.critic.quasimetric_model(critic_batch_info.zx, critic_batch_info.zy)
- return (self.local_constraint.step_cost * (dist.mean() / dist.square().mean().clamp_min_(1e-8))).detach().clamp_(1e-3, 1e3)
+ return dist, (self.local_constraint.step_cost * (dist.mean() / dist.square().mean().clamp_min(1e-12))) # .detach().clamp_(1e-1, 1e1)
def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult:
extra_info: Dict[str, torch.Tensor] = {}
- if self.scale_with_best_local_fit and not self.borrowing_embedding:
- scale = extra_info['quasimetric_autoscale'] = self.compute_best_quasimetric_scale(data, critic_batch_info)
- critic_batch_info.critic.quasimetric_model.quasimetric_head.scale.copy_(scale)
+ if self.quasimetric_scale is not None and not self.borrowing_embedding:
+ unscaled_dist, scale = self.compute_best_quasimetric_scale(data, critic_batch_info)
+ assert scale.grad_fn is not None # allow bp
+ if self.quasimetric_scale == 'best_local_fit_detach':
+ scale = scale.detach()
+ elif self.quasimetric_scale == 'best_local_fit_clip5':
+ scale = scale.clamp(1 / 5, 5)
+ elif self.quasimetric_scale == 'best_local_fit_clip10':
+ scale = scale.clamp(1 / 10, 10)
+ extra_info['unscaled_dist'] = unscaled_dist
+ extra_info['quasimetric_autoscale'] = scale
+ critic_batch_info.critic.quasimetric_model.quasimetric_head.scale = scale
loss_results: Dict[str, LossResult] = {}
if self.global_push is not None:
loss_results.update(global_push=self.global_push(data, critic_batch_info))
if self.global_push_linear is not None:
loss_results.update(global_push_linear=self.global_push_linear(data, critic_batch_info))
+ if self.global_push_next_mse is not None:
+ loss_results.update(global_push_next_mse=self.global_push_next_mse(data, critic_batch_info))
if self.global_push_log is not None:
loss_results.update(global_push_log=self.global_push_log(data, critic_batch_info))
if self.global_push_rbf is not None:
@@ -189,4 +208,4 @@ class QuasimetricCriticLosses(CriticLossBase):
return torch.nn.Module.__call__(self, data, critic_batch_info)
def extra_repr(self) -> str:
- return f"borrowing_embedding={self.borrowing_embedding}"
+ return f"borrowing_embedding={self.borrowing_embedding}, quasimetric_scale={self.quasimetric_scale!r}"
diff --git a/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py b/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py
index 7c3f9f3..e469f13 100644
--- a/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py
+++ b/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py
@@ -53,16 +53,95 @@ from . import CriticLossBase, CriticBatchInfo
# return f"weight={self.weight:g}, softplus_beta={self.softplus_beta:g}, softplus_offset={self.softplus_offset:g}"
-
-class GlobalPushLoss(CriticLossBase):
+class GlobalPushLossBase(CriticLossBase):
@attrs.define(kw_only=True)
- class Conf:
+ class Conf(abc.ABC):
# config / argparse uses this to specify behavior
enabled: bool = True
detach_goal: bool = False
detach_proj_goal: bool = False
+ detach_qmet: bool = False
+ step_cost: float = attrs.field(default=1., validator=attrs.validators.gt(0))
weight: float = attrs.field(default=1., validator=attrs.validators.gt(0))
+ weight_future_goal: float = attrs.field(default=0., validator=attrs.validators.ge(0))
+ clamp_max_future_goal: bool = True
+
+ @abc.abstractmethod
+ def make(self) -> Optional['GlobalPushLossBase']:
+ if not self.enabled:
+ return None
+
+ weight: float
+ weight_future_goal: float
+ detach_goal: bool
+ detach_proj_goal: bool
+ detach_qmet: bool
+ step_cost: float
+ clamp_max_future_goal: bool
+
+ def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool,
+ detach_qmet: bool, step_cost: float, clamp_max_future_goal: bool):
+ super().__init__()
+ self.weight = weight
+ self.weight_future_goal = weight_future_goal
+ self.detach_goal = detach_goal
+ self.detach_proj_goal = detach_proj_goal
+ self.detach_qmet = detach_qmet
+ self.step_cost = step_cost
+ self.clamp_max_future_goal = clamp_max_future_goal
+
+ def generate_dist_weight(self, data: BatchData, critic_batch_info: CriticBatchInfo):
+ def get_dist(za: torch.Tensor, zb: torch.Tensor):
+ if self.detach_goal:
+ zb = zb.detach()
+ with critic_batch_info.critic.quasimetric_model.requiring_grad(not self.detach_qmet):
+ return critic_batch_info.critic.quasimetric_model(za, zb, proj_grad_enabled=(True, not self.detach_proj_goal))
+
+ # To randomly pair zx, zy, we just roll over zy by 1, because zx and zy
+ # are latents of randomly ordered random batches.
+ zgoal = torch.roll(critic_batch_info.zy, 1, dims=0)
+ yield (
+ 'random_goal',
+ zgoal,
+ get_dist(critic_batch_info.zx, zgoal),
+ self.weight,
+ )
+ if self.weight_future_goal > 0:
+ zgoal = critic_batch_info.critic.encoder(data.future_observations)
+ dist = get_dist(critic_batch_info.zx, zgoal)
+ if self.clamp_max_future_goal:
+ dist = dist.clamp_max(self.step_cost * data.future_tdelta)
+ yield (
+ 'future_goal',
+ zgoal,
+ dist,
+ self.weight_future_goal,
+ )
+
+ @abc.abstractmethod
+ def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor, dist: torch.Tensor, weight: float) -> LossResult:
+ raise NotImplementedError
+
+ def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult:
+ return LossResult.combine(
+ {
+ name: self.compute_loss(data, critic_batch_info, zgoal, dist, weight)
+ for name, zgoal, dist, weight in self.generate_dist_weight(data, critic_batch_info)
+ },
+ )
+
+ def extra_repr(self) -> str:
+ return '\n'.join([
+ f"weight={self.weight:g}, detach_proj_goal={self.detach_proj_goal}",
+ f"weight_future_goal={self.weight_future_goal:g}, detach_qmet={self.detach_qmet}",
+ f"step_cost={self.step_cost:g}, clamp_max_future_goal={self.clamp_max_future_goal}",
+ ])
+
+
+class GlobalPushLoss(GlobalPushLossBase):
+ @attrs.define(kw_only=True)
+ class Conf(GlobalPushLossBase.Conf):
# smaller => smoother loss
softplus_beta: float = attrs.field(default=0.1, validator=attrs.validators.gt(0))
@@ -75,99 +154,161 @@ class GlobalPushLoss(CriticLossBase):
return None
return GlobalPushLoss(
weight=self.weight,
+ weight_future_goal=self.weight_future_goal,
detach_goal=self.detach_goal,
detach_proj_goal=self.detach_proj_goal,
+ detach_qmet=self.detach_qmet,
+ step_cost=self.step_cost,
+ clamp_max_future_goal=self.clamp_max_future_goal,
softplus_beta=self.softplus_beta,
softplus_offset=self.softplus_offset,
)
- weight: float
- detach_goal: bool
- detach_proj_goal: bool
softplus_beta: float
softplus_offset: float
- def __init__(self, *, weight: float, detach_goal: bool, detach_proj_goal: bool,
+ def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, detach_qmet: bool, step_cost: float, clamp_max_future_goal: bool,
softplus_beta: float, softplus_offset: float):
- super().__init__()
- self.weight = weight
- self.detach_goal = detach_goal
- self.detach_proj_goal = detach_proj_goal
+ super().__init__(weight=weight, weight_future_goal=weight_future_goal, detach_goal=detach_goal,
+ detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, clamp_max_future_goal=clamp_max_future_goal)
self.softplus_beta = softplus_beta
self.softplus_offset = softplus_offset
- def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult:
- # To randomly pair zx, zy, we just roll over zy by 1, because zx and zy
- # are latents of randomly ordered random batches.
- zgoal = torch.roll(critic_batch_info.zy, 1, dims=0)
- if self.detach_goal:
- zgoal = zgoal.detach()
- dists = critic_batch_info.critic.quasimetric_model(critic_batch_info.zx, zgoal,
- proj_grad_enabled=(True, not self.detach_proj_goal))
+ def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor, dist: torch.Tensor, weight: float) -> LossResult:
# Sec 3.2. Transform so that we penalize large distances less.
- tsfm_dist: torch.Tensor = F.softplus(self.softplus_offset - dists, beta=self.softplus_beta) # type: ignore
+ tsfm_dist: torch.Tensor = F.softplus(self.softplus_offset - dist, beta=self.softplus_beta) # type: ignore
tsfm_dist = tsfm_dist.mean()
- return LossResult(loss=tsfm_dist * self.weight, info=dict(dist=dists.mean(), tsfm_dist=tsfm_dist)) # type: ignore
+ return LossResult(loss=tsfm_dist * weight, info=dict(dist=dist.mean(), tsfm_dist=tsfm_dist))
def extra_repr(self) -> str:
- return f"weight={self.weight:g}, detach_proj_goal={self.detach_proj_goal}, softplus_beta={self.softplus_beta:g}, softplus_offset={self.softplus_offset:g}"
-
+ return '\n'.join([
+ super().extra_repr(),
+ f"softplus_beta={self.softplus_beta:g}, softplus_offset={self.softplus_offset:g}",
+ ])
-class GlobalPushLinearLoss(CriticLossBase):
+class GlobalPushLinearLoss(GlobalPushLossBase):
@attrs.define(kw_only=True)
- class Conf:
- # config / argparse uses this to specify behavior
-
+ class Conf(GlobalPushLossBase.Conf):
enabled: bool = False
- detach_goal: bool = False
- detach_proj_goal: bool = False
- weight: float = attrs.field(default=1., validator=attrs.validators.gt(0))
+
+ clamp_max: Optional[float] = attrs.field(default=None, validator=attrs.validators.optional(attrs.validators.gt(0)))
def make(self) -> Optional['GlobalPushLinearLoss']:
if not self.enabled:
return None
return GlobalPushLinearLoss(
weight=self.weight,
+ weight_future_goal=self.weight_future_goal,
detach_goal=self.detach_goal,
detach_proj_goal=self.detach_proj_goal,
+ detach_qmet=self.detach_qmet,
+ step_cost=self.step_cost,
+ clamp_max_future_goal=self.clamp_max_future_goal,
+ clamp_max=self.clamp_max,
)
- weight: float
- detach_goal: bool
- detach_proj_goal: bool
-
- def __init__(self, *, weight: float, detach_goal: bool, detach_proj_goal: bool):
- super().__init__()
- self.weight = weight
- self.detach_goal = detach_goal
- self.detach_proj_goal = detach_proj_goal
-
- def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult:
- # To randomly pair zx, zy, we just roll over zy by 1, because zx and zy
- # are latents of randomly ordered random batches.
- zgoal = torch.roll(critic_batch_info.zy, 1, dims=0)
- if self.detach_goal:
- zgoal = zgoal.detach()
- dists = critic_batch_info.critic.quasimetric_model(critic_batch_info.zx, zgoal,
- proj_grad_enabled=(True, not self.detach_proj_goal))
- dists = dists.mean()
- return LossResult(loss=dists * (-self.weight), info=dict(dist=dists)) # type: ignore
+ clamp_max: Optional[float]
+
+ def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, detach_qmet: bool, step_cost: float, clamp_max_future_goal: bool,
+ clamp_max: Optional[float]):
+ super().__init__(weight=weight, weight_future_goal=weight_future_goal, detach_goal=detach_goal,
+ detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, clamp_max_future_goal=clamp_max_future_goal)
+ self.clamp_max = clamp_max
+
+ def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor, dist: torch.Tensor, weight: float) -> LossResult:
+ info: Dict[str, torch.Tensor]
+ if self.clamp_max is None:
+ dist = dist.mean()
+ info = dict(dist=dist)
+ neg_loss = dist
+ else:
+ info = dict(dist=dist.mean())
+ tsfm_dist = dist.clamp_max(self.clamp_max)
+ info.update(
+ tsfm_dist=tsfm_dist.mean(),
+ exceed_rate=(dist >= self.clamp_max).mean(dtype=torch.float32),
+ )
+ neg_loss = tsfm_dist
+ return LossResult(loss=neg_loss * weight, info=info)
def extra_repr(self) -> str:
- return f"weight={self.weight:g}, detach_proj_goal={self.detach_proj_goal}"
+ return '\n'.join([
+ super().extra_repr(),
+ f"clamp_max={self.clamp_max!r}",
+ ])
-class GlobalPushLogLoss(CriticLossBase):
+
+class GlobalPushNextMSELoss(GlobalPushLossBase):
@attrs.define(kw_only=True)
- class Conf:
- # config / argparse uses this to specify behavior
+ class Conf(GlobalPushLossBase.Conf):
+ enabled: bool = False
+ detach_target_dist: bool = True
+ allow_gt: bool = False
+ gamma: Optional[float] = attrs.field(
+ default=None, validator=attrs.validators.optional(attrs.validators.and_(
+ attrs.validators.gt(0),
+ attrs.validators.lt(1),
+ )))
+
+ def make(self) -> Optional['GlobalPushNextMSELoss']:
+ if not self.enabled:
+ return None
+ return GlobalPushNextMSELoss(
+ weight=self.weight,
+ weight_future_goal=self.weight_future_goal,
+ detach_goal=self.detach_goal,
+ detach_proj_goal=self.detach_proj_goal,
+ detach_qmet=self.detach_qmet,
+ clamp_max_future_goal=self.clamp_max_future_goal,
+ step_cost=self.step_cost,
+ detach_target_dist=self.detach_target_dist,
+ allow_gt=self.allow_gt,
+ gamma=self.gamma,
+ )
+
+ def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, detach_qmet: bool, step_cost: float, clamp_max_future_goal: bool,
+ detach_target_dist: bool, allow_gt: bool, gamma: Optional[float]):
+ super().__init__(weight=weight, weight_future_goal=weight_future_goal, detach_goal=detach_goal,
+ detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, clamp_max_future_goal=clamp_max_future_goal)
+ self.detach_target_dist = detach_target_dist
+ self.allow_gt = allow_gt
+ self.gamma = gamma
+
+ def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor, dist: torch.Tensor, weight: float) -> LossResult:
+ with torch.enable_grad(self.detach_target_dist):
+ # by tri-eq, the actual cost can't be larger that step_cost + d(s', g)
+ next_dist = critic_batch_info.critic.quasimetric_model(
+ critic_batch_info.zy, zgoal, proj_grad_enabled=(True, not self.detach_proj_goal)
+ )
+ if self.detach_target_dist:
+ next_dist = next_dist.detach()
+ target_dist = self.step_cost + next_dist
+
+ if self.allow_gt:
+ dist = dist.clamp_max(target_dist)
+
+ if self.gamma is None:
+ loss = F.mse_loss(dist, target_dist)
+ else:
+ loss = F.mse_loss(self.gamma ** dist, self.gamma ** target_dist)
+
+ return LossResult(loss=loss * self.weight, info=dict(loss=loss, dist=dist.mean(), target_dist=target_dist.mean()))
+
+ def extra_repr(self) -> str:
+ return '\n'.join([
+ super().extra_repr(),
+ "detach_target_dist={self.detach_target_dist}",
+ ])
+
+class GlobalPushLogLoss(GlobalPushLossBase):
+ @attrs.define(kw_only=True)
+ class Conf(GlobalPushLossBase.Conf):
enabled: bool = False
- detach_goal: bool = False
- detach_proj_goal: bool = False
- weight: float = attrs.field(default=1., validator=attrs.validators.gt(0))
+
offset: float = attrs.field(default=1., validator=attrs.validators.gt(0))
def make(self) -> Optional['GlobalPushLogLoss']:
@@ -175,54 +316,43 @@ class GlobalPushLogLoss(CriticLossBase):
return None
return GlobalPushLogLoss(
weight=self.weight,
+ weight_future_goal=self.weight_future_goal,
detach_goal=self.detach_goal,
detach_proj_goal=self.detach_proj_goal,
+ detach_qmet=self.detach_qmet,
+ step_cost=self.step_cost,
+ clamp_max_future_goal=self.clamp_max_future_goal,
offset=self.offset,
)
- weight: float
- detach_goal: bool
- detach_proj_goal: bool
offset: float
- def __init__(self, *, weight: float, detach_goal: bool, detach_proj_goal: bool, offset: float):
- super().__init__()
- self.weight = weight
- self.detach_goal = detach_goal
- self.detach_proj_goal = detach_proj_goal
+ def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, detach_qmet: bool, step_cost: float, clamp_max_future_goal: bool,
+ offset: float):
+ super().__init__(weight=weight, weight_future_goal=weight_future_goal, detach_goal=detach_goal,
+ detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, clamp_max_future_goal=clamp_max_future_goal)
self.offset = offset
- def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult:
- # To randomly pair zx, zy, we just roll over zy by 1, because zx and zy
- # are latents of randomly ordered random batches.
- zgoal = torch.roll(critic_batch_info.zy, 1, dims=0)
- if self.detach_goal:
- zgoal = zgoal.detach()
- dists = critic_batch_info.critic.quasimetric_model(critic_batch_info.zx, zgoal,
- proj_grad_enabled=(True, not self.detach_proj_goal))
- # Sec 3.2. Transform so that we penalize large distances less.
- tsfm_dist: torch.Tensor = -dists.add(self.offset).log()
+ def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor, dist: torch.Tensor, weight: float) -> LossResult:
+ tsfm_dist: torch.Tensor = -dist.add(self.offset).log()
tsfm_dist = tsfm_dist.mean()
- return LossResult(loss=tsfm_dist * self.weight, info=dict(dist=dists.mean(), tsfm_dist=tsfm_dist)) # type: ignore
+ return LossResult(loss=tsfm_dist * weight, info=dict(dist=dist.mean(), tsfm_dist=tsfm_dist))
def extra_repr(self) -> str:
- return f"weight={self.weight:g}, detach_proj_goal={self.detach_proj_goal}, offset={self.offset:g}"
-
+ return '\n'.join([
+ super().extra_repr(),
+ f"offset={self.offset:g}",
+ ])
-class GlobalPushRBFLoss(CriticLossBase):
+class GlobalPushRBFLoss(GlobalPushLossBase):
# say E[opt T] approx sqrt(2)/2 timeout, so E[opt T^2] approx 1/2 timeout^2
# to emulate log E exp(-2 d^2), where 2 d^2 is around 4, we scale model T with r, and use log E exp(- r^2 T^2), and let r^2 T^2 to be around 4
# so r^2 approx 8 / timeout^2 and r approx 2.82 / timeout. If timeout = 850, this = 300
@attrs.define(kw_only=True)
- class Conf:
- # config / argparse uses this to specify behavior
-
+ class Conf(GlobalPushLossBase.Conf):
enabled: bool = False
- detach_goal: bool = False
- detach_proj_goal: bool = False
- weight: float = attrs.field(default=1., validator=attrs.validators.gt(0))
inv_scale: float = attrs.field(default=300., validator=attrs.validators.ge(1e-3))
@@ -231,37 +361,33 @@ class GlobalPushRBFLoss(CriticLossBase):
return None
return GlobalPushRBFLoss(
weight=self.weight,
+ weight_future_goal=self.weight_future_goal,
detach_goal=self.detach_goal,
detach_proj_goal=self.detach_proj_goal,
+ detach_qmet=self.detach_qmet,
+ step_cost=self.step_cost,
+ clamp_max_future_goal=self.clamp_max_future_goal,
inv_scale=self.inv_scale,
)
- weight: float
- detach_goal: bool
- detach_proj_goal: bool
inv_scale: float
- def __init__(self, *, weight: float, detach_goal: bool, detach_proj_goal: bool, inv_scale: float):
- super().__init__()
- self.weight = weight
- self.detach_goal = detach_goal
- self.detach_proj_goal = detach_proj_goal
+ def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, detach_qmet: bool, step_cost: float, clamp_max_future_goal: bool,
+ inv_scale: float):
+ super().__init__(weight=weight, weight_future_goal=weight_future_goal, detach_goal=detach_goal,
+ detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, clamp_max_future_goal=clamp_max_future_goal)
self.inv_scale = inv_scale
- def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult:
- # To randomly pair zx, zy, we just roll over zy by 1, because zx and zy
- # are latents of randomly ordered random batches.
- zgoal = torch.roll(critic_batch_info.zy, 1, dims=0)
- if self.detach_goal:
- zgoal = zgoal.detach()
- dists = critic_batch_info.critic.quasimetric_model(critic_batch_info.zx, zgoal,
- proj_grad_enabled=(True, not self.detach_proj_goal))
- inv_scale = dists.detach().square().mean().div(2).sqrt().clamp(1e-3, self.inv_scale) # make E[d^2]/r^2 approx 2
- tsfm_dist: torch.Tensor = (dists / inv_scale).square().neg().exp()
+ def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor, dist: torch.Tensor, weight: float) -> LossResult:
+ inv_scale = dist.detach().square().mean().div(2).sqrt().clamp(1e-3, self.inv_scale) # make E[d^2]/r^2 approx 2
+ tsfm_dist: torch.Tensor = (dist / inv_scale).square().neg().exp()
rbf_potential = tsfm_dist.mean().log()
return LossResult(loss=rbf_potential * self.weight,
- info=dict(dist=dists.mean(), inv_scale=inv_scale,
+ info=dict(dist=dist.mean(), inv_scale=inv_scale,
tsfm_dist=tsfm_dist, rbf_potential=rbf_potential)) # type: ignore
def extra_repr(self) -> str:
- return f"weight={self.weight:g}, detach_proj_goal={self.detach_proj_goal}, inv_scale={self.inv_scale:g}"
+ return '\n'.join([
+ super().extra_repr(),
+ f"inv_scale={self.inv_scale:g}",
+ ])
diff --git a/quasimetric_rl/modules/quasimetric_critic/losses/latent_dynamics.py b/quasimetric_rl/modules/quasimetric_critic/losses/latent_dynamics.py
index 2701e6f..f5e8248 100644
--- a/quasimetric_rl/modules/quasimetric_critic/losses/latent_dynamics.py
+++ b/quasimetric_rl/modules/quasimetric_critic/losses/latent_dynamics.py
@@ -35,6 +35,7 @@ class LatentDynamicsLoss(CriticLossBase):
detach_sp: bool = False
detach_proj_sp: bool = False
detach_qmet: bool = False
+ non_quasimetric_dim_mse_weight: float = attrs.field(default=0., validator=attrs.validators.ge(0))
def make(self) -> 'LatentDynamicsLoss':
return LatentDynamicsLoss(
@@ -45,6 +46,7 @@ class LatentDynamicsLoss(CriticLossBase):
detach_sp=self.detach_sp,
detach_proj_sp=self.detach_proj_sp,
detach_qmet=self.detach_qmet,
+ non_quasimetric_dim_mse_weight=self.non_quasimetric_dim_mse_weight,
init_lagrange_multiplier=self.init_lagrange_multiplier,
)
@@ -55,13 +57,15 @@ class LatentDynamicsLoss(CriticLossBase):
detach_sp: bool
detach_proj_sp: bool
detach_qmet: bool
+ non_quasimetric_dim_mse_weight: float
c: float
init_lagrange_multiplier: float
def __init__(self, *, epsilon: float, bidirectional: bool,
gamma: Optional[float], init_lagrange_multiplier: float,
# weight: float,
- detach_qmet: bool, detach_proj_sp: bool, detach_sp: bool):
+ detach_qmet: bool, detach_proj_sp: bool, detach_sp: bool,
+ non_quasimetric_dim_mse_weight: float):
super().__init__()
# self.weight = weight
self.epsilon = epsilon
@@ -70,6 +74,7 @@ class LatentDynamicsLoss(CriticLossBase):
self.detach_qmet = detach_qmet
self.detach_sp = detach_sp
self.detach_proj_sp = detach_proj_sp
+ self.non_quasimetric_dim_mse_weight = non_quasimetric_dim_mse_weight
self.init_lagrange_multiplier = init_lagrange_multiplier
self.raw_lagrange_multiplier = nn.Parameter(
torch.tensor(softplus_inv_float(init_lagrange_multiplier), dtype=torch.float32))
@@ -85,6 +90,7 @@ class LatentDynamicsLoss(CriticLossBase):
dists = critic.quasimetric_model(zy, pred_zy, bidirectional=self.bidirectional, # at least optimize d(s', hat{s'})
proj_grad_enabled=(self.detach_proj_sp, True))
+
lagrange_mult = F.softplus(self.raw_lagrange_multiplier) # make positive
# lagrange multiplier is minimax training, so grad_mul -1
lagrange_mult = grad_mul(lagrange_mult, -1)
@@ -101,6 +107,16 @@ class LatentDynamicsLoss(CriticLossBase):
loss = violation * lagrange_mult
info.update(violation=violation, lagrange_mult=lagrange_mult)
+ if self.non_quasimetric_dim_mse_weight > 0:
+ assert critic.quasimetric_model.input_slice_size < critic.quasimetric_model.input_size, \
+ "non-quasimetric dim mse only makes sense if input_slice_size < input_size, but got " \
+ f"{critic.quasimetric_model.input_slice_size} >= {critic.quasimetric_model.input_size}"
+ _zy = zy[..., critic.quasimetric_model.input_slice_size:]
+ _pred_zy = pred_zy[..., critic.quasimetric_model.input_slice_size:]
+ non_quasimetric_dim_mse = F.mse_loss(_pred_zy, _zy)
+ info.update(non_quasimetric_dim_mse=non_quasimetric_dim_mse)
+ loss += non_quasimetric_dim_mse * self.non_quasimetric_dim_mse_weight
+
if self.bidirectional:
dist_p2n, dist_n2p = dists.flatten(0, -2).mean(0).unbind(-1)
info.update(dist_p2n=dist_p2n, dist_n2p=dist_n2p)
diff --git a/quasimetric_rl/modules/quasimetric_critic/models/__init__.py b/quasimetric_rl/modules/quasimetric_critic/models/__init__.py
index f9e05c4..8789800 100644
--- a/quasimetric_rl/modules/quasimetric_critic/models/__init__.py
+++ b/quasimetric_rl/modules/quasimetric_critic/models/__init__.py
@@ -91,6 +91,9 @@ class QuasimetricCritic(Module):
else:
return self.encoder
+ def get_encoder(self, target: bool = False) -> Encoder:
+ return self.target_encoder if target else self.encoder
+
@torch.no_grad()
def update_target_encoder_(self):
if not self.borrowing_embedding and self.target_encoder_ema is not None:
diff --git a/quasimetric_rl/modules/quasimetric_critic/models/quasimetric_model.py b/quasimetric_rl/modules/quasimetric_critic/models/quasimetric_model.py
index 5d19754..92fbc58 100644
--- a/quasimetric_rl/modules/quasimetric_critic/models/quasimetric_model.py
+++ b/quasimetric_rl/modules/quasimetric_critic/models/quasimetric_model.py
@@ -41,11 +41,11 @@ def create_quasimetric_head_from_spec(spec: str) -> torchqmet.QuasimetricBase:
assert dim % components == 0, "IQE: dim must be divisible by components"
return torchqmet.IQE(dim, dim // components)
- def iqe2(*, dim: int, components: int, scale: bool = False, norm_delta: bool = False, fake_grad: bool = False) -> torchqmet.IQE:
+ def iqe2(*, dim: int, components: int, norm_delta: bool = False, fake_grad: bool = False, reduction: str = 'maxl12_sm') -> torchqmet.IQE:
assert dim % components == 0, "IQE: dim must be divisible by components"
return torchqmet.IQE2(
dim, dim // components,
- reduction='maxl12_sm' if not scale else 'maxl12_sm_scale',
+ reduction=reduction,
learned_delta=True,
learned_div=False,
div_init_mul=0.25,
@@ -80,6 +80,7 @@ class QuasimetricModel(Module):
class Conf:
# config / argparse uses this to specify behavior
+ input_slice_size: Optional[int] = attrs.field(default=None, validator=attrs.validators.optional(attrs.validators.gt(0))) # take the first n dims
projector_arch: Optional[Tuple[int, ...]] = (512,)
projector_layer_norm: bool = True
projector_dropout: float = attrs.field(default=0., validator=attrs.validators.ge(0)) # TD-MPC2 uses 0.01
@@ -88,8 +89,11 @@ class QuasimetricModel(Module):
quasimetric_head_spec: str = 'iqe(dim=2048,components=64)'
def make(self, *, input_size: int) -> 'QuasimetricModel':
+ if self.input_slice_size is not None:
+ assert self.input_slice_size <= input_size, f'input_slice_size={self.input_slice_size} > input_size={input_size}'
return QuasimetricModel(
input_size=input_size,
+ input_slice_size=self.input_slice_size or input_size,
projector_arch=self.projector_arch,
projector_layer_norm=self.projector_layer_norm,
projector_dropout=self.projector_dropout,
@@ -99,22 +103,24 @@ class QuasimetricModel(Module):
)
input_size: int
+ input_slice_size: int
projector: Union[Identity, MLP]
quasimetric_head: torchqmet.QuasimetricBase
- def __init__(self, *, input_size: int, projector_arch: Optional[Tuple[int, ...]],
+ def __init__(self, *, input_size: int, input_slice_size: int, projector_arch: Optional[Tuple[int, ...]],
projector_layer_norm: bool, projector_dropout: float, projector_weight_norm: bool,
projector_unit_norm: bool, quasimetric_head_spec: str):
super().__init__()
self.input_size = input_size
+ self.input_slice_size = input_slice_size
self.quasimetric_head = create_quasimetric_head_from_spec(quasimetric_head_spec)
if projector_arch is None:
- assert input_size == self.quasimetric_head.input_size, \
- f'no projector but latent input_size={input_size}, quasimetric_head.input_size={self.quasimetric_head.input_size}'
+ assert input_slice_size == self.quasimetric_head.input_size, \
+ f'no projector but latent input_slice_size={input_slice_size}, quasimetric_head.input_size={self.quasimetric_head.input_size}'
self.projector = Identity()
else:
self.projector = MLP(
- input_size, self.quasimetric_head.input_size,
+ input_slice_size, self.quasimetric_head.input_size,
hidden_sizes=projector_arch,
layer_norm=projector_layer_norm,
dropout=projector_dropout,
@@ -123,6 +129,8 @@ class QuasimetricModel(Module):
def forward(self, zx: LatentTensor, zy: LatentTensor, *, bidirectional: bool = False,
proj_grad_enabled: Tuple[bool, bool] = (True, True)) -> torch.Tensor:
+ zx = zx[..., :self.input_slice_size]
+ zy = zy[..., :self.input_slice_size]
with self.projector.requiring_grad(proj_grad_enabled[0]):
px = self.projector(zx) # [B x D]
with self.projector.requiring_grad(proj_grad_enabled[1]):
@@ -149,4 +157,4 @@ class QuasimetricModel(Module):
return super().__call__(zx, zy, bidirectional=bidirectional, proj_grad_enabled=proj_grad_enabled)
def extra_repr(self) -> str:
- return f"input_size={self.input_size}"
+ return f"input_size={self.input_size}, input_slice_size={self.input_slice_size}"
diff --git a/quasimetric_rl/utils/logging.py b/quasimetric_rl/utils/logging.py
index 2a9e23f..6efab8d 100644
--- a/quasimetric_rl/utils/logging.py
+++ b/quasimetric_rl/utils/logging.py
@@ -4,6 +4,8 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
+from typing import *
+
import sys
import os
import logging
@@ -34,10 +36,12 @@ class TqdmLoggingHandler(logging.Handler):
class MultiLineFormatter(logging.Formatter):
+ _fmt: str
def __init__(self, fmt=None, datefmt=None, style='%'):
assert style == '%'
super(MultiLineFormatter, self).__init__(fmt, datefmt, style)
+ assert fmt is not None
self.multiline_fmt = fmt
def format(self, record):
@@ -75,7 +79,7 @@ class MultiLineFormatter(logging.Formatter):
output += '\n'.join(
self.multiline_fmt % dict(record.__dict__, message=line)
for index, line
- in enumerate(record.exc_text.decode(sys.getfilesystemencoding(), 'replace').splitlines())
+ in enumerate(record.exc_text.decode(sys.getfilesystemencoding(), 'replace').splitlines()) # type: ignore
)
return output
@@ -96,7 +100,7 @@ def configure(logging_file, log_level=logging.INFO, level_prefix='', prefix='',
sys.excepthook = handle_exception # automatically log uncaught errors
- handlers = []
+ handlers: List[logging.Handler] = []
if write_to_stdout:
handlers.append(TqdmLoggingHandler())
Submodule third_party/torch-quasimetric c5213ff..0fce12e:
diff --git a/third_party/torch-quasimetric/torchqmet/__init__.py b/third_party/torch-quasimetric/torchqmet/__init__.py
index 0008b7a..3afb467 100644
--- a/third_party/torch-quasimetric/torchqmet/__init__.py
+++ b/third_party/torch-quasimetric/torchqmet/__init__.py
@@ -58,7 +58,10 @@ class QuasimetricBase(nn.Module, metaclass=abc.ABCMeta):
assert x.shape[-1] == y.shape[-1] == self.input_size
d = self.compute_components(x, y)
d: torch.Tensor = self.transforms(d)
- return self.reduction(d) * self.scale
+ scale = self.scale
+ if not self.training:
+ scale = scale.detach()
+ return self.reduction(d) * scale
def __call__(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
# Manually define for typing
diff --git a/third_party/torch-quasimetric/torchqmet/iqe.py b/third_party/torch-quasimetric/torchqmet/iqe.py
index bc03f05..a8e8c92 100644
--- a/third_party/torch-quasimetric/torchqmet/iqe.py
+++ b/third_party/torch-quasimetric/torchqmet/iqe.py
@@ -12,7 +12,6 @@ from . import QuasimetricBase
# The PQELH function.
-@torch.jit.script
def f_PQELH(h: torch.Tensor): # PQELH: strictly monotonically increasing mapping from [0, +infty) -> [0, 1)
return -torch.expm1(-h)
@@ -21,22 +20,22 @@ def iqe_tensor_delta(x: torch.Tensor, y: torch.Tensor, delta: torch.Tensor, div_
D = x.shape[-1] # D: component_dim
# ignore pairs that x >= y
- valid = (x < y)
+ valid = (x < y) # [..., K, D]
# sort to better count
- xy = torch.cat(torch.broadcast_tensors(x, y), dim=-1)
+ xy = torch.cat(torch.broadcast_tensors(x, y), dim=-1) # [..., K, 2D]
sxy, ixy = xy.sort(dim=-1)
# neg_inc: the **negated** increment of **input** of f at sorted locations
# inc = torch.gather(delta * valid, dim=-1, index=ixy % D) * torch.where(ixy < D, 1, -1)
- neg_inc = torch.gather(delta * valid, dim=-1, index=ixy % D) * torch.where(ixy < D, -1, 1)
+ neg_inc = torch.gather(delta * valid, dim=-1, index=ixy % D) * torch.where(ixy < D, -1, 1) # [..., K, 2D-sort]
# neg_incf: the **negated** increment of **output** of f at sorted locations
neg_f_input = torch.cumsum(neg_inc, dim=-1) / div_pre_f[:, None]
if fake_grad:
neg_f_input__grad_path = neg_f_input.clone()
- neg_f_input__grad_path.data.clamp_(max=17) # fake grad
+ neg_f_input__grad_path.data.clamp_(min=-15) # fake grad
neg_f_input = neg_f_input__grad_path + (
neg_f_input - neg_f_input__grad_path
).detach()
@@ -95,13 +94,29 @@ def iqe(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return (sxy * neg_incf).sum(-1)
-if torch.__version__ >= '2.0.1' and False: # well, broken process pool in notebooks
- iqe = torch.compile(iqe)
- iqe_tensor_delta = torch.compile(iqe_tensor_delta)
+def is_notebook():
+ r"""
+ Inspired by
+ https://github.com/tqdm/tqdm/blob/cc372d09dcd5a5eabdc6ed4cf365bdb0be004d44/tqdm/autonotebook.py
+ """
+ import sys
+ try:
+ get_ipython = sys.modules['IPython'].get_ipython
+ if 'IPKernelApp' not in get_ipython().config: # pragma: no cover
+ raise ImportError("console")
+ except Exception:
+ return False
+ else: # pragma: no cover
+ return True
+
+
+if torch.__version__ >= '2.0.1' and not is_notebook(): # well, broken process pool in notebooks
+ iqe = torch.compile(iqe, mode="max-autotune")
+ iqe_tensor_delta = torch.compile(iqe_tensor_delta, mode="max-autotune")
# iqe = torch.compile(iqe, dynamic=True)
else:
- iqe = torch.jit.script(iqe)
- iqe_tensor_delta = torch.jit.script(iqe_tensor_delta)
+ iqe = torch.jit.script(iqe) # type: ignore
+ iqe_tensor_delta = torch.jit.script(iqe_tensor_delta) # type: ignore
class IQE(QuasimetricBase):
@@ -231,8 +246,8 @@ class IQE2(IQE):
ema_weight: float = 0.95):
super().__init__(input_size, dim_per_component, transforms=transforms, reduction=reduction,
discount=discount, warn_if_not_quasimetric=warn_if_not_quasimetric)
- self.component_dropout_thresh = tuple(component_dropout_thresh)
- self.dropout_p_thresh = tuple(dropout_p_thresh)
+ self.component_dropout_thresh = tuple(component_dropout_thresh) # type: ignore
+ self.dropout_p_thresh = tuple(dropout_p_thresh) # type: ignore
self.dropout_batch_frac = float(dropout_batch_frac)
self.fake_grad = fake_grad
assert 0 <= self.dropout_batch_frac <= 1
@@ -249,7 +264,7 @@ class IQE2(IQE):
# )
self.register_parameter(
'raw_delta',
- torch.nn.Parameter(
+ torch.nn.Parameter( # type: ignore
torch.zeros(self.latent_2d_shape).requires_grad_()
)
)
@@ -270,7 +285,7 @@ class IQE2(IQE):
self.register_parameter(
'raw_div',
- torch.nn.Parameter(torch.zeros(self.num_components).requires_grad_())
+ torch.nn.Parameter(torch.zeros(self.num_components).requires_grad_()) # type: ignore
)
else:
self.register_buffer(
@@ -285,8 +300,8 @@ class IQE2(IQE):
self.div_init_mul = div_init_mul
self.mul_kind = mul_kind
- self.last_components = None
- self.last_drop_p = None
+ self.last_components = None # type: ignore
+ self.last_drop_p = None # type: ignore
def compute_components(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
diff --git a/third_party/torch-quasimetric/torchqmet/reductions.py b/third_party/torch-quasimetric/torchqmet/reductions.py
index 7681242..a87be8b 100644
--- a/third_party/torch-quasimetric/torchqmet/reductions.py
+++ b/third_party/torch-quasimetric/torchqmet/reductions.py
@@ -59,6 +59,11 @@ class Mean(ReductionBase):
return d.mean(dim=-1)
+class L2(ReductionBase):
+ def reduce_distance(self, d: torch.Tensor) -> torch.Tensor:
+ return d.norm(p=2, dim=-1)
+
+
class MaxMean(ReductionBase):
r'''
`maxmean` from Neural Norms paper:
@@ -144,7 +149,7 @@ class MaxL12_PGsm(ReductionBase):
super().__init__(input_num_components=input_num_components, discount=discount)
self.raw_alpha = nn.Parameter(torch.tensor([0., 0., 0., 0.], dtype=torch.float32).requires_grad_()) # pre normalizing
self.raw_alpha_w = nn.Parameter(torch.tensor([0., 0., 0.], dtype=torch.float32).requires_grad_()) # pre normalizing
- self.last_logp = None
+ self.last_logp = None # type: ignore
self.on_pi = True
# self.last_p = None
@@ -222,7 +227,7 @@ class MaxL12_PG3(ReductionBase):
super().__init__(input_num_components=input_num_components, discount=discount)
self.raw_alpha = nn.Parameter(torch.tensor([0., 0., 0.], dtype=torch.float32).requires_grad_()) # pre normalizing
self.raw_alpha_w = torch.tensor([], dtype=torch.float32) # just to make logging easier
- self.last_logp = None
+ self.last_logp = None # type: ignore
self.on_pi = True
# self.last_p = None
@@ -322,6 +327,7 @@ class DeepLinearNetWeightedSum(ReductionBase):
REDUCTIONS: Mapping[str, Type[ReductionBase]] = dict(
sum=Sum,
mean=Mean,
+ l2=L2,
maxmean=MaxMean,
maxl12=MaxL12,
maxl12_sm=MaxL12_sm,
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment