Skip to content

Instantly share code, notes, and snippets.

@ssnl
Created March 15, 2024 14:59
Show Gist options
  • Save ssnl/99a0f80224dd69026858007620532f87 to your computer and use it in GitHub Desktop.
Save ssnl/99a0f80224dd69026858007620532f87 to your computer and use it in GitHub Desktop.
diff --git a/offline/main.py b/offline/main.py
index 1581d30..5502749 100644
--- a/offline/main.py
+++ b/offline/main.py
@@ -43,7 +43,6 @@ class Conf(BaseConf):
total_optim_steps: int = attrs.field(default=int(2e5), validator=attrs.validators.gt(0))
log_steps: int = attrs.field(default=250, validator=attrs.validators.gt(0)) # type: ignore
- eval_before_training: bool = False
eval_steps: int = attrs.field(default=20000, validator=attrs.validators.gt(0)) # type: ignore
save_steps: int = attrs.field(default=50000, validator=attrs.validators.gt(0)) # type: ignore
num_eval_episodes: int = attrs.field(default=50, validator=attrs.validators.ge(0)) # type: ignore
@@ -58,7 +57,7 @@ cs.store(name='config', node=Conf())
@hydra.main(version_base=None, config_name="config")
def train(dict_cfg: DictConfig):
cfg: Conf = Conf.from_DictConfig(dict_cfg) # type: ignore
- wandb_run = cfg.setup_for_experiment() # checking & setup logging
+ cfg.setup_for_experiment() # checking & setup logging
dataset = cfg.env.make()
@@ -117,7 +116,7 @@ def train(dict_cfg: DictConfig):
logging.info(f"Checkpointed to {relpath}")
def eval(epoch, it, optim_steps):
- val_result_allenvs = trainer.evaluate(desc=f"opt{optim_steps:08d}")
+ val_result_allenvs = trainer.evaluate()
val_results.clear()
val_results.append(dict(
epoch=epoch,
@@ -129,10 +128,25 @@ def train(dict_cfg: DictConfig):
epoch=epoch,
it=it,
optim_steps=optim_steps,
- result={
- k: val_result.summarize() for k, val_result in val_result_allenvs.items()
- },
+ result={},
))
+ for k, val_result in val_result_allenvs.items():
+ succ_rate_ts = (
+ None if val_result.timestep_is_success is None
+ else torch.stack([_x.mean(dtype=torch.float32) for _x in val_result.timestep_is_success])
+ )
+ hitting_time = val_result.capped_hitting_time
+ summary = dict(
+ epi_return=val_result.episode_return,
+ epi_score=val_result.episode_score,
+ succ_rate_ts=succ_rate_ts,
+ succ_rate=val_result.is_success,
+ hitting_time=hitting_time,
+ )
+ for kk, v in val_result.extra_timestep_results.items():
+ summary[kk] = torch.stack([_v.mean(dtype=torch.float32) for _v in v])
+ summary[kk + '_last'] = torch.stack([_v[-1] for _v in v])
+ val_summaries[-1]['result'][k] = summary
averaged_info = cfg.stage_log_info(dict(eval=val_summaries[-1]), optim_steps)
with open(os.path.join(cfg.output_dir, 'eval.log'), 'a') as f: # type: ignore
print(json.dumps(averaged_info), file=f)
@@ -170,18 +184,15 @@ def train(dict_cfg: DictConfig):
)
num_total_epochs = int(np.ceil(cfg.total_optim_steps / trainer.num_batches))
- logging.info(f"save folder: {cfg.output_dir}")
- logging.info(f"wandb: {wandb_run.get_url()}")
# Training loop
optim_steps = 0
- if cfg.eval_before_training:
- eval(0, 0, optim_steps)
+ # eval(0, 0, optim_steps)
save(0, 0, optim_steps)
if start_epoch < num_total_epochs:
for epoch in range(num_total_epochs):
epoch_desc = f"Train epoch {epoch:05d}/{num_total_epochs:05d}"
- for it, (data, data_info) in enumerate(tqdm(trainer.iter_training_data(), total=trainer.num_batches, desc=epoch_desc, leave=False)):
+ for it, (data, data_info) in enumerate(tqdm(trainer.iter_training_data(), total=trainer.num_batches, desc=epoch_desc)):
step_counter.update_then_record_alerts()
optim_steps += 1
@@ -220,9 +231,4 @@ if __name__ == '__main__':
# set up some hydra flags before parsing
os.environ['HYDRA_FULL_ERROR'] = str(int(FLAGS.DEBUG))
- try:
- train()
- except:
- import wandb
- wandb.finish(1) # sometimes crashes are not reported?? let's be safe
- raise
+ train()
diff --git a/offline/trainer.py b/offline/trainer.py
index c75abed..e7f48cf 100644
--- a/offline/trainer.py
+++ b/offline/trainer.py
@@ -11,7 +11,52 @@ import torch
import torch.utils.data
from quasimetric_rl.modules import QRLConf, QRLAgent, QRLLosses, InfoT
-from quasimetric_rl.data import BatchData, Dataset, EpisodeData, EnvSpec, OfflineEnv, interaction
+from quasimetric_rl.data import BatchData, Dataset, EpisodeData, EnvSpec, OfflineEnv
+
+
+def first_nonzero(arr: torch.Tensor, dim: int = -1, invalid_val: int = -1):
+ mask = (arr != 0)
+ return torch.where(mask.any(dim=dim), mask.to(torch.uint8).argmax(dim=dim), invalid_val)
+
+
+@attrs.define(kw_only=True)
+class EvalEpisodeResult:
+ timestep_reward: List[torch.Tensor]
+ episode_return: torch.Tensor
+ episode_score: torch.Tensor
+ timestep_is_success: Optional[List[torch.Tensor]]
+ is_success: Optional[torch.Tensor]
+ hitting_time: Optional[torch.Tensor]
+ extra_timestep_results: Mapping[str, List[torch.Tensor]]
+
+ @property
+ def capped_hitting_time(self) -> Optional[torch.Tensor]:
+ # if not hit -> |ts| + 1
+ if self.hitting_time is None:
+ return None
+ assert self.timestep_is_success is not None
+ return torch.stack([torch.where(_x < 0, _succ.shape[0] + 1, _x) for _x, _succ in zip(self.hitting_time, self.timestep_is_success)])
+
+ @classmethod
+ def from_timestep_reward_is_success(cls, dataset: Dataset,
+ timestep_reward: List[torch.Tensor],
+ timestep_is_success: Optional[List[torch.Tensor]],
+ extra_timestep_results) -> Self:
+ return cls(
+ timestep_reward=timestep_reward,
+ episode_return=torch.stack([r.sum() for r in timestep_reward]),
+ episode_score=dataset.normalize_score(timestep_reward),
+ timestep_is_success=timestep_is_success,
+ is_success=(
+ None if timestep_is_success is None
+ else torch.stack([_x.any(dim=-1) for _x in timestep_is_success])
+ ),
+ hitting_time=(
+ None if timestep_is_success is None
+ else torch.stack([first_nonzero(_x, dim=-1) for _x in timestep_is_success])
+ ), # NB this is off by 1
+ extra_timestep_results=dict(extra_timestep_results),
+ )
class Trainer(object):
@@ -88,20 +133,32 @@ class Trainer(object):
adistn = self.agent.actor(obs[None].to(self.device), goal[None].to(self.device))
return adistn.mode.cpu().numpy()[0]
- rollout = interaction.collect_rollout(
+ rollout = Dataset.collect_rollout_general(
actor, env=env, env_spec=EnvSpec.from_env(env),
max_episode_length=env.max_episode_steps)
return rollout
- def evaluate(self, desc=None) -> Mapping[str, interaction.EvalEpisodeResult]:
+ def evaluate(self) -> Mapping[str, EvalEpisodeResult]:
envs = self.dataset.create_eval_envs(self.eval_seed)
- results: Dict[str, interaction.EvalEpisodeResult] = {}
+ results: Dict[str, EvalEpisodeResult] = {}
for k, env in envs.items():
rollouts: List[EpisodeData] = []
- this_desc = f'eval/{k}'
- if desc is not None:
- this_desc = f'{desc}/{this_desc}'
- for _ in tqdm(range(self.num_eval_episodes), desc=this_desc):
+ for _ in tqdm(range(self.num_eval_episodes), desc=f'eval/{k}'):
rollouts.append(self.collect_eval_rollout(env=env))
- results[k] = interaction.EvalEpisodeResult.from_episode_rollouts(self.dataset, rollouts)
+ results[k] = EvalEpisodeResult.from_timestep_reward_is_success(
+ self.dataset,
+ timestep_reward=[rollout.rewards for rollout in rollouts],
+ timestep_is_success=(
+ None
+ if len(rollouts) == 0 or 'is_success' not in rollouts[0].transition_infos
+ else [rollout.transition_infos['is_success'] for rollout in rollouts]
+ ),
+ extra_timestep_results=(
+ {} if len(rollouts) == 0 else
+ {
+ k: [rollout.transition_infos[k] for rollout in rollouts]
+ for k in rollouts[0].transition_infos.keys() if k != 'is_success'
+ }
+ ),
+ )
return results
diff --git a/online/main.py b/online/main.py
index c16e094..cd58f7a 100644
--- a/online/main.py
+++ b/online/main.py
@@ -50,7 +50,7 @@ cs.store(name='config', node=Conf())
@hydra.main(version_base=None, config_name="config")
def train(dict_cfg: DictConfig):
cfg: Conf = Conf.from_DictConfig(dict_cfg) # type: ignore
- wandb_run = cfg.setup_for_experiment() # checking & setup logging
+ cfg.setup_for_experiment() # checking & setup logging
replay_buffer = cfg.env.make()
@@ -85,7 +85,7 @@ def train(dict_cfg: DictConfig):
logging.info(f"Checkpointed to {relpath}")
def eval(env_steps, optim_steps):
- val_result_allenvs = trainer.evaluate(desc=f'env{env_steps:08d}_opt{optim_steps:08d}')
+ val_result_allenvs = trainer.evaluate()
val_results.clear()
val_results.append(dict(
env_steps=env_steps,
@@ -95,10 +95,18 @@ def train(dict_cfg: DictConfig):
val_summaries.append(dict(
env_steps=env_steps,
optim_steps=optim_steps,
- result={
- k: val_result.summarize() for k, val_result in val_result_allenvs.items()
- },
+ result={},
))
+ for k, val_result in val_result_allenvs.items():
+ succ_rate_ts = val_result.timestep_is_success.mean(dtype=torch.float32, dim=-1)
+ hitting_time = val_result.capped_hitting_time
+ val_summaries[-1]['result'][k] = dict(
+ epi_return=val_result.episode_return,
+ epi_score=val_result.episode_score,
+ succ_rate_ts=succ_rate_ts,
+ succ_rate=val_result.is_success,
+ hitting_time=hitting_time,
+ )
averaged_info = cfg.stage_log_info(dict(eval=val_summaries[-1]), optim_steps)
with open(os.path.join(cfg.output_dir, 'eval.log'), 'a') as f: # type: ignore
print(json.dumps(averaged_info), file=f)
@@ -113,8 +121,6 @@ def train(dict_cfg: DictConfig):
),
)
- logging.info(f"save folder: {cfg.output_dir}")
- logging.info(f"wandb: {wandb_run.get_url()}")
# Training loop
eval(0, 0); save(0, 0)
for optim_steps, (env_steps, next_iter_new_env_step, data, data_info) in enumerate(trainer.iter_training_data(), start=1):
@@ -156,9 +162,4 @@ if __name__ == '__main__':
# set up some hydra flags before parsing
os.environ['HYDRA_FULL_ERROR'] = str(int(FLAGS.DEBUG))
- try:
- train()
- except:
- import wandb
- wandb.finish(1) # sometimes crashes are not reported?? let's be safe
- raise
+ train()
diff --git a/online/trainer.py b/online/trainer.py
index 8926871..01faad8 100644
--- a/online/trainer.py
+++ b/online/trainer.py
@@ -12,11 +12,43 @@ import torch
import torch.utils.data
from quasimetric_rl.modules import QRLConf, QRLAgent, QRLLosses, InfoT
-from quasimetric_rl.data import BatchData, EpisodeData, interaction
+from quasimetric_rl.data import Dataset, BatchData, EpisodeData, MultiEpisodeData
from quasimetric_rl.data.online import ReplayBuffer, OnlineFixedLengthEnv
from quasimetric_rl.utils import tqdm
+def first_nonzero(arr: torch.Tensor, dim: int = -1, invalid_val: int = -1):
+ mask = (arr != 0)
+ return torch.where(mask.any(dim=dim), mask.to(torch.uint8).argmax(dim=dim), invalid_val)
+
+
+@attrs.define(kw_only=True)
+class EvalEpisodeResult:
+ timestep_reward: torch.Tensor
+ episode_return: torch.Tensor
+ episode_score: torch.Tensor
+ timestep_is_success: torch.Tensor
+ is_success: torch.Tensor
+ hitting_time: torch.Tensor
+
+ @property
+ def capped_hitting_time(self) -> torch.Tensor:
+ # if not hit -> |ts| + 1
+ return torch.stack([torch.where(_x < 0, self.timestep_is_success.shape[-1] + 1, _x) for _x in self.hitting_time])
+
+ @classmethod
+ def from_timestep_reward_is_success(cls, dataset: Dataset, timestep_reward: torch.Tensor,
+ timestep_is_success: torch.Tensor) -> Self:
+ return cls(
+ timestep_reward=timestep_reward,
+ episode_return=timestep_reward.sum(-1),
+ episode_score=dataset.normalize_score(cast(Sequence[torch.Tensor], timestep_reward)),
+ timestep_is_success=timestep_is_success,
+ is_success=timestep_is_success.any(dim=-1),
+ hitting_time=first_nonzero(timestep_is_success, dim=-1), # NB this is off by 1
+ )
+
+
@attrs.define(kw_only=True)
class InteractionConf:
total_env_steps: int = attrs.field(default=int(1e6), validator=attrs.validators.gt(0))
@@ -133,18 +165,23 @@ class Trainer(object):
self.replay.add_rollout(rollout)
return rollout
- def evaluate(self, desc=None) -> Mapping[str, interaction.EvalEpisodeResult]:
+ def evaluate(self) -> Mapping[str, EvalEpisodeResult]:
envs = self.make_evaluate_envs()
- results: Dict[str, interaction.EvalEpisodeResult] = {}
+ results: Dict[str, EvalEpisodeResult] = {}
for k, env in envs.items():
rollouts = []
- this_desc = f'eval/{k}'
- if desc is not None:
- this_desc = f'{desc}/{this_desc}'
- for _ in tqdm(range(self.num_eval_episodes), desc=this_desc):
+ for _ in tqdm(range(self.num_eval_episodes), desc=f'eval/{k}'):
rollouts.append(self.collect_rollout(eval=True, store=False, env=env))
- results[k] = interaction.EvalEpisodeResult.from_episode_rollouts(
- self.replay, rollouts)
+ mrollouts = MultiEpisodeData.cat(rollouts)
+ results[k] = EvalEpisodeResult.from_timestep_reward_is_success(
+ self.replay,
+ mrollouts.rewards.reshape(
+ self.num_eval_episodes, env.episode_length,
+ ),
+ mrollouts.transition_infos['is_success'].reshape(
+ self.num_eval_episodes, env.episode_length,
+ ),
+ )
return results
def iter_training_data(self) -> Iterator[Tuple[int, bool, BatchData, InfoT]]:
@@ -160,7 +197,7 @@ class Trainer(object):
"""
def yield_data():
num_transitions = self.replay.num_transitions_realized
- for icyc in tqdm(range(self.num_samples_per_cycle), desc=f"{num_transitions} env steps, train batches", leave=False):
+ for icyc in tqdm(range(self.num_samples_per_cycle), desc=f"{num_transitions} env steps, train batches"):
data_t0 = time.time()
data = self.sample()
info = dict(
diff --git a/quasimetric_rl/base_conf.py b/quasimetric_rl/base_conf.py
index e9f2202..b3a332e 100644
--- a/quasimetric_rl/base_conf.py
+++ b/quasimetric_rl/base_conf.py
@@ -172,7 +172,7 @@ class BaseConf(abc.ABC):
else:
raise RuntimeError(f'Output directory {self.output_dir} exists and is complete')
- run = wandb.init(
+ wandb.init(
project=self.wandb_project,
name=self.output_folder.replace('/', '__') + '__' + datetime.now().strftime(r"%Y%m%d_%H:%M:%S"),
config=yaml.safe_load(OmegaConf.to_yaml(self)),
@@ -219,6 +219,3 @@ class BaseConf(abc.ABC):
if self.device.type == 'cuda' and self.device.index is not None:
torch.cuda.set_device(self.device.index)
-
- assert run is not None
- return run
diff --git a/quasimetric_rl/data/__init__.py b/quasimetric_rl/data/__init__.py
index 519796b..41fc413 100644
--- a/quasimetric_rl/data/__init__.py
+++ b/quasimetric_rl/data/__init__.py
@@ -5,9 +5,8 @@ from .env_spec import EnvSpec
from . import online
from .online import register_online_env, OnlineFixedLengthEnv
from .offline import OfflineEnv
-from . import interaction
__all__ = [
'BatchData', 'EpisodeData', 'MultiEpisodeData', 'Dataset', 'register_offline_env',
- 'EnvSpec', 'online', 'register_online_env', 'OnlineFixedLengthEnv', 'OfflineEnv', 'interaction',
+ 'EnvSpec', 'online', 'register_online_env', 'OnlineFixedLengthEnv', 'OfflineEnv',
]
diff --git a/quasimetric_rl/data/base.py b/quasimetric_rl/data/base.py
index 09711c4..c86f525 100644
--- a/quasimetric_rl/data/base.py
+++ b/quasimetric_rl/data/base.py
@@ -36,7 +36,6 @@ class BatchData(TensorCollectionAttrsMixin): # TensorCollectionAttrsMixin has s
timeouts: torch.Tensor
future_observations: torch.Tensor # sampled!
- future_tdelta: torch.Tensor
@property
def device(self) -> torch.device:
@@ -144,12 +143,6 @@ class EpisodeData(MultiEpisodeData):
rewards=self.rewards[:t],
terminals=self.terminals[:t],
timeouts=self.timeouts[:t],
- observation_infos={
- k: v[:t + 1] for k, v in self.observation_infos.items()
- },
- transition_infos={
- k: v[:t] for k, v in self.transition_infos.items()
- },
)
@@ -242,7 +235,6 @@ class Dataset(torch.utils.data.Dataset):
indices_to_episode_timesteps: torch.Tensor
max_episode_length: int
# -----
- device: torch.device
def create_env(self, *, dict_obseravtion: Optional[bool] = None, seed: Optional[int] = None, **kwargs) -> 'OfflineEnv':
from .offline import OfflineEnv
@@ -280,7 +272,6 @@ class Dataset(torch.utils.data.Dataset):
def __init__(self, kind: str, name: str, *,
future_observation_discount: float,
dummy: bool = False, # when you don't want to load data, e.g., in analysis
- device: torch.device = torch.device('cpu'), # FIXME: get some heuristic
) -> None:
self.kind = kind
self.name = name
@@ -307,22 +298,97 @@ class Dataset(torch.utils.data.Dataset):
indices_to_episode_timesteps.append(torch.arange(l, dtype=torch.int64))
assert len(episodes) > 0, "must have at least one episode"
- self.raw_data = MultiEpisodeData.cat(episodes).to(device)
+ self.raw_data = MultiEpisodeData.cat(episodes)
- self.obs_indices_to_obs_index_in_episode = torch.cat(obs_indices_to_obs_index_in_episode, dim=0).to(device)
- self.indices_to_episode_indices = torch.cat(indices_to_episode_indices, dim=0).to(device)
- self.indices_to_episode_timesteps = torch.cat(indices_to_episode_timesteps, dim=0).to(device)
+ self.obs_indices_to_obs_index_in_episode = torch.cat(obs_indices_to_obs_index_in_episode, dim=0)
+ self.indices_to_episode_indices = torch.cat(indices_to_episode_indices, dim=0)
+ self.indices_to_episode_timesteps = torch.cat(indices_to_episode_timesteps, dim=0)
self.max_episode_length = int(self.raw_data.episode_lengths.max().item())
- self.device = device
-
- # def max_bytes_used(self):
- # return self
def get_observations(self, obs_indices: torch.Tensor):
- return self.raw_data.all_observations[obs_indices.to(self.device)]
+ return self.raw_data.all_observations[obs_indices]
+
+ @classmethod
+ def collect_rollout_general(cls, actor: Callable[[torch.Tensor, torch.Tensor, gym.Space], np.ndarray], *,
+ env: gym.Env, env_spec: EnvSpec, max_episode_length: int,
+ assert_exact_episode_length: bool = False, extra_transition_info_keys: Collection[str] = []) -> EpisodeData:
+ from .utils import get_empty_episode
+
+ epi = get_empty_episode(env_spec, max_episode_length)
+
+ # check observation space
+ obs_dict_keys = {'observation', 'achieved_goal', 'desired_goal'}
+ WRONG_OBS_ERR_MESSAGE = (
+ f"{cls.__name__} collect_rollout only supports Dict "
+ f"observation space with keys {obs_dict_keys}, but got {env.observation_space}"
+ )
+ assert isinstance(env.observation_space, gym.spaces.Dict), WRONG_OBS_ERR_MESSAGE
+ assert set(env.observation_space.spaces.keys()) == {'observation', 'achieved_goal', 'desired_goal'}, WRONG_OBS_ERR_MESSAGE
+
+ observation_dict = cast(Mapping[str, np.ndarray], env.reset())
+ observation: torch.Tensor = torch.as_tensor(observation_dict['observation'], dtype=torch.float32)
+
+ goal: torch.Tensor = torch.as_tensor(observation_dict['desired_goal'], dtype=torch.float32)
+ agoal: torch.Tensor = torch.as_tensor(observation_dict['achieved_goal'], dtype=torch.float32)
+ epi.all_observations[0] = observation
+ epi.observation_infos['desired_goals'][0] = goal
+ epi.observation_infos['achieved_goals'][0] = agoal
+ if len(extra_transition_info_keys):
+ epi.transition_infos = dict(epi.transition_infos)
+ epi.transition_infos.update({
+ k: torch.empty([max_episode_length], dtype=torch.float32) for k in extra_transition_info_keys
+ })
+
+ t = 0
+ timeout = False
+ terminal = False
+ while not timeout and not terminal:
+ assert t < max_episode_length
+
+ action = actor(
+ observation,
+ goal,
+ env_spec.action_space,
+ )
+ transition_out = env.step(np.asarray(action))
+ observation_dict, reward, terminal, info = transition_out[:3] + transition_out[-1:] # some BC
+
+ observation = torch.tensor(observation_dict['observation'], dtype=torch.float32) # copy just in case
+
+ goal: torch.Tensor = torch.as_tensor(observation_dict['desired_goal'], dtype=torch.float32)
+ agoal: torch.Tensor = torch.as_tensor(observation_dict['achieved_goal'], dtype=torch.float32)
+
+ if 'is_success' in info:
+ is_success: bool = info['is_success']
+ epi.transition_infos['is_success'][t] = is_success
+ else:
+ if t == 0:
+ # remove field
+ transition_infos = dict(epi.transition_infos)
+ del transition_infos['is_success']
+ epi.transition_infos = transition_infos
+
+ for k in extra_transition_info_keys:
+ epi.transition_infos[k][t] = info[k]
+
+ epi.all_observations[t + 1] = observation
+ epi.actions[t] = torch.as_tensor(action, dtype=torch.float32)
+ epi.rewards[t] = reward
+ epi.observation_infos['desired_goals'][t + 1] = goal
+ epi.observation_infos['achieved_goals'][t + 1] = agoal
+
+ t += 1
+ timeout = info.get('TimeLimit.truncated', False)
+ if assert_exact_episode_length:
+ assert (timeout or terminal) == (t == max_episode_length)
+
+ if t < max_episode_length:
+ epi = epi.first_t(t)
+
+ return epi
def __getitem__(self, indices: torch.Tensor) -> BatchData:
- indices = torch.as_tensor(indices, device=self.device)
+ indices = torch.as_tensor(indices)
eindices = self.indices_to_episode_indices[indices]
obs_indices = indices + eindices # index for `observation`: skip the s_last from previous episodes
obs = self.get_observations(obs_indices)
@@ -332,7 +398,7 @@ class Dataset(torch.utils.data.Dataset):
tindices = self.indices_to_episode_timesteps[indices]
epilengths = self.raw_data.episode_lengths[eindices] # max idx is this
- deltas = torch.arange(self.max_episode_length, device=self.device)
+ deltas = torch.arange(self.max_episode_length)
pdeltas = torch.where(
# test tidx + 1 + delta <= max_idx = epi_length
(tindices[:, None] + deltas) < epilengths[:, None],
@@ -341,16 +407,14 @@ class Dataset(torch.utils.data.Dataset):
)
deltas = torch.distributions.Categorical(
probs=pdeltas,
- validate_args=False,
- ).sample() + 1
- future_observations = self.get_observations(obs_indices + deltas)
+ ).sample()
+ future_observations = self.get_observations(obs_indices + 1 + deltas)
return BatchData(
observations=obs,
actions=self.raw_data.actions[indices],
next_observations=nobs,
future_observations=future_observations,
- future_tdelta=deltas,
rewards=self.raw_data.rewards[indices],
terminals=terminals,
timeouts=self.raw_data.timeouts[indices],
diff --git a/quasimetric_rl/data/interaction.py b/quasimetric_rl/data/interaction.py
deleted file mode 100644
index 699fa61..0000000
--- a/quasimetric_rl/data/interaction.py
+++ /dev/null
@@ -1,178 +0,0 @@
-from __future__ import annotations
-from typing import *
-
-import attrs
-
-import gym
-import gym.spaces
-import numpy as np
-import torch
-import torch.utils.data
-
-from . import Dataset, EpisodeData, EnvSpec
-
-
-def first_nonzero(arr: torch.Tensor, dim: int = -1, invalid_val: int = -1):
- mask = (arr != 0)
- return torch.where(mask.any(dim=dim), mask.to(torch.uint8).argmax(dim=dim), invalid_val)
-
-
-@attrs.define(kw_only=True)
-class EvalEpisodeResult:
- timestep_reward: List[torch.Tensor]
- episode_return: torch.Tensor
- episode_score: torch.Tensor
- timestep_is_success: Optional[List[torch.Tensor]]
- is_success: Optional[torch.Tensor]
- hitting_time: Optional[torch.Tensor]
- extra_timestep_results: Mapping[str, List[torch.Tensor]]
-
- @property
- def capped_hitting_time(self) -> Optional[torch.Tensor]:
- # if not hit -> |ts| + 1
- if self.hitting_time is None:
- return None
- assert self.timestep_is_success is not None
- return torch.stack([torch.where(_x < 0, _succ.shape[0] + 1, _x) for _x, _succ in zip(self.hitting_time, self.timestep_is_success)])
-
- @classmethod
- def from_timestep_reward_is_success(cls, dataset: Dataset,
- timestep_reward: List[torch.Tensor],
- timestep_is_success: Optional[List[torch.Tensor]],
- extra_timestep_results) -> Self:
- return cls(
- timestep_reward=timestep_reward,
- episode_return=torch.stack([r.sum() for r in timestep_reward]),
- episode_score=dataset.normalize_score(timestep_reward),
- timestep_is_success=timestep_is_success,
- is_success=(
- None if timestep_is_success is None
- else torch.stack([_x.any(dim=-1) for _x in timestep_is_success])
- ),
- hitting_time=(
- None if timestep_is_success is None
- else torch.stack([first_nonzero(_x, dim=-1) for _x in timestep_is_success])
- ), # NB this is off by 1
- extra_timestep_results=dict(extra_timestep_results),
- )
-
- @classmethod
- def from_episode_rollouts(cls, dataset: Dataset,rollouts: Sequence[EpisodeData]) -> Self:
- return cls.from_timestep_reward_is_success(
- dataset,
- timestep_reward=[rollout.rewards for rollout in rollouts],
- timestep_is_success=(
- None
- if len(rollouts) == 0 or 'is_success' not in rollouts[0].transition_infos
- else [rollout.transition_infos['is_success'] for rollout in rollouts]
- ),
- extra_timestep_results=(
- {} if len(rollouts) == 0 else
- {
- k: [rollout.transition_infos[k] for rollout in rollouts]
- for k in rollouts[0].transition_infos.keys() if k != 'is_success'
- }
- ),
- )
-
- def summarize(self) -> Mapping[str, Union[torch.Tensor, float, None]]:
- succ_rate_ts = (
- None if self.timestep_is_success is None
- else torch.stack([_x.mean(dtype=torch.float32) for _x in self.timestep_is_success])
- )
- hitting_time = self.capped_hitting_time
- summary = dict(
- epi_return=self.episode_return,
- epi_score=self.episode_score,
- succ_rate_ts=succ_rate_ts,
- succ_rate=self.is_success,
- hitting_time=hitting_time,
- )
- for kk, v in self.extra_timestep_results.items():
- summary[kk] = torch.stack([_v.mean(dtype=torch.float32) for _v in v])
- summary[kk + '_last'] = torch.stack([_v[-1] for _v in v])
-
- return summary
-
-
-def collect_rollout(actor: Callable[[torch.Tensor, torch.Tensor, gym.Space], np.ndarray], *,
- env: gym.Env, env_spec: EnvSpec, max_episode_length: int,
- assert_exact_episode_length: bool = False) -> EpisodeData:
- # NOTE: extra tracked info can be specified by env.tracked_info_keys
-
- from .utils import get_empty_episode
-
- epi = get_empty_episode(env_spec, max_episode_length)
-
- # check observation space
- obs_dict_keys = {'observation', 'achieved_goal', 'desired_goal'}
- WRONG_OBS_ERR_MESSAGE = (
- f"collect_rollout only supports Dict "
- f"observation space with keys {obs_dict_keys}, but got {env.observation_space}"
- )
- assert isinstance(env.observation_space, gym.spaces.Dict), WRONG_OBS_ERR_MESSAGE
- assert set(env.observation_space.spaces.keys()) == {'observation', 'achieved_goal', 'desired_goal'}, WRONG_OBS_ERR_MESSAGE
-
- observation_dict = cast(Mapping[str, np.ndarray], env.reset())
- observation: torch.Tensor = torch.as_tensor(observation_dict['observation'], dtype=torch.float32)
-
- goal: torch.Tensor = torch.as_tensor(observation_dict['desired_goal'], dtype=torch.float32)
- agoal: torch.Tensor = torch.as_tensor(observation_dict['achieved_goal'], dtype=torch.float32)
- epi.all_observations[0] = observation
- epi.observation_infos['desired_goals'][0] = goal
- epi.observation_infos['achieved_goals'][0] = agoal
-
- extra_transition_info_keys = getattr(env, 'tracked_info_keys', [])
- if len(extra_transition_info_keys):
- epi.transition_infos = dict(epi.transition_infos)
- epi.transition_infos.update({
- k: torch.empty([max_episode_length], dtype=torch.float32) for k in extra_transition_info_keys
- })
-
- t = 0
- timeout = False
- terminal = False
- while not timeout and not terminal:
- assert t < max_episode_length
-
- action = actor(
- observation,
- goal,
- env_spec.action_space,
- )
- transition_out = env.step(np.asarray(action))
- observation_dict, reward, terminal, info = transition_out[:3] + transition_out[-1:] # some BC
-
- observation = torch.tensor(observation_dict['observation'], dtype=torch.float32) # copy just in case
-
- goal: torch.Tensor = torch.as_tensor(observation_dict['desired_goal'], dtype=torch.float32)
- agoal: torch.Tensor = torch.as_tensor(observation_dict['achieved_goal'], dtype=torch.float32)
-
- if 'is_success' in info:
- is_success: bool = info['is_success']
- epi.transition_infos['is_success'][t] = is_success
- else:
- if t == 0:
- # remove field
- transition_infos = dict(epi.transition_infos)
- del transition_infos['is_success']
- epi.transition_infos = transition_infos
-
- for k in extra_transition_info_keys:
- epi.transition_infos[k][t] = info[k]
-
- epi.all_observations[t + 1] = observation
- epi.actions[t] = torch.as_tensor(action, dtype=torch.float32)
- epi.rewards[t] = reward
- epi.observation_infos['desired_goals'][t + 1] = goal
- epi.observation_infos['achieved_goals'][t + 1] = agoal
-
- t += 1
- timeout = info.get('TimeLimit.truncated', False)
- if assert_exact_episode_length:
- assert (timeout or terminal) == (t == max_episode_length)
-
- if t < max_episode_length:
- epi = epi.first_t(t)
-
- return epi
diff --git a/quasimetric_rl/data/offline/__init__.py b/quasimetric_rl/data/offline/__init__.py
index c816a7b..22bb42c 100644
--- a/quasimetric_rl/data/offline/__init__.py
+++ b/quasimetric_rl/data/offline/__init__.py
@@ -60,10 +60,6 @@ class OfflineGoalCondEnv(gym.ObservationWrapper, OfflineEnv): # type: ignore
self.get_goal_fn = get_goal_fn
self.extra_info_fns = extra_info_fns
- @property
- def tracked_info_keys(self):
- return tuple(self.extra_info_fns.keys())
-
def observation(self, observation: np.ndarray):
o, g = observation, self.get_goal_fn(self.env)
if self.is_image_based:
diff --git a/quasimetric_rl/data/offline/d4rl/antmaze.py b/quasimetric_rl/data/offline/d4rl/antmaze.py
index dbf5198..c48ab34 100644
--- a/quasimetric_rl/data/offline/d4rl/antmaze.py
+++ b/quasimetric_rl/data/offline/d4rl/antmaze.py
@@ -182,13 +182,12 @@ def create_env_antmaze(name, dict_obseravtion: Optional[bool] = None, *, random_
return env
-def load_episodes_antmaze(name, normalize_observation=True):
+def load_episodes_antmaze(name):
env = load_environment(name)
d4rl_dataset = cached_d4rl_dataset(name)
- if normalize_observation:
- # normalize
- d4rl_dataset['observations'] = obs_norm(name, d4rl_dataset['observations'])
- d4rl_dataset['next_observations'] = obs_norm(name, d4rl_dataset['next_observations'])
+ # normalize
+ d4rl_dataset['observations'] = obs_norm(name, d4rl_dataset['observations'])
+ d4rl_dataset['next_observations'] = obs_norm(name, d4rl_dataset['next_observations'])
yield from convert_dict_to_EpisodeData_iter(
sequence_dataset(
env,
@@ -207,11 +206,3 @@ for name in ['antmaze-umaze-v2', 'antmaze-umaze-diverse-v2',
normalize_score_fn=functools.partial(get_normalized_score, name),
eval_specs=dict(single_task=dict(random_start_goal=False), multi_task=dict(random_start_goal=True)),
)
- register_offline_env(
- 'd4rl', name + '-nonorm',
- create_env_fn=functools.partial(create_env_antmaze, name, normalize_observation=False),
- load_episodes_fn=functools.partial(load_episodes_antmaze, name, normalize_observation=False),
- normalize_score_fn=functools.partial(get_normalized_score, name),
- eval_specs=dict(single_task=dict(random_start_goal=False, normalize_observation=False),
- multi_task=dict(random_start_goal=True, normalize_observation=False)),
- )
diff --git a/quasimetric_rl/data/online/memory.py b/quasimetric_rl/data/online/memory.py
index a3445f4..b7cb94b 100644
--- a/quasimetric_rl/data/online/memory.py
+++ b/quasimetric_rl/data/online/memory.py
@@ -12,7 +12,6 @@ import gym.spaces
from . import OnlineFixedLengthEnv
from ..base import EpisodeData, MultiEpisodeData, Dataset, BatchData
-from ..interaction import collect_rollout
from ..utils import get_empty_episode, get_empty_episodes
@@ -154,7 +153,7 @@ class ReplayBuffer(Dataset):
get_empty_episodes(
self.env_spec, self.episode_length,
int(np.ceil(self.increment_num_transitions / self.episode_length)),
- ).to(self.device),
+ ),
],
dim=0,
)
@@ -165,19 +164,19 @@ class ReplayBuffer(Dataset):
# indices_to_episode_timesteps: torch.Tensor
self.indices_to_episode_indices = torch.cat([
self.indices_to_episode_indices,
- torch.repeat_interleave(torch.arange(original_capacity, new_capacity, device=self.device), self.episode_length),
+ torch.repeat_interleave(torch.arange(original_capacity, new_capacity), self.episode_length),
], dim=0)
self.indices_to_episode_timesteps = torch.cat([
self.indices_to_episode_timesteps,
- torch.arange(self.episode_length, device=self.device).repeat(new_capacity - original_capacity),
+ torch.arange(self.episode_length).repeat(new_capacity - original_capacity),
], dim=0)
logging.info(f'ReplayBuffer: Expanded from capacity={original_capacity} to {new_capacity} episodes')
def collect_rollout(self, actor: Callable[[torch.Tensor, torch.Tensor, gym.Space], np.ndarray], *,
env: Optional[OnlineFixedLengthEnv] = None) -> EpisodeData:
- return collect_rollout(actor, env=(env or self.env), env_spec=self.env_spec,
- max_episode_length=self.episode_length, assert_exact_episode_length=True)
+ return self.collect_rollout_general(actor, env=(env or self.env), env_spec=self.env_spec,
+ max_episode_length=self.episode_length, assert_exact_episode_length=True)
def add_rollout(self, episode: EpisodeData):
if self.num_episodes_realized == self.episodes_capacity:
@@ -215,8 +214,7 @@ class ReplayBuffer(Dataset):
def sample(self, batch_size: int) -> BatchData:
indices = torch.as_tensor(
- np.random.choice(self.num_transitions_realized, size=[batch_size]),
- device=self.device,
+ np.random.choice(self.num_transitions_realized, size=[batch_size])
)
return self[indices]
diff --git a/quasimetric_rl/flags.py b/quasimetric_rl/flags.py
index 2f7578a..2fb0fab 100644
--- a/quasimetric_rl/flags.py
+++ b/quasimetric_rl/flags.py
@@ -25,33 +25,31 @@ FLAGS = FlagsDefinition()
def pdb_if_DEBUG(fn: Callable):
@functools.wraps(fn)
def wrapped(*args, **kwargs):
- if not FLAGS.DEBUG: # check here, in case it is set after decorator call
- return fn(*args, **kwargs)
- else:
- try:
- return fn(*args, **kwargs)
- except:
- # follow ABSL:
- # https://github.com/abseil/abseil-py/blob/a0ae31683e6cf3667886c500327f292c893a1740/absl/app.py#L311-L327
-
- exc = sys.exc_info()[1]
- if isinstance(exc, KeyboardInterrupt):
- raise
-
- # Don't try to post-mortem debug successful SystemExits, since those
- # mean there wasn't actually an error. In particular, the test framework
- # raises SystemExit(False) even if all tests passed.
- if isinstance(exc, SystemExit) and not exc.code:
- raise
-
- # Check the tty so that we don't hang waiting for input in an
- # non-interactive scenario.
+ try:
+ fn(*args, **kwargs)
+ except:
+ # follow ABSL:
+ # https://github.com/abseil/abseil-py/blob/a0ae31683e6cf3667886c500327f292c893a1740/absl/app.py#L311-L327
+
+ exc = sys.exc_info()[1]
+ if isinstance(exc, KeyboardInterrupt):
+ raise
+
+ # Don't try to post-mortem debug successful SystemExits, since those
+ # mean there wasn't actually an error. In particular, the test framework
+ # raises SystemExit(False) even if all tests passed.
+ if isinstance(exc, SystemExit) and not exc.code:
+ raise
+
+ # Check the tty so that we don't hang waiting for input in an
+ # non-interactive scenario.
+ if FLAGS.DEBUG:
traceback.print_exc()
print()
print(' *** Entering post-mortem debugging ***')
print()
pdb.post_mortem()
- raise
+ raise
return wrapped
diff --git a/quasimetric_rl/modules/__init__.py b/quasimetric_rl/modules/__init__.py
index 81b5443..79fd35f 100644
--- a/quasimetric_rl/modules/__init__.py
+++ b/quasimetric_rl/modules/__init__.py
@@ -9,7 +9,6 @@ from . import actor, quasimetric_critic
from ..data import EnvSpec, BatchData
from .utils import LossResult, Module, InfoT, InfoValT
-from ..flags import FLAGS
class QRLAgent(Module):
@@ -33,15 +32,13 @@ class QRLLosses(Module):
critic_losses: Collection[quasimetric_critic.QuasimetricCriticLosses],
critics_total_grad_clip_norm: Optional[float],
recompute_critic_for_actor_loss: bool,
- critics_share_embedding: bool,
- critic_losses_use_target_encoder: bool):
+ critics_share_embedding: bool):
super().__init__()
self.add_module('actor_loss', actor_loss)
self.critic_losses = torch.nn.ModuleList(critic_losses) # type: ignore
self.critics_total_grad_clip_norm = critics_total_grad_clip_norm
self.recompute_critic_for_actor_loss = recompute_critic_for_actor_loss
self.critics_share_embedding = critics_share_embedding
- self.critic_losses_use_target_encoder = critic_losses_use_target_encoder
def forward(self, agent: QRLAgent, data: BatchData, *, optimize: bool = True) -> LossResult:
# compute CriticBatchInfo
@@ -53,106 +50,53 @@ class QRLLosses(Module):
stack.enter_context(critic_loss.optim_update_context(optimize=optimize))
if self.critics_share_embedding and idx > 0:
- critic_batch_info = attrs.evolve(critic_batch_infos[0], critic=critic)
+ critic_batch_info = quasimetric_critic.CriticBatchInfo(
+ critic=critic,
+ zx=critic_batch_infos[0].zx,
+ zy=critic_batch_infos[0].zy,
+ )
else:
zx = critic.encoder(data.observations)
- zy = critic.get_encoder(target=self.critic_losses_use_target_encoder)(data.next_observations)
- if critic.has_separate_target_encoder and self.critic_losses_use_target_encoder:
+ zy = critic.target_encoder(data.next_observations)
+ if critic.has_separate_target_encoder:
assert not zy.requires_grad
critic_batch_info = quasimetric_critic.CriticBatchInfo(
critic=critic,
zx=zx,
zy=zy,
- zy_from_target_encoder=self.critic_losses_use_target_encoder,
)
loss_results[f"critic_{idx:02d}"] = critic_loss(data, critic_batch_info) # we update together to handle shared embedding
critic_batch_infos.append(critic_batch_info)
- critic_grad_norm: InfoValT = {}
-
- if FLAGS.DEBUG:
- def get_grad(loss: Union[torch.Tensor, float]) -> Union[torch.Tensor, float]:
- if isinstance(loss, (int, float)):
- return 0
- loss_grads = torch.autograd.grad(
- loss,
- list(cast(torch.nn.ModuleList, agent.critics).parameters()),
- retain_graph=True,
- allow_unused=True,
- )
- return cast(
- torch.Tensor,
- sum(pg.pow(2).sum() for pg in loss_grads if pg is not None),
- ).sqrt()
-
- for k, loss_r in loss_results.items():
- critic_grad_norm.update({
- k: loss_r.map_losses(get_grad),
- })
-
torch.stack(
- [loss_r.total_loss for loss_r in loss_results.values()]
+ [cast(torch.Tensor, loss_r.loss) for loss_r in loss_results.values()]
).sum().backward()
if self.critics_total_grad_clip_norm is not None:
- critic_grad_norm['total'] = torch.nn.utils.clip_grad_norm_(
- cast(torch.nn.ModuleList, agent.critics).parameters(),
- max_norm=self.critics_total_grad_clip_norm)
- else:
- critic_grad_norm['total'] = cast(
- torch.Tensor,
- sum(p.grad.pow(2).sum() for p in cast(torch.nn.ModuleList, agent.critics).parameters() if p.grad is not None),
- ).sqrt()
+ torch.nn.utils.clip_grad_norm_(cast(torch.nn.ModuleList, agent.critics).parameters(),
+ max_norm=self.critics_total_grad_clip_norm)
if optimize:
for critic in agent.critics:
- critic.update_target_models_()
+ critic.update_target_encoder_()
- actor_grad_norm: InfoValT = {}
if self.actor_loss is not None:
assert agent.actor is not None
with torch.no_grad(), torch.inference_mode():
- for idx, critic in enumerate(agent.critics): # FIXME?
- if self.recompute_critic_for_actor_loss or (critic.has_separate_target_encoder and self.actor_loss.use_target_encoder):
- zx, zy = critic.get_encoder(target=self.actor_loss.use_target_encoder)(
- torch.stack([data.observations, data.next_observations], dim=0)).unbind(0)
+ for idx, critic in enumerate(agent.critics):
+ if self.recompute_critic_for_actor_loss or critic.has_separate_target_encoder:
+ zx, zy = critic.target_encoder(torch.stack([data.observations, data.next_observations], dim=0)).unbind(0)
critic_batch_infos[idx] = quasimetric_critic.CriticBatchInfo(
critic=critic,
zx=zx,
zy=zy,
- zx_from_target_encoder=self.actor_loss.use_target_encoder,
- zy_from_target_encoder=self.actor_loss.use_target_encoder,
)
with self.actor_loss.optim_update_context(optimize=optimize):
loss_results['actor'] = loss_r = self.actor_loss(agent.actor, critic_batch_infos, data)
+ cast(torch.Tensor, loss_r.loss).backward()
- if FLAGS.DEBUG:
- def get_grad(loss: Union[torch.Tensor, float]) -> Union[torch.Tensor, float]:
- if isinstance(loss, (int, float)):
- return 0
- assert agent.actor is not None
- loss_grads = torch.autograd.grad(
- loss,
- list(agent.actor.parameters()),
- retain_graph=True,
- allow_unused=True,
- )
- return cast(
- torch.Tensor,
- sum(pg.pow(2).sum() for pg in loss_grads if pg is not None),
- ).sqrt()
-
- actor_grad_norm.update(cast(Mapping, loss_r.map_losses(get_grad)))
-
- loss_r.total_loss.backward()
-
- actor_grad_norm['total'] = cast(
- torch.Tensor,
- sum(p.grad.pow(2).sum() for p in agent.actor.parameters() if p.grad is not None),
- ).sqrt()
-
- return LossResult.combine(loss_results, grad_norm=dict(critic=critic_grad_norm, actor=actor_grad_norm))
+ return LossResult.combine(loss_results)
# for type hints
def __call__(self, agent: QRLAgent, data: BatchData, *, optimize: bool = True) -> LossResult:
@@ -198,12 +142,7 @@ class QRLLosses(Module):
critic_loss.dynamics_lagrange_mult_sched.load_state_dict(optim_scheds[f"critic_{idx:02d}"]['dynamics_lagrange_mult_sched'])
def extra_repr(self) -> str:
- return '\n'.join([
- f'recompute_critic_for_actor_loss={self.recompute_critic_for_actor_loss}',
- f'critics_share_embedding={self.critics_share_embedding}',
- f'critics_total_grad_clip_norm={self.critics_total_grad_clip_norm}',
- f'critic_losses_use_target_encoder={self.critic_losses_use_target_encoder}',
- ])
+ return f'recompute_critic_for_actor_loss={self.recompute_critic_for_actor_loss}'
@attrs.define(kw_only=True)
@@ -216,7 +155,6 @@ class QRLConf:
default=None, validator=attrs.validators.optional(attrs.validators.gt(0)),
)
recompute_critic_for_actor_loss: bool = False
- critic_losses_use_target_encoder: bool = True
def make(self, *, env_spec: EnvSpec, total_optim_steps: int) -> Tuple[QRLAgent, QRLLosses]:
if self.actor is None:
@@ -236,12 +174,9 @@ class QRLConf:
critics.append(critic)
critic_losses.append(critic_loss)
- return QRLAgent(actor=actor, critics=critics), QRLLosses(
- actor_loss=actor_losses, critic_losses=critic_losses,
- critics_share_embedding=self.critics_share_embedding,
- critics_total_grad_clip_norm=self.critics_total_grad_clip_norm,
- recompute_critic_for_actor_loss=self.recompute_critic_for_actor_loss,
- critic_losses_use_target_encoder=self.critic_losses_use_target_encoder,
- )
+ return QRLAgent(actor=actor, critics=critics), QRLLosses(actor_loss=actor_losses, critic_losses=critic_losses,
+ critics_share_embedding=self.critics_share_embedding,
+ critics_total_grad_clip_norm=self.critics_total_grad_clip_norm,
+ recompute_critic_for_actor_loss=self.recompute_critic_for_actor_loss)
__all__ = ['QRLAgent', 'QRLLosses', 'QRLConf', 'InfoT', 'InfoValT']
diff --git a/quasimetric_rl/modules/actor/losses/__init__.py b/quasimetric_rl/modules/actor/losses/__init__.py
index f9c78a1..c8f1ec7 100644
--- a/quasimetric_rl/modules/actor/losses/__init__.py
+++ b/quasimetric_rl/modules/actor/losses/__init__.py
@@ -15,15 +15,12 @@ from ...optim import OptimWrapper, AdamWSpec, LRScheduler
class ActorLossBase(LossBase):
@abc.abstractmethod
- def forward(self, actor: Actor, critic_batch_infos: Collection[CriticBatchInfo], data: BatchData,
- use_target_encoder: bool, use_target_quasimetric_model: bool) -> LossResult:
+ def forward(self, actor: Actor, critic_batch_infos: Collection[CriticBatchInfo], data: BatchData) -> LossResult:
pass
# for type hints
- def __call__(self, actor: Actor, critic_batch_infos: Collection[CriticBatchInfo], data: BatchData,
- use_target_encoder: bool, use_target_quasimetric_model: bool) -> LossResult:
- return super().__call__(actor, critic_batch_infos, data, use_target_encoder=use_target_encoder,
- use_target_quasimetric_model=use_target_quasimetric_model)
+ def __call__(self, actor: Actor, critic_batch_infos: Collection[CriticBatchInfo], data: BatchData) -> LossResult:
+ return super().__call__(actor, critic_batch_infos, data)
from .min_dist import MinDistLoss
@@ -43,9 +40,6 @@ class ActorLosses(ActorLossBase):
actor_optim: AdamWSpec.Conf = AdamWSpec.Conf(lr=3e-5)
entropy_weight_optim: AdamWSpec.Conf = AdamWSpec.Conf(lr=3e-4) # TODO: move to loss
- use_target_encoder: bool = True
- use_target_quasimetric_model: bool = True
-
def make(self, actor: Actor, total_optim_steps: int, env_spec: EnvSpec) -> 'ActorLosses':
return ActorLosses(
actor,
@@ -55,8 +49,6 @@ class ActorLosses(ActorLossBase):
advantage_weighted_regression=self.advantage_weighted_regression.make(),
actor_optim_spec=self.actor_optim.make(),
entropy_weight_optim_spec=self.entropy_weight_optim.make(),
- use_target_encoder=self.use_target_encoder,
- use_target_quasimetric_model=self.use_target_quasimetric_model,
)
min_dist: Optional[MinDistLoss]
@@ -68,14 +60,10 @@ class ActorLosses(ActorLossBase):
entropy_weight_optim: OptimWrapper
entropy_weight_sched: LRScheduler
- use_target_encoder: bool
- use_target_quasimetric_model: bool
-
def __init__(self, actor: Actor, *, total_optim_steps: int,
min_dist: Optional[MinDistLoss], behavior_cloning: Optional[BCLoss],
advantage_weighted_regression: Optional[AWRLoss],
- actor_optim_spec: AdamWSpec, entropy_weight_optim_spec: AdamWSpec,
- use_target_encoder: bool, use_target_quasimetric_model: bool):
+ actor_optim_spec: AdamWSpec, entropy_weight_optim_spec: AdamWSpec):
super().__init__()
self.add_module('min_dist', min_dist)
self.add_module('behavior_cloning', behavior_cloning)
@@ -88,9 +76,6 @@ class ActorLosses(ActorLossBase):
if min_dist is not None:
assert len(list(min_dist.parameters())) <= 1
- self.use_target_encoder = use_target_encoder
- self.use_target_quasimetric_model = use_target_quasimetric_model
-
def optimizers(self) -> Iterable[OptimWrapper]:
return [self.actor_optim, self.entropy_weight_optim]
@@ -101,24 +86,18 @@ class ActorLosses(ActorLossBase):
loss_results: Dict[str, LossResult] = {}
if self.min_dist is not None:
loss_results.update(
- min_dist=self.min_dist(actor, critic_batch_infos, data, self.use_target_encoder, self.use_target_quasimetric_model),
+ min_dist=self.min_dist(actor, critic_batch_infos, data),
)
if self.behavior_cloning is not None:
loss_results.update(
- bc=self.behavior_cloning(actor, critic_batch_infos, data, self.use_target_encoder, self.use_target_quasimetric_model),
+ bc=self.behavior_cloning(actor, critic_batch_infos, data),
)
if self.advantage_weighted_regression is not None:
loss_results.update(
- awr=self.advantage_weighted_regression(actor, critic_batch_infos, data, self.use_target_encoder, self.use_target_quasimetric_model),
+ awr=self.advantage_weighted_regression(actor, critic_batch_infos, data),
)
return LossResult.combine(loss_results)
# for type hints
def __call__(self, actor: Actor, critic_batch_infos: Sequence[CriticBatchInfo], data: BatchData) -> LossResult:
return torch.nn.Module.__call__(self, actor, critic_batch_infos, data)
-
- def extra_repr(self) -> str:
- return '\n'.join([
- f"actor_optim={self.actor_optim}, entropy_weight_optim={self.entropy_weight_optim}",
- f"use_target_encoder={self.use_target_encoder}, use_target_quasimetric_model={self.use_target_quasimetric_model}",
- ])
diff --git a/quasimetric_rl/modules/actor/losses/awr.py b/quasimetric_rl/modules/actor/losses/awr.py
index 1313dc3..c69e79c 100644
--- a/quasimetric_rl/modules/actor/losses/awr.py
+++ b/quasimetric_rl/modules/actor/losses/awr.py
@@ -3,10 +3,11 @@ from typing import *
import attrs
import torch
+import torch.nn as nn
-from ....data import BatchData
+from ....data import BatchData, EnvSpec
-from ...utils import LatentTensor, LossResult, bcast_bshape
+from ...utils import LatentTensor, LossResult, grad_mul
from ..model import Actor
from ...quasimetric_critic import QuasimetricCritic, CriticBatchInfo
@@ -81,8 +82,8 @@ class AWRLoss(ActorLossBase):
self.clamp = clamp
self.add_goal_as_future_state = add_goal_as_future_state
- def gather_obs_goal_pairs(self, critic_batch_infos: Sequence[CriticBatchInfo], data: BatchData,
- use_target_encoder: bool) -> Tuple[torch.Tensor, torch.Tensor, Collection[ActorObsGoalCriticInfo]]:
+ def gather_obs_goal_pairs(self, critic_batch_infos: Sequence[CriticBatchInfo],
+ data: BatchData) -> Tuple[torch.Tensor, torch.Tensor, Collection[ActorObsGoalCriticInfo]]:
r"""
Returns (
obs,
@@ -115,7 +116,7 @@ class AWRLoss(ActorLossBase):
# add future_observations
zg = torch.stack([
zg,
- critic_batch_info.critic.get_encoder(target=use_target_encoder)(data.future_observations),
+ critic_batch_info.critic.target_encoder(data.future_observations),
], 0)
# zo = zo.expand_as(zg)
@@ -127,11 +128,9 @@ class AWRLoss(ActorLossBase):
return obs, goal, actor_obs_goal_critic_infos
- def forward(self, actor: Actor, critic_batch_infos: Sequence[CriticBatchInfo], data: BatchData,
- use_target_encoder: bool, use_target_quasimetric_model: bool) -> LossResult:
-
+ def forward(self, actor: Actor, critic_batch_infos: Sequence[CriticBatchInfo], data: BatchData) -> LossResult:
with torch.no_grad():
- obs, goal, actor_obs_goal_critic_infos = self.gather_obs_goal_pairs(critic_batch_infos, data, use_target_encoder)
+ obs, goal, actor_obs_goal_critic_infos = self.gather_obs_goal_pairs(critic_batch_infos, data)
info: Dict[str, Union[float, torch.Tensor]] = {}
@@ -141,25 +140,14 @@ class AWRLoss(ActorLossBase):
for idx, actor_obs_goal_critic_info in enumerate(actor_obs_goal_critic_infos):
critic = actor_obs_goal_critic_info.critic
- latent_dynamics = critic.latent_dynamics
- quasimetric_model = critic.get_quasimetric_model(target=use_target_quasimetric_model)
- zo = actor_obs_goal_critic_info.zo.detach() # [B,D]
- zg = actor_obs_goal_critic_info.zg.detach() # [2?,B,D]
+ zo = actor_obs_goal_critic_info.zo.detach()
+ zg = actor_obs_goal_critic_info.zg.detach()
with torch.no_grad(), critic.mode(False):
- zp = latent_dynamics(data.observations, zo, data.actions) # [B,D]
- if idx == 0 or not critic.borrowing_embedding:
- zo, zp, zg = bcast_bshape(
- (zo, 1),
- (zp, 1),
- (zg, 1),
- )
- z = torch.stack([zo, zp], dim=0) # [2,2?,B,D]
- # z = z[(slice(None),) + (None,) * (zg.ndim - zo.ndim) + (Ellipsis,)] # if zg is batched, add enough dim to be broadcast-able
- dist_noact, dist = quasimetric_model(z, zg).unbind(0) # [2,2?,B] -> 2x [2?,B]
- dist_noact = dist_noact.detach()
- else:
- dist = quasimetric_model(zp, zg)
- dist_noact = dists_noact[0]
+ zp = critic.latent_dynamics(data.observations, zo, data.actions)
+ z = torch.stack([zo, zp], dim=0)
+ z = z[(slice(None),) + (None,) * (zg.ndim - zo.ndim) + (Ellipsis,)] # if zg is batched, add enough dim to be broadcast-able
+ dist_noact, dist = critic.quasimetric_model(z, zg).unbind(0)
+ dist_noact = dist_noact.detach()
info[f'dist_delta_{idx:02d}'] = (dist_noact - dist).mean()
info[f'dist_{idx:02d}'] = dist.mean()
dists_noact.append(dist_noact)
diff --git a/quasimetric_rl/modules/actor/losses/behavior_cloning.py b/quasimetric_rl/modules/actor/losses/behavior_cloning.py
index eba36af..632221e 100644
--- a/quasimetric_rl/modules/actor/losses/behavior_cloning.py
+++ b/quasimetric_rl/modules/actor/losses/behavior_cloning.py
@@ -34,8 +34,7 @@ class BCLoss(ActorLossBase):
super().__init__()
self.weight = weight
- def forward(self, actor: Actor, critic_batch_infos: Collection[CriticBatchInfo], data: BatchData,
- use_target_encoder: bool, use_target_quasimetric_model: bool) -> LossResult:
+ def forward(self, actor: Actor, critic_batch_infos: Collection[CriticBatchInfo], data: BatchData) -> LossResult:
actor_distn = actor(data.observations, data.future_observations)
log_prob: torch.Tensor = actor_distn.log_prob(data.actions).mean()
loss = -log_prob * self.weight
diff --git a/quasimetric_rl/modules/actor/losses/min_dist.py b/quasimetric_rl/modules/actor/losses/min_dist.py
index 1d57609..0f90cff 100644
--- a/quasimetric_rl/modules/actor/losses/min_dist.py
+++ b/quasimetric_rl/modules/actor/losses/min_dist.py
@@ -87,8 +87,8 @@ class MinDistLoss(ActorLossBase):
self.register_parameter('raw_entropy_weight', None)
self.target_entropy = None
- def gather_obs_goal_pairs(self, critic_batch_infos: Sequence[CriticBatchInfo], data: BatchData,
- use_target_encoder: bool) -> Tuple[torch.Tensor, torch.Tensor, Collection[ActorObsGoalCriticInfo]]:
+ def gather_obs_goal_pairs(self, critic_batch_infos: Sequence[CriticBatchInfo],
+ data: BatchData) -> Tuple[torch.Tensor, torch.Tensor, Collection[ActorObsGoalCriticInfo]]:
r"""
Returns (
obs,
@@ -121,7 +121,7 @@ class MinDistLoss(ActorLossBase):
# add future_observations
zg = torch.stack([
zg,
- critic_batch_info.critic.get_encoder(target=use_target_encoder)(data.future_observations),
+ critic_batch_info.critic.target_encoder(data.future_observations),
], 0)
# zo = zo.expand_as(zg)
@@ -133,10 +133,9 @@ class MinDistLoss(ActorLossBase):
return obs, goal, actor_obs_goal_critic_infos
- def forward(self, actor: Actor, critic_batch_infos: Sequence[CriticBatchInfo], data: BatchData,
- use_target_encoder: bool, use_target_quasimetric_model: bool) -> LossResult:
+ def forward(self, actor: Actor, critic_batch_infos: Sequence[CriticBatchInfo], data: BatchData) -> LossResult:
with torch.no_grad():
- obs, goal, actor_obs_goal_critic_infos = self.gather_obs_goal_pairs(critic_batch_infos, data, use_target_encoder)
+ obs, goal, actor_obs_goal_critic_infos = self.gather_obs_goal_pairs(critic_batch_infos, data)
actor_distn = actor(obs, goal)
action = actor_distn.rsample()
@@ -148,20 +147,13 @@ class MinDistLoss(ActorLossBase):
for idx, actor_obs_goal_critic_info in enumerate(actor_obs_goal_critic_infos):
critic = actor_obs_goal_critic_info.critic
- latent_dynamics = critic.latent_dynamics
- quasimetric_model = critic.get_quasimetric_model(target=use_target_quasimetric_model)
- zo = actor_obs_goal_critic_info.zo.detach() # [B,D]
- zg = actor_obs_goal_critic_info.zg.detach() # [2?,B,D]
+ zo = actor_obs_goal_critic_info.zo.detach()
+ zg = actor_obs_goal_critic_info.zg.detach()
with critic.requiring_grad(False), critic.mode(False):
- zp = latent_dynamics(data.observations, zo, action) # [2?,B,D]
- if idx == 0 or not critic.borrowing_embedding:
- # action: [2?,B,A]
- z = torch.stack(torch.broadcast_tensors(zo, zp), dim=0) # [2,2?,B,D]
- dist_noact, dist = quasimetric_model(z, zg).unbind(0) # [2,2?,B] -> 2x [2?,B]
- dist_noact = dist_noact.detach()
- else:
- dist = quasimetric_model(zp, zg) # [2?,B]
- dist_noact = dists_noact[0]
+ zp = critic.latent_dynamics(data.observations, zo, action)
+ z = torch.stack(torch.broadcast_tensors(zo, zp), dim=0)
+ dist_noact, dist = critic.quasimetric_model(z, zg).unbind(0)
+ dist_noact = dist_noact.detach()
info[f'dist_delta_{idx:02d}'] = (dist_noact - dist).mean()
info[f'dist_{idx:02d}'] = dist.mean()
dists_noact.append(dist_noact)
diff --git a/quasimetric_rl/modules/optim.py b/quasimetric_rl/modules/optim.py
index 5019e1b..5131ce0 100644
--- a/quasimetric_rl/modules/optim.py
+++ b/quasimetric_rl/modules/optim.py
@@ -1,7 +1,6 @@
from typing import *
import attrs
-import logging
import contextlib
import torch
@@ -92,11 +91,9 @@ class AdamWSpec:
if len(params) == 0:
params = [dict(params=[])] # dummy param group so pytorch doesn't complain
for ii in range(len(params)):
- pg = params[ii]
- if isinstance(pg, Mapping) and 'lr_mul' in pg: # handle lr_multiplier
- pg['params'] = params_list = cast(List[torch.Tensor], list(pg['params']))
+ if isinstance(params, Mapping) and 'lr_mul' in params: # handle lr_multiplier
+ pg = params[ii]
assert 'lr' not in pg
- logging.info(f'params (#tensor={len(params_list)}, #={sum(p.numel() for p in params_list)}): lr_mul={pg["lr_mul"]}')
pg['lr'] = self.lr * pg['lr_mul'] # type: ignore
del pg['lr_mul']
return OptimWrapper(
diff --git a/quasimetric_rl/modules/quasimetric_critic/losses/__init__.py b/quasimetric_rl/modules/quasimetric_critic/losses/__init__.py
index 5487442..fa8eba0 100644
--- a/quasimetric_rl/modules/quasimetric_critic/losses/__init__.py
+++ b/quasimetric_rl/modules/quasimetric_critic/losses/__init__.py
@@ -19,8 +19,6 @@ class CriticBatchInfo:
critic: QuasimetricCritic
zx: LatentTensor
zy: LatentTensor
- zx_from_target_encoder: bool = False
- zy_from_target_encoder: bool = False
class CriticLossBase(LossBase):
@@ -33,7 +31,7 @@ class CriticLossBase(LossBase):
return super().__call__(data, critic_batch_info)
-from .global_push import GlobalPushLoss, GlobalPushLinearLoss, GlobalPushLogLoss, GlobalPushRBFLoss, GlobalPushNextMSELoss
+from .global_push import GlobalPushLoss, GlobalPushLinearLoss, GlobalPushLogLoss, GlobalPushRBFLoss
from .local_constraint import LocalConstraintLoss
from .latent_dynamics import LatentDynamicsLoss
@@ -43,24 +41,20 @@ class QuasimetricCriticLosses(CriticLossBase):
class Conf:
global_push: GlobalPushLoss.Conf = GlobalPushLoss.Conf()
global_push_linear: GlobalPushLinearLoss.Conf = GlobalPushLinearLoss.Conf()
- global_push_next_mse: GlobalPushNextMSELoss.Conf = GlobalPushNextMSELoss.Conf()
global_push_log: GlobalPushLogLoss.Conf = GlobalPushLogLoss.Conf()
global_push_rbf: GlobalPushRBFLoss.Conf = GlobalPushRBFLoss.Conf()
local_constraint: LocalConstraintLoss.Conf = LocalConstraintLoss.Conf()
latent_dynamics: LatentDynamicsLoss.Conf = LatentDynamicsLoss.Conf()
critic_optim: AdamWSpec.Conf = AdamWSpec.Conf(lr=1e-4)
- latent_dynamics_lr_mul: float = attrs.field(default=1., validator=attrs.validators.ge(0))
- quasimetric_model_lr_mul: float = attrs.field(default=1., validator=attrs.validators.ge(0))
- encoder_lr_mul: float = attrs.field(default=1., validator=attrs.validators.ge(0)) # TD-MPC2 uses 0.3
- quasimetric_head_lr_mul: float = attrs.field(default=1., validator=attrs.validators.ge(0)) # IQE2 can benefit from smaller lr, ~1e-5
+ latent_dynamics_lr_mul: float = 1
+ quasimetric_model_lr_mul: float = 1
+ encoder_lr_mul: float = 1 # TD-MPC2 uses 0.3
+ quasimetric_head_lr_mul: float = 1 # IQE2 can benefit from smaller lr, ~1e-5
local_lagrange_mult_optim: AdamWSpec.Conf = AdamWSpec.Conf(lr=1e-2)
dynamics_lagrange_mult_optim: AdamWSpec.Conf = AdamWSpec.Conf(lr=0)
- quasimetric_scale: Optional[str] = attrs.field(
- default=None, validator=attrs.validators.optional(attrs.validators.in_(
- ['best_local_fit', 'best_local_fit_clip5', 'best_local_fit_clip10',
- 'best_local_fit_detach']))) # type: ignore
+ scale_with_best_local_fit: bool = False
def make(self, critic: QuasimetricCritic, total_optim_steps: int,
share_embedding_from: Optional[QuasimetricCritic] = None) -> 'QuasimetricCriticLosses':
@@ -71,7 +65,6 @@ class QuasimetricCriticLosses(CriticLossBase):
# global losses
global_push=self.global_push.make(),
global_push_linear=self.global_push_linear.make(),
- global_push_next_mse=self.global_push_next_mse.make(),
global_push_log=self.global_push_log.make(),
global_push_rbf=self.global_push_rbf.make(),
# local loss
@@ -88,13 +81,12 @@ class QuasimetricCriticLosses(CriticLossBase):
local_lagrange_mult_optim_spec=self.local_lagrange_mult_optim.make(),
dynamics_lagrange_mult_optim_spec=self.dynamics_lagrange_mult_optim.make(),
#
- quasimetric_scale=self.quasimetric_scale,
+ scale_with_best_local_fit=self.scale_with_best_local_fit,
)
borrowing_embedding: bool
global_push: Optional[GlobalPushLoss]
global_push_linear: Optional[GlobalPushLinearLoss]
- global_push_next_mse: Optional[GlobalPushNextMSELoss]
global_push_log: Optional[GlobalPushLogLoss]
global_push_rbf: Optional[GlobalPushRBFLoss]
local_constraint: Optional[LocalConstraintLoss]
@@ -106,13 +98,12 @@ class QuasimetricCriticLosses(CriticLossBase):
local_lagrange_mult_sched: LRScheduler
dynamics_lagrange_mult_optim: OptimWrapper
dynamics_lagrange_mult_sched: LRScheduler
- quasimetric_scale: Optional[str]
+ scale_with_best_local_fit: bool
def __init__(self, critic: QuasimetricCritic, *, total_optim_steps: int,
share_embedding_from: Optional[QuasimetricCritic] = None,
global_push: Optional[GlobalPushLoss], global_push_linear: Optional[GlobalPushLinearLoss],
- global_push_next_mse: Optional[GlobalPushNextMSELoss], global_push_log: Optional[GlobalPushLogLoss],
- global_push_rbf: Optional[GlobalPushRBFLoss],
+ global_push_log: Optional[GlobalPushLogLoss], global_push_rbf: Optional[GlobalPushRBFLoss],
local_constraint: Optional[LocalConstraintLoss], latent_dynamics: LatentDynamicsLoss,
critic_optim_spec: AdamWSpec,
latent_dynamics_lr_mul: float,
@@ -121,7 +112,7 @@ class QuasimetricCriticLosses(CriticLossBase):
quasimetric_head_lr_mul: float,
local_lagrange_mult_optim_spec: AdamWSpec,
dynamics_lagrange_mult_optim_spec: AdamWSpec,
- quasimetric_scale: Optional[str]):
+ scale_with_best_local_fit: bool):
super().__init__()
self.borrowing_embedding = share_embedding_from is not None
if self.borrowing_embedding:
@@ -132,7 +123,6 @@ class QuasimetricCriticLosses(CriticLossBase):
local_constraint = None
self.add_module('global_push', global_push)
self.add_module('global_push_linear', global_push_linear)
- self.add_module('global_push_next_mse', global_push_next_mse)
self.add_module('global_push_log', global_push_log)
self.add_module('global_push_rbf', global_push_rbf)
self.add_module('local_constraint', local_constraint)
@@ -157,7 +147,7 @@ class QuasimetricCriticLosses(CriticLossBase):
self.dynamics_lagrange_mult_optim, self.dynamics_lagrange_mult_sched = dynamics_lagrange_mult_optim_spec.create_optim_scheduler(
latent_dynamics.parameters(), total_optim_steps)
assert len(list(latent_dynamics.parameters())) == 1
- self.quasimetric_scale = quasimetric_scale
+ self.scale_with_best_local_fit = scale_with_best_local_fit
def optimizers(self) -> Iterable[OptimWrapper]:
return [self.critic_optim, self.local_lagrange_mult_optim, self.dynamics_lagrange_mult_optim]
@@ -165,34 +155,23 @@ class QuasimetricCriticLosses(CriticLossBase):
def schedulers(self) -> Iterable[LRScheduler]:
return [self.critic_sched, self.local_lagrange_mult_sched, self.dynamics_lagrange_mult_sched]
- def compute_best_quasimetric_scale(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> Tuple[torch.Tensor, torch.Tensor]:
+ @torch.no_grad()
+ def compute_best_quasimetric_scale(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> torch.Tensor:
assert self.local_constraint is not None and not self.borrowing_embedding
- critic_batch_info.critic.quasimetric_model.quasimetric_head.scale.detach_().fill_(1) # reset
dist = critic_batch_info.critic.quasimetric_model(critic_batch_info.zx, critic_batch_info.zy)
- return dist, (self.local_constraint.step_cost * (dist.mean() / dist.square().mean().clamp_min(1e-12))) # .detach().clamp_(1e-1, 1e1)
+ return (self.local_constraint.step_cost * (dist.mean() / dist.square().mean().clamp_min_(1e-8))).detach().clamp_(1e-3, 1e3)
def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult:
extra_info: Dict[str, torch.Tensor] = {}
- if self.quasimetric_scale is not None and not self.borrowing_embedding:
- unscaled_dist, scale = self.compute_best_quasimetric_scale(data, critic_batch_info)
- assert scale.grad_fn is not None # allow bp
- if self.quasimetric_scale == 'best_local_fit_detach':
- scale = scale.detach()
- elif self.quasimetric_scale == 'best_local_fit_clip5':
- scale = scale.clamp(1 / 5, 5)
- elif self.quasimetric_scale == 'best_local_fit_clip10':
- scale = scale.clamp(1 / 10, 10)
- extra_info['unscaled_dist'] = unscaled_dist
- extra_info['quasimetric_autoscale'] = scale
- critic_batch_info.critic.quasimetric_model.quasimetric_head.scale = scale
+ if self.scale_with_best_local_fit and not self.borrowing_embedding:
+ scale = extra_info['quasimetric_autoscale'] = self.compute_best_quasimetric_scale(data, critic_batch_info)
+ critic_batch_info.critic.quasimetric_model.quasimetric_head.scale.copy_(scale)
loss_results: Dict[str, LossResult] = {}
if self.global_push is not None:
loss_results.update(global_push=self.global_push(data, critic_batch_info))
if self.global_push_linear is not None:
loss_results.update(global_push_linear=self.global_push_linear(data, critic_batch_info))
- if self.global_push_next_mse is not None:
- loss_results.update(global_push_next_mse=self.global_push_next_mse(data, critic_batch_info))
if self.global_push_log is not None:
loss_results.update(global_push_log=self.global_push_log(data, critic_batch_info))
if self.global_push_rbf is not None:
@@ -210,4 +189,4 @@ class QuasimetricCriticLosses(CriticLossBase):
return torch.nn.Module.__call__(self, data, critic_batch_info)
def extra_repr(self) -> str:
- return f"borrowing_embedding={self.borrowing_embedding}, quasimetric_scale={self.quasimetric_scale!r}"
+ return f"borrowing_embedding={self.borrowing_embedding}"
diff --git a/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py b/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py
index 3879518..7c3f9f3 100644
--- a/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py
+++ b/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py
@@ -53,104 +53,16 @@ from . import CriticLossBase, CriticBatchInfo
# return f"weight={self.weight:g}, softplus_beta={self.softplus_beta:g}, softplus_offset={self.softplus_offset:g}"
-class GlobalPushLossBase(CriticLossBase):
+
+class GlobalPushLoss(CriticLossBase):
@attrs.define(kw_only=True)
- class Conf(abc.ABC):
+ class Conf:
# config / argparse uses this to specify behavior
enabled: bool = True
detach_goal: bool = False
detach_proj_goal: bool = False
- detach_qmet: bool = False
- step_cost: float = attrs.field(default=1., validator=attrs.validators.gt(0))
weight: float = attrs.field(default=1., validator=attrs.validators.gt(0))
- weight_future_goal: float = attrs.field(default=0., validator=attrs.validators.ge(0))
- clamp_max_future_goal: bool = True
-
- @abc.abstractmethod
- def make(self) -> Optional['GlobalPushLossBase']:
- if not self.enabled:
- return None
-
- weight: float
- weight_future_goal: float
- detach_goal: bool
- detach_proj_goal: bool
- detach_qmet: bool
- step_cost: float
- clamp_max_future_goal: bool
-
- def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool,
- detach_qmet: bool, step_cost: float, clamp_max_future_goal: bool):
- super().__init__()
- self.weight = weight
- self.weight_future_goal = weight_future_goal
- self.detach_goal = detach_goal
- self.detach_proj_goal = detach_proj_goal
- self.detach_qmet = detach_qmet
- self.step_cost = step_cost
- self.clamp_max_future_goal = clamp_max_future_goal
-
- def generate_dist_weight(self, data: BatchData, critic_batch_info: CriticBatchInfo):
- def get_dist(za: torch.Tensor, zb: torch.Tensor):
- if self.detach_goal:
- zb = zb.detach()
- with critic_batch_info.critic.quasimetric_model.requiring_grad(not self.detach_qmet):
- return critic_batch_info.critic.quasimetric_model(za, zb, proj_grad_enabled=(True, not self.detach_proj_goal))
-
- # To randomly pair zx, zy, we just roll over zy by 1, because zx and zy
- # are latents of randomly ordered random batches.
- zgoal = torch.roll(critic_batch_info.zy, 1, dims=0)
- yield (
- 'random_goal',
- zgoal,
- get_dist(critic_batch_info.zx, zgoal),
- self.weight,
- {},
- )
- if self.weight_future_goal > 0:
- zgoal = critic_batch_info.critic.get_encoder(critic_batch_info.zy_from_target_encoder)(data.future_observations)
- dist = get_dist(critic_batch_info.zx, zgoal)
- if self.clamp_max_future_goal:
- observed_upper_bound = self.step_cost * data.future_tdelta
- info = dict(
- ratio_future_observed_dist=(dist / observed_upper_bound).mean(),
- exceed_future_observed_dist_rate=(dist > observed_upper_bound).mean(dtype=torch.float32),
- )
- dist = dist.clamp_max(self.step_cost * data.future_tdelta)
- else:
- info = {}
- yield (
- 'future_goal',
- zgoal,
- dist,
- self.weight_future_goal,info,
- )
-
- @abc.abstractmethod
- def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor,
- dist: torch.Tensor, weight: float, info: Mapping[str, torch.Tensor]) -> LossResult:
- raise NotImplementedError
-
- def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult:
- return LossResult.combine(
- {
- name: self.compute_loss(data, critic_batch_info, zgoal, dist, weight, info)
- for name, zgoal, dist, weight, info in self.generate_dist_weight(data, critic_batch_info)
- },
- )
-
- def extra_repr(self) -> str:
- return '\n'.join([
- f"weight={self.weight:g}, detach_proj_goal={self.detach_proj_goal}",
- f"weight_future_goal={self.weight_future_goal:g}, detach_qmet={self.detach_qmet}",
- f"step_cost={self.step_cost:g}, clamp_max_future_goal={self.clamp_max_future_goal}",
- ])
-
-
-class GlobalPushLoss(GlobalPushLossBase):
- @attrs.define(kw_only=True)
- class Conf(GlobalPushLossBase.Conf):
# smaller => smoother loss
softplus_beta: float = attrs.field(default=0.1, validator=attrs.validators.gt(0))
@@ -163,175 +75,99 @@ class GlobalPushLoss(GlobalPushLossBase):
return None
return GlobalPushLoss(
weight=self.weight,
- weight_future_goal=self.weight_future_goal,
detach_goal=self.detach_goal,
detach_proj_goal=self.detach_proj_goal,
- detach_qmet=self.detach_qmet,
- step_cost=self.step_cost,
- clamp_max_future_goal=self.clamp_max_future_goal,
softplus_beta=self.softplus_beta,
softplus_offset=self.softplus_offset,
)
+ weight: float
+ detach_goal: bool
+ detach_proj_goal: bool
softplus_beta: float
softplus_offset: float
- def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, detach_qmet: bool, step_cost: float, clamp_max_future_goal: bool,
+ def __init__(self, *, weight: float, detach_goal: bool, detach_proj_goal: bool,
softplus_beta: float, softplus_offset: float):
- super().__init__(weight=weight, weight_future_goal=weight_future_goal, detach_goal=detach_goal,
- detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, clamp_max_future_goal=clamp_max_future_goal)
+ super().__init__()
+ self.weight = weight
+ self.detach_goal = detach_goal
+ self.detach_proj_goal = detach_proj_goal
self.softplus_beta = softplus_beta
self.softplus_offset = softplus_offset
- def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor,
- dist: torch.Tensor, weight: float, info: Mapping[str, torch.Tensor]) -> LossResult:
- dict_info: Dict[str, torch.Tensor] = dict(info)
+ def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult:
+ # To randomly pair zx, zy, we just roll over zy by 1, because zx and zy
+ # are latents of randomly ordered random batches.
+ zgoal = torch.roll(critic_batch_info.zy, 1, dims=0)
+ if self.detach_goal:
+ zgoal = zgoal.detach()
+ dists = critic_batch_info.critic.quasimetric_model(critic_batch_info.zx, zgoal,
+ proj_grad_enabled=(True, not self.detach_proj_goal))
# Sec 3.2. Transform so that we penalize large distances less.
- tsfm_dist: torch.Tensor = F.softplus(self.softplus_offset - dist, beta=self.softplus_beta) # type: ignore
+ tsfm_dist: torch.Tensor = F.softplus(self.softplus_offset - dists, beta=self.softplus_beta) # type: ignore
tsfm_dist = tsfm_dist.mean()
- dict_info.update(dist=dist.mean(), tsfm_dist=tsfm_dist)
- return LossResult(loss=tsfm_dist * weight, info=dict_info)
+ return LossResult(loss=tsfm_dist * self.weight, info=dict(dist=dists.mean(), tsfm_dist=tsfm_dist)) # type: ignore
def extra_repr(self) -> str:
- return '\n'.join([
- super().extra_repr(),
- f"softplus_beta={self.softplus_beta:g}, softplus_offset={self.softplus_offset:g}",
- ])
+ return f"weight={self.weight:g}, detach_proj_goal={self.detach_proj_goal}, softplus_beta={self.softplus_beta:g}, softplus_offset={self.softplus_offset:g}"
+
-class GlobalPushLinearLoss(GlobalPushLossBase):
+class GlobalPushLinearLoss(CriticLossBase):
@attrs.define(kw_only=True)
- class Conf(GlobalPushLossBase.Conf):
- enabled: bool = False
+ class Conf:
+ # config / argparse uses this to specify behavior
- clamp_max: Optional[float] = attrs.field(default=None, validator=attrs.validators.optional(attrs.validators.gt(0)))
+ enabled: bool = False
+ detach_goal: bool = False
+ detach_proj_goal: bool = False
+ weight: float = attrs.field(default=1., validator=attrs.validators.gt(0))
def make(self) -> Optional['GlobalPushLinearLoss']:
if not self.enabled:
return None
return GlobalPushLinearLoss(
weight=self.weight,
- weight_future_goal=self.weight_future_goal,
detach_goal=self.detach_goal,
detach_proj_goal=self.detach_proj_goal,
- detach_qmet=self.detach_qmet,
- step_cost=self.step_cost,
- clamp_max_future_goal=self.clamp_max_future_goal,
- clamp_max=self.clamp_max,
)
- clamp_max: Optional[float]
-
- def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, detach_qmet: bool, step_cost: float, clamp_max_future_goal: bool,
- clamp_max: Optional[float]):
- super().__init__(weight=weight, weight_future_goal=weight_future_goal, detach_goal=detach_goal,
- detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, clamp_max_future_goal=clamp_max_future_goal)
- self.clamp_max = clamp_max
-
- def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor,
- dist: torch.Tensor, weight: float, info: Mapping[str, torch.Tensor]) -> LossResult:
- dict_info: Dict[str, torch.Tensor] = dict(info)
- if self.clamp_max is None:
- dist = dist.mean()
- dict_info.update(dist=dist)
- neg_loss = dist
- else:
- tsfm_dist = dist.clamp_max(self.clamp_max)
- dict_info.update(
- dist=dist.mean(),
- tsfm_dist=tsfm_dist.mean(),
- exceed_rate=(dist >= self.clamp_max).mean(dtype=torch.float32),
- )
- neg_loss = tsfm_dist
- return LossResult(loss=neg_loss * weight, info=dict_info)
-
- def extra_repr(self) -> str:
- return '\n'.join([
- super().extra_repr(),
- f"clamp_max={self.clamp_max!r}",
- ])
-
-
-
-
-class GlobalPushNextMSELoss(GlobalPushLossBase):
- @attrs.define(kw_only=True)
- class Conf(GlobalPushLossBase.Conf):
- enabled: bool = False
- detach_target_dist: bool = True
- allow_gt: bool = False
- gamma: Optional[float] = attrs.field(
- default=None, validator=attrs.validators.optional(attrs.validators.and_(
- attrs.validators.gt(0),
- attrs.validators.lt(1),
- )))
-
- def make(self) -> Optional['GlobalPushNextMSELoss']:
- if not self.enabled:
- return None
- return GlobalPushNextMSELoss(
- weight=self.weight,
- weight_future_goal=self.weight_future_goal,
- detach_goal=self.detach_goal,
- detach_proj_goal=self.detach_proj_goal,
- detach_qmet=self.detach_qmet,
- clamp_max_future_goal=self.clamp_max_future_goal,
- step_cost=self.step_cost,
- detach_target_dist=self.detach_target_dist,
- allow_gt=self.allow_gt,
- gamma=self.gamma,
- )
-
- def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, detach_qmet: bool, step_cost: float, clamp_max_future_goal: bool,
- detach_target_dist: bool, allow_gt: bool, gamma: Optional[float]):
- super().__init__(weight=weight, weight_future_goal=weight_future_goal, detach_goal=detach_goal,
- detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, clamp_max_future_goal=clamp_max_future_goal)
- self.detach_target_dist = detach_target_dist
- self.allow_gt = allow_gt
- self.gamma = gamma
-
- def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor,
- dist: torch.Tensor, weight: float, info: Mapping[str, torch.Tensor]) -> LossResult:
- dict_info: Dict[str, torch.Tensor] = dict(info)
-
- with torch.enable_grad(self.detach_target_dist):
- # by tri-eq, the actual cost can't be larger than step_cost + d(s', g)
- next_dist = critic_batch_info.critic.quasimetric_model(
- critic_batch_info.zy, zgoal, proj_grad_enabled=(True, not self.detach_proj_goal)
- )
- if self.detach_target_dist:
- next_dist = next_dist.detach()
- target_dist = self.step_cost + next_dist
-
- if self.allow_gt:
- dist = dist.clamp_max(target_dist)
- dict_info.update(
- exceed_rate=(dist >= target_dist).mean(dtype=torch.float32),
- )
-
- if self.gamma is None:
- loss = F.mse_loss(dist, target_dist)
- else:
- loss = F.mse_loss(self.gamma ** dist, self.gamma ** target_dist)
+ weight: float
+ detach_goal: bool
+ detach_proj_goal: bool
- dict_info.update(
- loss=loss, dist=dist.mean(), target_dist=target_dist.mean()
- )
+ def __init__(self, *, weight: float, detach_goal: bool, detach_proj_goal: bool):
+ super().__init__()
+ self.weight = weight
+ self.detach_goal = detach_goal
+ self.detach_proj_goal = detach_proj_goal
- return LossResult(loss=loss * self.weight, info=dict_info)
+ def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult:
+ # To randomly pair zx, zy, we just roll over zy by 1, because zx and zy
+ # are latents of randomly ordered random batches.
+ zgoal = torch.roll(critic_batch_info.zy, 1, dims=0)
+ if self.detach_goal:
+ zgoal = zgoal.detach()
+ dists = critic_batch_info.critic.quasimetric_model(critic_batch_info.zx, zgoal,
+ proj_grad_enabled=(True, not self.detach_proj_goal))
+ dists = dists.mean()
+ return LossResult(loss=dists * (-self.weight), info=dict(dist=dists)) # type: ignore
def extra_repr(self) -> str:
- return '\n'.join([
- super().extra_repr(),
- "detach_target_dist={self.detach_target_dist}",
- ])
+ return f"weight={self.weight:g}, detach_proj_goal={self.detach_proj_goal}"
+
-class GlobalPushLogLoss(GlobalPushLossBase):
+class GlobalPushLogLoss(CriticLossBase):
@attrs.define(kw_only=True)
- class Conf(GlobalPushLossBase.Conf):
- enabled: bool = False
+ class Conf:
+ # config / argparse uses this to specify behavior
+ enabled: bool = False
+ detach_goal: bool = False
+ detach_proj_goal: bool = False
+ weight: float = attrs.field(default=1., validator=attrs.validators.gt(0))
offset: float = attrs.field(default=1., validator=attrs.validators.gt(0))
def make(self) -> Optional['GlobalPushLogLoss']:
@@ -339,46 +175,54 @@ class GlobalPushLogLoss(GlobalPushLossBase):
return None
return GlobalPushLogLoss(
weight=self.weight,
- weight_future_goal=self.weight_future_goal,
detach_goal=self.detach_goal,
detach_proj_goal=self.detach_proj_goal,
- detach_qmet=self.detach_qmet,
- step_cost=self.step_cost,
- clamp_max_future_goal=self.clamp_max_future_goal,
offset=self.offset,
)
+ weight: float
+ detach_goal: bool
+ detach_proj_goal: bool
offset: float
- def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, detach_qmet: bool, step_cost: float, clamp_max_future_goal: bool,
- offset: float):
- super().__init__(weight=weight, weight_future_goal=weight_future_goal, detach_goal=detach_goal,
- detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, clamp_max_future_goal=clamp_max_future_goal)
+ def __init__(self, *, weight: float, detach_goal: bool, detach_proj_goal: bool, offset: float):
+ super().__init__()
+ self.weight = weight
+ self.detach_goal = detach_goal
+ self.detach_proj_goal = detach_proj_goal
self.offset = offset
- def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor,
- dist: torch.Tensor, weight: float, info: Mapping[str, torch.Tensor]) -> LossResult:
- dict_info: Dict[str, torch.Tensor] = dict(info)
- tsfm_dist: torch.Tensor = -dist.add(self.offset).log()
+ def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult:
+ # To randomly pair zx, zy, we just roll over zy by 1, because zx and zy
+ # are latents of randomly ordered random batches.
+ zgoal = torch.roll(critic_batch_info.zy, 1, dims=0)
+ if self.detach_goal:
+ zgoal = zgoal.detach()
+ dists = critic_batch_info.critic.quasimetric_model(critic_batch_info.zx, zgoal,
+ proj_grad_enabled=(True, not self.detach_proj_goal))
+ # Sec 3.2. Transform so that we penalize large distances less.
+ tsfm_dist: torch.Tensor = -dists.add(self.offset).log()
tsfm_dist = tsfm_dist.mean()
- dict_info.update(dist=dist.mean(), tsfm_dist=tsfm_dist)
- return LossResult(loss=tsfm_dist * weight, info=dict_info)
+ return LossResult(loss=tsfm_dist * self.weight, info=dict(dist=dists.mean(), tsfm_dist=tsfm_dist)) # type: ignore
def extra_repr(self) -> str:
- return '\n'.join([
- super().extra_repr(),
- f"offset={self.offset:g}",
- ])
+ return f"weight={self.weight:g}, detach_proj_goal={self.detach_proj_goal}, offset={self.offset:g}"
+
-class GlobalPushRBFLoss(GlobalPushLossBase):
+class GlobalPushRBFLoss(CriticLossBase):
# say E[opt T] approx sqrt(2)/2 timeout, so E[opt T^2] approx 1/2 timeout^2
# to emulate log E exp(-2 d^2), where 2 d^2 is around 4, we scale model T with r, and use log E exp(- r^2 T^2), and let r^2 T^2 to be around 4
# so r^2 approx 8 / timeout^2 and r approx 2.82 / timeout. If timeout = 850, this = 300
@attrs.define(kw_only=True)
- class Conf(GlobalPushLossBase.Conf):
+ class Conf:
+ # config / argparse uses this to specify behavior
+
enabled: bool = False
+ detach_goal: bool = False
+ detach_proj_goal: bool = False
+ weight: float = attrs.field(default=1., validator=attrs.validators.gt(0))
inv_scale: float = attrs.field(default=300., validator=attrs.validators.ge(1e-3))
@@ -387,37 +231,37 @@ class GlobalPushRBFLoss(GlobalPushLossBase):
return None
return GlobalPushRBFLoss(
weight=self.weight,
- weight_future_goal=self.weight_future_goal,
detach_goal=self.detach_goal,
detach_proj_goal=self.detach_proj_goal,
- detach_qmet=self.detach_qmet,
- step_cost=self.step_cost,
- clamp_max_future_goal=self.clamp_max_future_goal,
inv_scale=self.inv_scale,
)
+ weight: float
+ detach_goal: bool
+ detach_proj_goal: bool
inv_scale: float
- def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, detach_qmet: bool, step_cost: float, clamp_max_future_goal: bool,
- inv_scale: float):
- super().__init__(weight=weight, weight_future_goal=weight_future_goal, detach_goal=detach_goal,
- detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, clamp_max_future_goal=clamp_max_future_goal)
+ def __init__(self, *, weight: float, detach_goal: bool, detach_proj_goal: bool, inv_scale: float):
+ super().__init__()
+ self.weight = weight
+ self.detach_goal = detach_goal
+ self.detach_proj_goal = detach_proj_goal
self.inv_scale = inv_scale
- def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor,
- dist: torch.Tensor, weight: float, info: Mapping[str, torch.Tensor]) -> LossResult:
- dict_info: Dict[str, torch.Tensor] = dict(info)
- inv_scale = dist.detach().square().mean().div(2).sqrt().clamp(1e-3, self.inv_scale) # make E[d^2]/r^2 approx 2
- tsfm_dist: torch.Tensor = (dist / inv_scale).square().neg().exp()
+ def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult:
+ # To randomly pair zx, zy, we just roll over zy by 1, because zx and zy
+ # are latents of randomly ordered random batches.
+ zgoal = torch.roll(critic_batch_info.zy, 1, dims=0)
+ if self.detach_goal:
+ zgoal = zgoal.detach()
+ dists = critic_batch_info.critic.quasimetric_model(critic_batch_info.zx, zgoal,
+ proj_grad_enabled=(True, not self.detach_proj_goal))
+ inv_scale = dists.detach().square().mean().div(2).sqrt().clamp(1e-3, self.inv_scale) # make E[d^2]/r^2 approx 2
+ tsfm_dist: torch.Tensor = (dists / inv_scale).square().neg().exp()
rbf_potential = tsfm_dist.mean().log()
- dict_info.update(
- dist=dist.mean(), inv_scale=inv_scale,
- tsfm_dist=tsfm_dist, rbf_potential=rbf_potential,
- )
- return LossResult(loss=rbf_potential * self.weight, info=dict_info)
+ return LossResult(loss=rbf_potential * self.weight,
+ info=dict(dist=dists.mean(), inv_scale=inv_scale,
+ tsfm_dist=tsfm_dist, rbf_potential=rbf_potential)) # type: ignore
def extra_repr(self) -> str:
- return '\n'.join([
- super().extra_repr(),
- f"inv_scale={self.inv_scale:g}",
- ])
+ return f"weight={self.weight:g}, detach_proj_goal={self.detach_proj_goal}, inv_scale={self.inv_scale:g}"
diff --git a/quasimetric_rl/modules/quasimetric_critic/losses/latent_dynamics.py b/quasimetric_rl/modules/quasimetric_critic/losses/latent_dynamics.py
index 19e4886..2701e6f 100644
--- a/quasimetric_rl/modules/quasimetric_critic/losses/latent_dynamics.py
+++ b/quasimetric_rl/modules/quasimetric_critic/losses/latent_dynamics.py
@@ -2,7 +2,6 @@ from typing import *
import attrs
-import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -26,7 +25,7 @@ class LatentDynamicsLoss(CriticLossBase):
# weight: float = attrs.field(default=0.1, validator=attrs.validators.gt(0))
epsilon: float = attrs.field(default=0.25, validator=attrs.validators.gt(0))
- # bidirectional: bool = True
+ bidirectional: bool = True
gamma: Optional[float] = attrs.field(
default=None, validator=attrs.validators.optional(attrs.validators.and_(
attrs.validators.gt(0),
@@ -36,54 +35,41 @@ class LatentDynamicsLoss(CriticLossBase):
detach_sp: bool = False
detach_proj_sp: bool = False
detach_qmet: bool = False
- non_quasimetric_dim_mse_weight: float = attrs.field(default=0., validator=attrs.validators.ge(0))
-
- kind: str = attrs.field(
- default='mean', validator=attrs.validators.in_({'mean', 'unidir', 'max', 'bound', 'bound_l2comp', 'l2comp'})) # type: ignore
def make(self) -> 'LatentDynamicsLoss':
return LatentDynamicsLoss(
# weight=self.weight,
epsilon=self.epsilon,
- # bidirectional=self.bidirectional,
+ bidirectional=self.bidirectional,
gamma=self.gamma,
detach_sp=self.detach_sp,
detach_proj_sp=self.detach_proj_sp,
detach_qmet=self.detach_qmet,
- non_quasimetric_dim_mse_weight=self.non_quasimetric_dim_mse_weight,
init_lagrange_multiplier=self.init_lagrange_multiplier,
- kind=self.kind,
)
# weight: float
epsilon: float
- # bidirectional: bool
+ bidirectional: bool
gamma: Optional[float]
detach_sp: bool
detach_proj_sp: bool
detach_qmet: bool
- non_quasimetric_dim_mse_weight: float
c: float
init_lagrange_multiplier: float
- kind: str
- def __init__(self, *, epsilon: float,
- # bidirectional: bool,
+ def __init__(self, *, epsilon: float, bidirectional: bool,
gamma: Optional[float], init_lagrange_multiplier: float,
# weight: float,
- detach_qmet: bool, detach_proj_sp: bool, detach_sp: bool,
- non_quasimetric_dim_mse_weight: float,
- kind: str):
+ detach_qmet: bool, detach_proj_sp: bool, detach_sp: bool):
super().__init__()
# self.weight = weight
self.epsilon = epsilon
- # self.bidirectional = bidirectional
+ self.bidirectional = bidirectional
self.gamma = gamma
self.detach_qmet = detach_qmet
self.detach_sp = detach_sp
self.detach_proj_sp = detach_proj_sp
- self.kind = kind
- self.non_quasimetric_dim_mse_weight = non_quasimetric_dim_mse_weight
self.init_lagrange_multiplier = init_lagrange_multiplier
self.raw_lagrange_multiplier = nn.Parameter(
torch.tensor(softplus_inv_float(init_lagrange_multiplier), dtype=torch.float32))
@@ -95,58 +81,15 @@ class LatentDynamicsLoss(CriticLossBase):
pred_zy = critic.latent_dynamics(data.observations, zx, data.actions)
if self.detach_sp:
zy = zy.detach()
+ with critic.quasimetric_model.requiring_grad(not self.detach_qmet):
+ dists = critic.quasimetric_model(zy, pred_zy, bidirectional=self.bidirectional, # at least optimize d(s', hat{s'})
+ proj_grad_enabled=(self.detach_proj_sp, True))
lagrange_mult = F.softplus(self.raw_lagrange_multiplier) # make positive
# lagrange multiplier is minimax training, so grad_mul -1
lagrange_mult = grad_mul(lagrange_mult, -1)
- info: Dict[str, torch.Tensor] = {}
-
- if self.kind == 'unidir':
- with critic.quasimetric_model.requiring_grad(not self.detach_qmet):
- dists = critic.quasimetric_model(zy, pred_zy, bidirectional=False, # at least optimize d(s', hat{s'})
- proj_grad_enabled=(self.detach_proj_sp, True))
- info.update(dist_n2p=dists.mean())
- elif self.kind in ['mean', 'max', 'bound', 'l2comp']:
- with critic.quasimetric_model.requiring_grad(not self.detach_qmet):
- dists_comps = critic.quasimetric_model(
- zy, pred_zy, bidirectional=True, proj_grad_enabled=(self.detach_proj_sp, True),
- reduced=False)
- dists = critic.quasimetric_model.quasimetric_head.reduction(dists_comps)
-
- dist_p2n, dist_n2p = dists.unbind(-1)
- info.update(dist_p2n=dist_p2n.mean(), dist_n2p=dist_n2p.mean())
- if self.kind == 'max':
- dists = dists.max(-1).values
- elif self.kind == 'bound':
- with critic.quasimetric_model.requiring_grad(not self.detach_qmet):
- dists = critic.quasimetric_model(
- zy, pred_zy, bidirectional=False, # NB: false is enough for bound
- proj_grad_enabled=(self.detach_proj_sp, True),
- symmetric_upperbound=True,
- )
- info.update(dist_upbnd=dists.mean())
- elif self.kind == 'bound_l2comp':
- with critic.quasimetric_model.requiring_grad(not self.detach_qmet):
- dists_comps = critic.quasimetric_model(
- zy, pred_zy, bidirectional=False, # NB: false is enough for bound
- proj_grad_enabled=(self.detach_proj_sp, True),
- reduced=False, symmetric_upperbound=True,
- )
- dists = cast(
- torch.Tensor,
- dists_comps.norm(dim=-1) / np.sqrt(dists_comps.shape[-1]),
- )
- info.update(dist_upbnd_l2comp=dists.mean())
- elif self.kind == 'l2comp':
- dists = cast(
- torch.Tensor,
- dists_comps.norm(dim=-1) / np.sqrt(dists_comps.shape[-1]),
- )
- info.update(dist_l2comp=dists.mean())
- else:
- raise NotImplementedError(self.kind)
-
+ info = {}
if self.gamma is None:
sq_dists = dists.square().mean()
violation = (sq_dists - self.epsilon ** 2)
@@ -158,15 +101,11 @@ class LatentDynamicsLoss(CriticLossBase):
loss = violation * lagrange_mult
info.update(violation=violation, lagrange_mult=lagrange_mult)
- if self.non_quasimetric_dim_mse_weight > 0:
- assert critic.quasimetric_model.input_slice_size < critic.quasimetric_model.input_size, \
- "non-quasimetric dim mse only makes sense if input_slice_size < input_size, but got " \
- f"{critic.quasimetric_model.input_slice_size} >= {critic.quasimetric_model.input_size}"
- _zy = zy[..., critic.quasimetric_model.input_slice_size:]
- _pred_zy = pred_zy[..., critic.quasimetric_model.input_slice_size:]
- non_quasimetric_dim_mse = F.mse_loss(_pred_zy, _zy)
- info.update(non_quasimetric_dim_mse=non_quasimetric_dim_mse)
- loss += non_quasimetric_dim_mse * self.non_quasimetric_dim_mse_weight
+ if self.bidirectional:
+ dist_p2n, dist_n2p = dists.flatten(0, -2).mean(0).unbind(-1)
+ info.update(dist_p2n=dist_p2n, dist_n2p=dist_n2p)
+ else:
+ info.update(dist_n2p=dists.mean())
return LossResult(
loss=loss,
@@ -175,7 +114,7 @@ class LatentDynamicsLoss(CriticLossBase):
def extra_repr(self) -> str:
return '\n'.join([
- f"kind={self.kind}, gamma={self.gamma!r}",
- f"epsilon={self.epsilon!r}, detach_sp={self.detach_sp}, detach_proj_sp={self.detach_proj_sp}, detach_qmet={self.detach_qmet}",
+ f"epsilon={self.epsilon!r}, bidirectional={self.bidirectional}",
+ f"gamma={self.gamma!r}, detach_sp={self.detach_sp}, detach_proj_sp={self.detach_proj_sp}, detach_qmet={self.detach_qmet}",
])
# return f"weight={self.weight:g}, detach_sp={self.detach_sp}"
diff --git a/quasimetric_rl/modules/quasimetric_critic/losses/local_constraint.py b/quasimetric_rl/modules/quasimetric_critic/losses/local_constraint.py
index eb87576..1806248 100644
--- a/quasimetric_rl/modules/quasimetric_critic/losses/local_constraint.py
+++ b/quasimetric_rl/modules/quasimetric_critic/losses/local_constraint.py
@@ -2,7 +2,6 @@ from typing import *
import attrs
-import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -31,14 +30,8 @@ class LocalConstraintLoss(CriticLossBase):
init_lagrange_multiplier: float = attrs.field(default=0.01, validator=attrs.validators.gt(0))
- p: float = attrs.field(default=2., validator=attrs.validators.ge(1))
- compare_before_pow: bool = True
- # kind: str = attrs.field(
- # default='mse', validator=attrs.validators.in_(['mse', 'sq'])) # type: ignore
- # batch_reduction: str = attrs.field(
- # default='mean', validator=attrs.validators.in_(['mean', 'l2'])) # type: ignore
- invpow_after_batch_agg: bool = False
- log: bool = False
+ kind: str = attrs.field(
+ default='mse', validator=attrs.validators.in_(['mse', 'sq'])) # type: ignore
detach_proj_sp: bool = False
detach_sp: bool = False
@@ -56,22 +49,12 @@ class LocalConstraintLoss(CriticLossBase):
step_cost=self.step_cost,
step_cost_high=self.step_cost_high,
init_lagrange_multiplier=self.init_lagrange_multiplier,
- # kind=self.kind,
- p=self.p,
- compare_before_pow=self.compare_before_pow,
- log=self.log,
- # batch_reduction=self.batch_reduction,
- invpow_after_batch_agg=self.invpow_after_batch_agg,
+ kind=self.kind,
detach_sp=self.detach_sp,
detach_proj_sp=self.detach_proj_sp,
)
- # kind: str
- p: float
- compare_before_pow: bool
- log: bool
- # batch_reduction: str
- invpow_after_batch_agg: bool
+ kind: str
epsilon: float
allow_le: bool
step_cost: float
@@ -82,21 +65,11 @@ class LocalConstraintLoss(CriticLossBase):
raw_lagrange_multiplier: nn.Parameter # for the QRL constrained optimization
- def __init__(self, *,
- # kind: str,
- p: float, compare_before_pow: bool, invpow_after_batch_agg: bool,
- log: bool,
- # batch_reduction: str,
- epsilon: float, allow_le: bool,
+ def __init__(self, *, kind: str, epsilon: float, allow_le: bool,
step_cost: float, step_cost_high: float, init_lagrange_multiplier: float,
detach_sp: bool, detach_proj_sp: bool):
super().__init__()
- # self.kind = kind
- self.p = p
- self.compare_before_pow = compare_before_pow
- self.log = log
- # self.batch_reduction = batch_reduction
- self.invpow_after_batch_agg = invpow_after_batch_agg
+ self.kind = kind
self.epsilon = epsilon
self.allow_le = allow_le
self.step_cost = step_cost
@@ -127,7 +100,6 @@ class LocalConstraintLoss(CriticLossBase):
info['dist_090'],
info['dist_100'],
) = dist.quantile(dist.new_tensor([0, 0.1, 0.25, 0.5, 0.75, 0.9, 1])).unbind()
- info.update(dist=dist.mean())
lagrange_mult = F.softplus(self.raw_lagrange_multiplier) # make positive
# lagrange multiplier is minimax training, so grad_mul -1
@@ -138,63 +110,31 @@ class LocalConstraintLoss(CriticLossBase):
else:
target = dist.detach().clamp(self.step_cost, self.step_cost_high)
- if self.log:
- dist = dist.log()
- if isinstance(target, torch.Tensor):
- target = target.log()
+ if self.kind == 'mse':
+ if self.allow_le:
+ sq_deviation = (dist - target).relu().square().mean()
else:
- target = np.log(target)
- info.update(dist_log=dist.mean())
-
- if self.compare_before_pow:
- deviation = dist - target
- deviation = (torch.relu if self.allow_le else torch.abs)(deviation)
- deviation = deviation ** self.p
- else:
- deviation = (dist ** self.p - target ** self.p)
- deviation = (torch.relu if self.allow_le else torch.abs)(deviation)
-
- if self.invpow_after_batch_agg:
- deviation = deviation.mean().pow(1 / self.p)
- violation = deviation - self.epsilon
+ sq_deviation = (dist - target).square().mean()
+ elif self.kind == 'sq':
+ if self.allow_le:
+ sq_deviation = (dist.square() - target ** 2).relu().mean()
+ else:
+ sq_deviation = (dist.square() - target ** 2).abs().mean()
else:
- violation = deviation.mean() - (self.epsilon ** self.p)
- info.update(deviation=deviation.mean(), violation=violation, lagrange_mult=lagrange_mult)
-
- # if self.kind == 'mse':
- # if self.allow_le:
- # sq_deviation = (dist - target).relu().square()
- # else:
- # sq_deviation = (dist - target).square()
- # elif self.kind == 'sq':
- # if self.allow_le:
- # sq_deviation = (dist.square() - target ** 2).relu()
- # else:
- # sq_deviation = (dist.square() - target ** 2).abs()
- # else:
- # raise ValueError(f"Unknown kind {self.kind}")
- # info.update(sq_deviation=sq_deviation.mean())
-
- # if self.batch_reduction == 'mean':
- # batch_sq_deviation = sq_deviation.mean()
- # elif self.batch_reduction == 'l2':
- # batch_sq_deviation = sq_deviation.square().mean().sqrt()
- # else:
- # raise ValueError(f"Unknown batch_reduction {self.batch_reduction}")
- # info.update(batch_sq_deviation=batch_sq_deviation)
-
- # violation = (batch_sq_deviation - self.epsilon ** 2)
-
+ raise ValueError(f"Unknown kind {self.kind}")
+ violation = (sq_deviation - self.epsilon ** 2)
loss = violation * lagrange_mult
- # info.update(violation=violation, lagrange_mult=lagrange_mult)
+
+ info.update(
+ dist=dist.mean(), sq_deviation=sq_deviation,
+ violation=violation, lagrange_mult=lagrange_mult,
+ )
return LossResult(loss=loss, info=info)
def extra_repr(self) -> str:
return '\n'.join([
- # f"kind={self.kind}, log={self.log}, batch_reduction={self.batch_reduction}",
- f"p={self.p}, compare_before_pow={self.compare_before_pow}, invpow_after_batch_agg={self.invpow_after_batch_agg}",
- f"epsilon={self.epsilon:g}, allow_le={self.allow_le}",
+ f"{self.kind}, epsilon={self.epsilon:g}, allow_le={self.allow_le}",
f"step_cost={self.step_cost:g}, step_cost_high={self.step_cost_high:g}",
f"detach_sp={self.detach_sp}, detach_proj_sp={self.detach_proj_sp}",
])
diff --git a/quasimetric_rl/modules/quasimetric_critic/models/__init__.py b/quasimetric_rl/modules/quasimetric_critic/models/__init__.py
index 266f592..f9e05c4 100644
--- a/quasimetric_rl/modules/quasimetric_critic/models/__init__.py
+++ b/quasimetric_rl/modules/quasimetric_critic/models/__init__.py
@@ -25,11 +25,6 @@ class QuasimetricCritic(Module):
attrs.validators.le(1),
)))
encoder: Encoder.Conf = Encoder.Conf()
- target_quasimetric_model_ema: Optional[float] = attrs.field(
- default=None, validator=attrs.validators.optional(attrs.validators.and_(
- attrs.validators.ge(0),
- attrs.validators.le(1),
- )))
quasimetric_model: QuasimetricModel.Conf = QuasimetricModel.Conf()
latent_dynamics: Dynamics.Conf = Dynamics.Conf()
@@ -47,46 +42,34 @@ class QuasimetricCritic(Module):
)
return QuasimetricCritic(encoder, quasimetric_model, latent_dynamics,
share_embedding_from=share_embedding_from,
- target_encoder_ema=self.target_encoder_ema,
- target_quasimetric_model_ema=self.target_quasimetric_model_ema)
+ target_encoder_ema=self.target_encoder_ema)
borrowing_embedding: bool
encoder: Encoder
_target_encoder: Optional[Encoder]
target_encoder_ema: Optional[float]
quasimetric_model: QuasimetricModel
- _target_quasimetric_model: Optional[QuasimetricModel]
- target_quasimetric_model_ema: Optional[float]
latent_dynamics: Dynamics # FIXME: add BC to model loading & change name
def __init__(self, encoder: Encoder, quasimetric_model: QuasimetricModel, latent_dynamics: Dynamics,
- share_embedding_from: Optional['QuasimetricCritic'], target_encoder_ema: Optional[float],
- target_quasimetric_model_ema: Optional[float] = None):
+ share_embedding_from: Optional['QuasimetricCritic'], target_encoder_ema: Optional[float]):
super().__init__()
self.borrowing_embedding = share_embedding_from is not None
if share_embedding_from is not None:
encoder = share_embedding_from.encoder
- _target_encoder = share_embedding_from._target_encoder
+ target_encoder = share_embedding_from.target_encoder
assert target_encoder_ema == share_embedding_from.target_encoder_ema
quasimetric_model = share_embedding_from.quasimetric_model
- assert target_quasimetric_model_ema == share_embedding_from.target_quasimetric_model_ema
- _target_quasimetric_model = share_embedding_from._target_quasimetric_model
else:
if target_encoder_ema is None:
- _target_encoder = None
- else:
- _target_encoder = copy.deepcopy(encoder).requires_grad_(False).eval()
- if target_quasimetric_model_ema is None:
- _target_quasimetric_model = None
+ target_encoder = None
else:
- _target_quasimetric_model = copy.deepcopy(quasimetric_model).requires_grad_(False).eval()
+ target_encoder = copy.deepcopy(encoder).requires_grad_(False).eval()
self.encoder = encoder
+ self.add_module('_target_encoder', target_encoder)
self.target_encoder_ema = target_encoder_ema
- self.add_module('_target_encoder', _target_encoder)
self.quasimetric_model = quasimetric_model
- self.target_quasimetric_model_ema = target_quasimetric_model_ema
- self.add_module('_target_quasimetric_model', _target_quasimetric_model)
self.latent_dynamics = latent_dynamics
def forward(self, x: torch.Tensor, y: torch.Tensor, *, action: Optional[torch.Tensor] = None) -> torch.Tensor:
@@ -108,34 +91,16 @@ class QuasimetricCritic(Module):
else:
return self.encoder
- def get_encoder(self, target: bool = False) -> Encoder:
- return self.target_encoder if target else self.encoder
-
- @property
- def target_quasimetric_model(self) -> QuasimetricModel:
- if self._target_quasimetric_model is not None:
- return self._target_quasimetric_model
- else:
- return self.quasimetric_model
-
- def get_quasimetric_model(self, target: bool = False) -> QuasimetricModel:
- return self.target_quasimetric_model if target else self.quasimetric_model
-
@torch.no_grad()
- def update_target_models_(self):
- if not self.borrowing_embedding:
- if self.target_encoder_ema is not None:
- assert self._target_encoder is not None
- for p, p_target in zip(self.encoder.parameters(), self._target_encoder.parameters()):
- p_target.data.lerp_(p.data, 1 - self.target_encoder_ema)
- if self.target_quasimetric_model_ema is not None:
- assert self._target_quasimetric_model is not None
- for p, p_target in zip(self.quasimetric_model.parameters(), self._target_quasimetric_model.parameters()):
- p_target.data.lerp_(p.data, 1 - self.target_quasimetric_model_ema)
+ def update_target_encoder_(self):
+ if not self.borrowing_embedding and self.target_encoder_ema is not None:
+ assert self._target_encoder is not None
+ for p, p_target in zip(self.encoder.parameters(), self._target_encoder.parameters()):
+ p_target.data.lerp_(p.data, 1 - self.target_encoder_ema)
# for type hints
def __call__(self, x: torch.Tensor, y: torch.Tensor, *, action: Optional[torch.Tensor] = None) -> torch.Tensor:
return super().__call__(x, y, action=action)
def extra_repr(self) -> str:
- return f"borrowing_embedding={self.borrowing_embedding}, target_encoder_ema={self.target_encoder_ema}, target_quasimetric_model_ema={self.target_quasimetric_model_ema}"
+ return f"borrowing_embedding={self.borrowing_embedding}"
diff --git a/quasimetric_rl/modules/quasimetric_critic/models/latent_dynamics.py b/quasimetric_rl/modules/quasimetric_critic/models/latent_dynamics.py
index 9613674..0847076 100644
--- a/quasimetric_rl/modules/quasimetric_critic/models/latent_dynamics.py
+++ b/quasimetric_rl/modules/quasimetric_critic/models/latent_dynamics.py
@@ -20,7 +20,6 @@ class Dynamics(MLP):
# config / argparse uses this to specify behavior
arch: Tuple[int, ...] = (512, 512)
- action_arch: Tuple[int, ...] = ()
layer_norm: bool = True
latent_input: bool = True
raw_observation_input_arch: Optional[Tuple[int, ...]] = None
@@ -35,7 +34,6 @@ class Dynamics(MLP):
latent_input=self.latent_input,
raw_observation_input_arch=self.raw_observation_input_arch,
env_spec=env_spec,
- action_arch=self.action_arch,
hidden_sizes=self.arch,
layer_norm=self.layer_norm,
use_latent_space_proj=self.use_latent_space_proj,
@@ -43,8 +41,7 @@ class Dynamics(MLP):
fc_fuse=self.fc_fuse,
)
- action_input: Union[InputEncoding, nn.Sequential]
- action_arch: Tuple[int, ...]
+ action_input: InputEncoding
latent_input: bool
raw_observation_input: Optional[nn.Sequential]
residual: bool
@@ -53,50 +50,32 @@ class Dynamics(MLP):
def __init__(self, *, latent: LatentSpaceConf,
latent_input: bool,
raw_observation_input_arch: Optional[Tuple[int, ...]],
- env_spec: EnvSpec, action_arch: Tuple[int, ...],
- hidden_sizes: Tuple[int, ...], layer_norm: bool,
+ env_spec: EnvSpec, hidden_sizes: Tuple[int, ...], layer_norm: bool,
use_latent_space_proj: bool, residual: bool,
fc_fuse: bool):
- self.action_arch = action_arch
action_input = env_spec.make_action_input()
- action_outsize = action_input.output_size
- if action_arch:
- action_input = nn.Sequential(
- action_input,
- MLP(
- action_input.output_size,
- action_arch[-1],
- hidden_sizes=action_arch[:-1],
- activation_last_fc=True,
- layer_norm_last_fc=layer_norm,
- layer_norm=layer_norm,
- ),
- )
- action_outsize = action_arch[-1]
- mlp_input_size = action_outsize
+ mlp_input_size = action_input.output_size
mlp_output_size = latent.latent_size
self.latent_input = latent_input
if latent_input:
mlp_input_size += latent.latent_size
if raw_observation_input_arch is not None:
+ assert raw_observation_input_arch is not None
+ assert len(raw_observation_input_arch) >= 1
input_encoding = env_spec.make_observation_input()
- if len(raw_observation_input_arch) >= 1:
- raw_observation_input = MLP(
- input_encoding.output_size,
- raw_observation_input_arch[-1],
- hidden_sizes=raw_observation_input_arch[:-1],
- activation_last_fc=True,
- layer_norm_last_fc=layer_norm,
- layer_norm=layer_norm,
- )
- mlp_input_size += raw_observation_input.output_size
- raw_observation_input = nn.Sequential(
- input_encoding,
- raw_observation_input,
- )
- else:
- raw_observation_input = input_encoding
- mlp_input_size += input_encoding.output_size
+ raw_observation_input = MLP(
+ input_encoding.output_size,
+ raw_observation_input_arch[-1],
+ hidden_sizes=raw_observation_input_arch[:-1],
+ activation_last_fc=True,
+ layer_norm_last_fc=layer_norm,
+ layer_norm=layer_norm,
+ )
+ mlp_input_size += raw_observation_input.output_size
+ raw_observation_input = nn.Sequential(
+ input_encoding,
+ raw_observation_input,
+ )
else:
raw_observation_input = None
super().__init__(
@@ -134,7 +113,7 @@ class Dynamics(MLP):
inputs: List[torch.Tensor] = []
if self.raw_observation_input is not None:
inputs.append(self.raw_observation_input(observation))
- if self.latent_input:
+ if self.latent_input is not None:
inputs.append(zx)
inputs.append(self.action_input(action))
diff --git a/quasimetric_rl/modules/quasimetric_critic/models/quasimetric_model.py b/quasimetric_rl/modules/quasimetric_critic/models/quasimetric_model.py
index 11577ee..5d19754 100644
--- a/quasimetric_rl/modules/quasimetric_critic/models/quasimetric_model.py
+++ b/quasimetric_rl/modules/quasimetric_critic/models/quasimetric_model.py
@@ -20,7 +20,7 @@ class L2(torchqmet.QuasimetricBase):
super().__init__(input_size, num_components=1, guaranteed_quasimetric=True,
transforms=[], reduction='sum', discount=None)
- def compute_components(self, x: torch.Tensor, y: torch.Tensor, symmetric_upperbound: bool = False) -> torch.Tensor:
+ def compute_components(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
r'''
Inputs:
x (torch.Tensor): Shape [..., input_size]
@@ -41,11 +41,11 @@ def create_quasimetric_head_from_spec(spec: str) -> torchqmet.QuasimetricBase:
assert dim % components == 0, "IQE: dim must be divisible by components"
return torchqmet.IQE(dim, dim // components)
- def iqe2(*, dim: int, components: int, norm_delta: bool = False, fake_grad: bool = False, reduction: str = 'maxl12_sm') -> torchqmet.IQE:
+ def iqe2(*, dim: int, components: int, scale: bool = False, norm_delta: bool = False, fake_grad: bool = False) -> torchqmet.IQE:
assert dim % components == 0, "IQE: dim must be divisible by components"
return torchqmet.IQE2(
dim, dim // components,
- reduction=reduction,
+ reduction='maxl12_sm' if not scale else 'maxl12_sm_scale',
learned_delta=True,
learned_div=False,
div_init_mul=0.25,
@@ -80,7 +80,6 @@ class QuasimetricModel(Module):
class Conf:
# config / argparse uses this to specify behavior
- input_slice_size: Optional[int] = attrs.field(default=None, validator=attrs.validators.optional(attrs.validators.gt(0))) # take the first n dims
projector_arch: Optional[Tuple[int, ...]] = (512,)
projector_layer_norm: bool = True
projector_dropout: float = attrs.field(default=0., validator=attrs.validators.ge(0)) # TD-MPC2 uses 0.01
@@ -89,11 +88,8 @@ class QuasimetricModel(Module):
quasimetric_head_spec: str = 'iqe(dim=2048,components=64)'
def make(self, *, input_size: int) -> 'QuasimetricModel':
- if self.input_slice_size is not None:
- assert self.input_slice_size <= input_size, f'input_slice_size={self.input_slice_size} > input_size={input_size}'
return QuasimetricModel(
input_size=input_size,
- input_slice_size=self.input_slice_size or input_size,
projector_arch=self.projector_arch,
projector_layer_norm=self.projector_layer_norm,
projector_dropout=self.projector_dropout,
@@ -103,36 +99,30 @@ class QuasimetricModel(Module):
)
input_size: int
- input_slice_size: int
projector: Union[Identity, MLP]
quasimetric_head: torchqmet.QuasimetricBase
- def __init__(self, *, input_size: int, input_slice_size: int, projector_arch: Optional[Tuple[int, ...]],
+ def __init__(self, *, input_size: int, projector_arch: Optional[Tuple[int, ...]],
projector_layer_norm: bool, projector_dropout: float, projector_weight_norm: bool,
projector_unit_norm: bool, quasimetric_head_spec: str):
super().__init__()
self.input_size = input_size
- self.input_slice_size = input_slice_size
self.quasimetric_head = create_quasimetric_head_from_spec(quasimetric_head_spec)
if projector_arch is None:
- assert input_slice_size == self.quasimetric_head.input_size, \
- f'no projector but latent input_slice_size={input_slice_size}, quasimetric_head.input_size={self.quasimetric_head.input_size}'
+ assert input_size == self.quasimetric_head.input_size, \
+ f'no projector but latent input_size={input_size}, quasimetric_head.input_size={self.quasimetric_head.input_size}'
self.projector = Identity()
else:
self.projector = MLP(
- input_slice_size, self.quasimetric_head.input_size,
+ input_size, self.quasimetric_head.input_size,
hidden_sizes=projector_arch,
layer_norm=projector_layer_norm,
dropout=projector_dropout,
weight_norm_last_fc=projector_weight_norm,
- unit_norm_last_fc=projector_unit_norm,
- )
+ unit_norm_last_fc=projector_unit_norm)
def forward(self, zx: LatentTensor, zy: LatentTensor, *, bidirectional: bool = False,
- proj_grad_enabled: Tuple[bool, bool] = (True, True), reduced: bool = True,
- symmetric_upperbound: bool = False) -> torch.Tensor:
- zx = zx[..., :self.input_slice_size]
- zy = zy[..., :self.input_slice_size]
+ proj_grad_enabled: Tuple[bool, bool] = (True, True)) -> torch.Tensor:
with self.projector.requiring_grad(proj_grad_enabled[0]):
px = self.projector(zx) # [B x D]
with self.projector.requiring_grad(proj_grad_enabled[1]):
@@ -142,7 +132,7 @@ class QuasimetricModel(Module):
px, py = torch.broadcast_tensors(px, py)
px, py = torch.stack([px, py], dim=-2), torch.stack([py, px], dim=-2) # [B x 2 x D]
- return self.quasimetric_head(px, py, reduced=reduced, symmetric_upperbound=symmetric_upperbound)
+ return self.quasimetric_head(px, py)
def parameters(self, recurse: bool = True, *, include_head: bool = True) -> Iterator[Parameter]:
if include_head:
@@ -155,12 +145,8 @@ class QuasimetricModel(Module):
# for type hint
def __call__(self, zx: LatentTensor, zy: LatentTensor, *, bidirectional: bool = False,
- proj_grad_enabled: Tuple[bool, bool] = (True, True), reduced: bool = True,
- symmetric_upperbound: bool = False) -> torch.Tensor:
- return super().__call__(zx, zy, bidirectional=bidirectional,
- proj_grad_enabled=proj_grad_enabled,
- reduced=reduced,
- symmetric_upperbound=symmetric_upperbound)
+ proj_grad_enabled: Tuple[bool, bool] = (True, True)) -> torch.Tensor:
+ return super().__call__(zx, zy, bidirectional=bidirectional, proj_grad_enabled=proj_grad_enabled)
def extra_repr(self) -> str:
- return f"input_size={self.input_size}, input_slice_size={self.input_slice_size}"
+ return f"input_size={self.input_size}"
diff --git a/quasimetric_rl/modules/utils.py b/quasimetric_rl/modules/utils.py
index 0241ac2..5d7a845 100644
--- a/quasimetric_rl/modules/utils.py
+++ b/quasimetric_rl/modules/utils.py
@@ -33,18 +33,11 @@ InfoValT = Union[InfoT, float, torch.Tensor]
@attrs.define(kw_only=True)
class LossResult:
- loss: InfoValT
+ loss: Union[torch.Tensor, float]
info: InfoT
def __attrs_post_init__(self):
- def test_loss(loss: InfoValT):
- if isinstance(loss, Mapping):
- for v in loss.values():
- test_loss(v)
- else:
- assert isinstance(loss, (int, float)) or loss.numel() == 1
-
- test_loss(self.loss)
+ assert isinstance(self.loss, (int, float)) or self.loss.numel() == 1
# detach info tensors
def detach(d: InfoValT) -> InfoValT:
@@ -57,40 +50,17 @@ class LossResult:
object.__setattr__(self, "info", detach(self.info))
- def map_losses(self, fn: Callable[[torch.Tensor | float], torch.Tensor | float]) -> InfoValT:
- def map_loss(loss: InfoValT):
- if isinstance(loss, Mapping):
- return {k: map_loss(v) for k, v in loss.items()}
- else:
- return fn(loss)
-
- return map_loss(self.loss)
-
@classmethod
def empty(cls) -> 'LossResult':
return LossResult(loss=0, info={})
- @property
- def total_loss(self) -> torch.Tensor:
- l = 0
- def add_loss(loss: InfoValT):
- nonlocal l
- if isinstance(loss, Mapping):
- for v in loss.values():
- add_loss(v)
- else:
- l += loss
- add_loss(self.loss)
- assert isinstance(l, torch.Tensor)
- return l
-
@classmethod
def combine(cls, results: Mapping[str, 'LossResult'], **kwargs) -> 'LossResult':
info = {k: r.info for k, r in results.items()}
assert info.keys().isdisjoint(kwargs.keys())
info.update(**kwargs)
return LossResult(
- loss={k: r.loss for k, r in results.items()},
+ loss=sum(r.loss for r in results.values()),
info=info,
)
diff --git a/quasimetric_rl/utils/logging.py b/quasimetric_rl/utils/logging.py
index 6efab8d..2a9e23f 100644
--- a/quasimetric_rl/utils/logging.py
+++ b/quasimetric_rl/utils/logging.py
@@ -4,8 +4,6 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
-from typing import *
-
import sys
import os
import logging
@@ -36,12 +34,10 @@ class TqdmLoggingHandler(logging.Handler):
class MultiLineFormatter(logging.Formatter):
- _fmt: str
def __init__(self, fmt=None, datefmt=None, style='%'):
assert style == '%'
super(MultiLineFormatter, self).__init__(fmt, datefmt, style)
- assert fmt is not None
self.multiline_fmt = fmt
def format(self, record):
@@ -79,7 +75,7 @@ class MultiLineFormatter(logging.Formatter):
output += '\n'.join(
self.multiline_fmt % dict(record.__dict__, message=line)
for index, line
- in enumerate(record.exc_text.decode(sys.getfilesystemencoding(), 'replace').splitlines()) # type: ignore
+ in enumerate(record.exc_text.decode(sys.getfilesystemencoding(), 'replace').splitlines())
)
return output
@@ -100,7 +96,7 @@ def configure(logging_file, log_level=logging.INFO, level_prefix='', prefix='',
sys.excepthook = handle_exception # automatically log uncaught errors
- handlers: List[logging.Handler] = []
+ handlers = []
if write_to_stdout:
handlers.append(TqdmLoggingHandler())
Submodule third_party/torch-quasimetric 186ed87..c5213ff (rewind):
diff --git a/third_party/torch-quasimetric/torchqmet/__init__.py b/third_party/torch-quasimetric/torchqmet/__init__.py
index b776de5..0008b7a 100644
--- a/third_party/torch-quasimetric/torchqmet/__init__.py
+++ b/third_party/torch-quasimetric/torchqmet/__init__.py
@@ -54,22 +54,16 @@ class QuasimetricBase(nn.Module, metaclass=abc.ABCMeta):
'''
pass
- def forward(self, x: torch.Tensor, y: torch.Tensor, *, reduced: bool = True, **kwargs) -> torch.Tensor:
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
assert x.shape[-1] == y.shape[-1] == self.input_size
- d = self.compute_components(x, y, **kwargs)
+ d = self.compute_components(x, y)
d: torch.Tensor = self.transforms(d)
- scale = self.scale
- if not self.training:
- scale = scale.detach()
- if reduced:
- return self.reduction(d) * scale
- else:
- return d * scale
-
- def __call__(self, x: torch.Tensor, y: torch.Tensor, reduced: bool = True, **kwargs) -> torch.Tensor:
+ return self.reduction(d) * self.scale
+
+ def __call__(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
# Manually define for typing
# https://github.com/pytorch/pytorch/issues/45414
- return super().__call__(x, y, reduced=reduced, **kwargs)
+ return super().__call__(x, y)
def extra_repr(self) -> str:
return f"guaranteed_quasimetric={self.guaranteed_quasimetric}\ninput_size={self.input_size}, num_components={self.num_components}" + (
@@ -82,5 +76,5 @@ from .iqe import IQE, IQE2
from .mrn import MRN, MRNFixed
from .neural_norms import DeepNorm, WideNorm
-__all__ = ['PQE', 'PQELH', 'PQEGG', 'IQE', 'IQE2', 'MRN', 'MRNFixed', 'DeepNorm', 'WideNorm']
+__all__ = ['PQE', 'PQELH', 'PQEGG', 'IQE', 'MRN', 'MRNFixed', 'DeepNorm', 'WideNorm']
__version__ = "0.1.0"
diff --git a/third_party/torch-quasimetric/torchqmet/iqe.py b/third_party/torch-quasimetric/torchqmet/iqe.py
index f084b6d..bc03f05 100644
--- a/third_party/torch-quasimetric/torchqmet/iqe.py
+++ b/third_party/torch-quasimetric/torchqmet/iqe.py
@@ -12,6 +12,7 @@ from . import QuasimetricBase
# The PQELH function.
+@torch.jit.script
def f_PQELH(h: torch.Tensor): # PQELH: strictly monotonically increasing mapping from [0, +infty) -> [0, 1)
return -torch.expm1(-h)
@@ -20,28 +21,28 @@ def iqe_tensor_delta(x: torch.Tensor, y: torch.Tensor, delta: torch.Tensor, div_
D = x.shape[-1] # D: component_dim
# ignore pairs that x >= y
- valid = (x < y) # [..., K, D]
+ valid = (x < y)
# sort to better count
- xy = torch.cat(torch.broadcast_tensors(x, y), dim=-1) # [..., K, 2D]
+ xy = torch.cat(torch.broadcast_tensors(x, y), dim=-1)
sxy, ixy = xy.sort(dim=-1)
# neg_inc: the **negated** increment of **input** of f at sorted locations
# inc = torch.gather(delta * valid, dim=-1, index=ixy % D) * torch.where(ixy < D, 1, -1)
- neg_inc = torch.gather(delta * valid, dim=-1, index=ixy % D) * torch.where(ixy < D, -1, 1) # [..., K, 2D-sort]
+ neg_inc = torch.gather(delta * valid, dim=-1, index=ixy % D) * torch.where(ixy < D, -1, 1)
# neg_incf: the **negated** increment of **output** of f at sorted locations
neg_f_input = torch.cumsum(neg_inc, dim=-1) / div_pre_f[:, None]
if fake_grad:
neg_f_input__grad_path = neg_f_input.clone()
- neg_f_input__grad_path.data.clamp_(min=-15) # fake grad
+ neg_f_input__grad_path.data.clamp_(max=17) # fake grad
neg_f_input = neg_f_input__grad_path + (
neg_f_input - neg_f_input__grad_path
).detach()
neg_f = torch.expm1(neg_f_input)
- neg_incf = torch.diff(neg_f, dim=-1, prepend=neg_f.new_zeros(()).expand_as(neg_f[..., :1]))
+ neg_incf = torch.cat([neg_f.narrow(-1, 0, 1), torch.diff(neg_f, dim=-1)], dim=-1)
# reduction
if neg_incf.ndim == 3:
@@ -62,6 +63,7 @@ def iqe_tensor_delta(x: torch.Tensor, y: torch.Tensor, delta: torch.Tensor, div_
return comp
+
def iqe(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
D = x.shape[-1] # D: dim_per_component
@@ -87,35 +89,19 @@ def iqe(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
# neg_inp = torch.where(neg_inp_copies == 0, 0., -delta)
# f output: 0 -> 0, x -> 1.
neg_f = (neg_inp_copies < 0) * (-1.)
- neg_incf = torch.diff(neg_f, dim=-1, prepend=neg_f.new_zeros(()).expand_as(neg_f[..., :1]))
+ neg_incf = torch.cat([neg_f.narrow(-1, 0, 1), torch.diff(neg_f, dim=-1)], dim=-1)
# reduction
return (sxy * neg_incf).sum(-1)
-def is_notebook():
- r"""
- Inspired by
- https://github.com/tqdm/tqdm/blob/cc372d09dcd5a5eabdc6ed4cf365bdb0be004d44/tqdm/autonotebook.py
- """
- import sys
- try:
- get_ipython = sys.modules['IPython'].get_ipython
- if 'IPKernelApp' not in get_ipython().config: # pragma: no cover
- raise ImportError("console")
- except Exception:
- return False
- else: # pragma: no cover
- return True
-
-
-if torch.__version__ >= '2.0.1' and not is_notebook(): # well, broken process pool in notebooks
+if torch.__version__ >= '2.0.1' and False: # well, broken process pool in notebooks
iqe = torch.compile(iqe)
iqe_tensor_delta = torch.compile(iqe_tensor_delta)
# iqe = torch.compile(iqe, dynamic=True)
else:
- iqe = torch.jit.script(iqe) # type: ignore
- iqe_tensor_delta = torch.jit.script(iqe_tensor_delta) # type: ignore
+ iqe = torch.jit.script(iqe)
+ iqe_tensor_delta = torch.jit.script(iqe_tensor_delta)
class IQE(QuasimetricBase):
@@ -245,8 +231,8 @@ class IQE2(IQE):
ema_weight: float = 0.95):
super().__init__(input_size, dim_per_component, transforms=transforms, reduction=reduction,
discount=discount, warn_if_not_quasimetric=warn_if_not_quasimetric)
- self.component_dropout_thresh = tuple(component_dropout_thresh) # type: ignore
- self.dropout_p_thresh = tuple(dropout_p_thresh) # type: ignore
+ self.component_dropout_thresh = tuple(component_dropout_thresh)
+ self.dropout_p_thresh = tuple(dropout_p_thresh)
self.dropout_batch_frac = float(dropout_batch_frac)
self.fake_grad = fake_grad
assert 0 <= self.dropout_batch_frac <= 1
@@ -263,7 +249,7 @@ class IQE2(IQE):
# )
self.register_parameter(
'raw_delta',
- torch.nn.Parameter( # type: ignore
+ torch.nn.Parameter(
torch.zeros(self.latent_2d_shape).requires_grad_()
)
)
@@ -284,7 +270,7 @@ class IQE2(IQE):
self.register_parameter(
'raw_div',
- torch.nn.Parameter(torch.zeros(self.num_components).requires_grad_()) # type: ignore
+ torch.nn.Parameter(torch.zeros(self.num_components).requires_grad_())
)
else:
self.register_buffer(
@@ -299,11 +285,11 @@ class IQE2(IQE):
self.div_init_mul = div_init_mul
self.mul_kind = mul_kind
- self.last_components = None # type: ignore
- self.last_drop_p = None # type: ignore
+ self.last_components = None
+ self.last_drop_p = None
- def compute_components(self, x: torch.Tensor, y: torch.Tensor, *, symmetric_upperbound: bool = False) -> torch.Tensor:
+ def compute_components(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
# if self.raw_delta is None:
# components = super().compute_components(x, y)
# else:
@@ -313,9 +299,6 @@ class IQE2(IQE):
delta.data.clamp_(max=1e3 / (self.latent_2d_shape[-1] / 8))
div_pre_f.data.clamp_(min=1e-3)
- if symmetric_upperbound:
- x, y = torch.minimum(x, y), torch.maximum(x, y)
-
components = iqe_tensor_delta(
x=x.unflatten(-1, self.latent_2d_shape),
y=y.unflatten(-1, self.latent_2d_shape),
diff --git a/third_party/torch-quasimetric/torchqmet/reductions.py b/third_party/torch-quasimetric/torchqmet/reductions.py
index a87be8b..7681242 100644
--- a/third_party/torch-quasimetric/torchqmet/reductions.py
+++ b/third_party/torch-quasimetric/torchqmet/reductions.py
@@ -59,11 +59,6 @@ class Mean(ReductionBase):
return d.mean(dim=-1)
-class L2(ReductionBase):
- def reduce_distance(self, d: torch.Tensor) -> torch.Tensor:
- return d.norm(p=2, dim=-1)
-
-
class MaxMean(ReductionBase):
r'''
`maxmean` from Neural Norms paper:
@@ -149,7 +144,7 @@ class MaxL12_PGsm(ReductionBase):
super().__init__(input_num_components=input_num_components, discount=discount)
self.raw_alpha = nn.Parameter(torch.tensor([0., 0., 0., 0.], dtype=torch.float32).requires_grad_()) # pre normalizing
self.raw_alpha_w = nn.Parameter(torch.tensor([0., 0., 0.], dtype=torch.float32).requires_grad_()) # pre normalizing
- self.last_logp = None # type: ignore
+ self.last_logp = None
self.on_pi = True
# self.last_p = None
@@ -227,7 +222,7 @@ class MaxL12_PG3(ReductionBase):
super().__init__(input_num_components=input_num_components, discount=discount)
self.raw_alpha = nn.Parameter(torch.tensor([0., 0., 0.], dtype=torch.float32).requires_grad_()) # pre normalizing
self.raw_alpha_w = torch.tensor([], dtype=torch.float32) # just to make logging easier
- self.last_logp = None # type: ignore
+ self.last_logp = None
self.on_pi = True
# self.last_p = None
@@ -327,7 +322,6 @@ class DeepLinearNetWeightedSum(ReductionBase):
REDUCTIONS: Mapping[str, Type[ReductionBase]] = dict(
sum=Sum,
mean=Mean,
- l2=L2,
maxmean=MaxMean,
maxl12=MaxL12,
maxl12_sm=MaxL12_sm,
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment