Created
March 15, 2024 17:44
-
-
Save ssnl/c7ac69147331b1db6232f640603a9f75 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/offline/main.py b/offline/main.py | |
index 5502749..1581d30 100644 | |
--- a/offline/main.py | |
+++ b/offline/main.py | |
@@ -43,6 +43,7 @@ class Conf(BaseConf): | |
total_optim_steps: int = attrs.field(default=int(2e5), validator=attrs.validators.gt(0)) | |
log_steps: int = attrs.field(default=250, validator=attrs.validators.gt(0)) # type: ignore | |
+ eval_before_training: bool = False | |
eval_steps: int = attrs.field(default=20000, validator=attrs.validators.gt(0)) # type: ignore | |
save_steps: int = attrs.field(default=50000, validator=attrs.validators.gt(0)) # type: ignore | |
num_eval_episodes: int = attrs.field(default=50, validator=attrs.validators.ge(0)) # type: ignore | |
@@ -57,7 +58,7 @@ cs.store(name='config', node=Conf()) | |
@hydra.main(version_base=None, config_name="config") | |
def train(dict_cfg: DictConfig): | |
cfg: Conf = Conf.from_DictConfig(dict_cfg) # type: ignore | |
- cfg.setup_for_experiment() # checking & setup logging | |
+ wandb_run = cfg.setup_for_experiment() # checking & setup logging | |
dataset = cfg.env.make() | |
@@ -116,7 +117,7 @@ def train(dict_cfg: DictConfig): | |
logging.info(f"Checkpointed to {relpath}") | |
def eval(epoch, it, optim_steps): | |
- val_result_allenvs = trainer.evaluate() | |
+ val_result_allenvs = trainer.evaluate(desc=f"opt{optim_steps:08d}") | |
val_results.clear() | |
val_results.append(dict( | |
epoch=epoch, | |
@@ -128,25 +129,10 @@ def train(dict_cfg: DictConfig): | |
epoch=epoch, | |
it=it, | |
optim_steps=optim_steps, | |
- result={}, | |
+ result={ | |
+ k: val_result.summarize() for k, val_result in val_result_allenvs.items() | |
+ }, | |
)) | |
- for k, val_result in val_result_allenvs.items(): | |
- succ_rate_ts = ( | |
- None if val_result.timestep_is_success is None | |
- else torch.stack([_x.mean(dtype=torch.float32) for _x in val_result.timestep_is_success]) | |
- ) | |
- hitting_time = val_result.capped_hitting_time | |
- summary = dict( | |
- epi_return=val_result.episode_return, | |
- epi_score=val_result.episode_score, | |
- succ_rate_ts=succ_rate_ts, | |
- succ_rate=val_result.is_success, | |
- hitting_time=hitting_time, | |
- ) | |
- for kk, v in val_result.extra_timestep_results.items(): | |
- summary[kk] = torch.stack([_v.mean(dtype=torch.float32) for _v in v]) | |
- summary[kk + '_last'] = torch.stack([_v[-1] for _v in v]) | |
- val_summaries[-1]['result'][k] = summary | |
averaged_info = cfg.stage_log_info(dict(eval=val_summaries[-1]), optim_steps) | |
with open(os.path.join(cfg.output_dir, 'eval.log'), 'a') as f: # type: ignore | |
print(json.dumps(averaged_info), file=f) | |
@@ -184,15 +170,18 @@ def train(dict_cfg: DictConfig): | |
) | |
num_total_epochs = int(np.ceil(cfg.total_optim_steps / trainer.num_batches)) | |
+ logging.info(f"save folder: {cfg.output_dir}") | |
+ logging.info(f"wandb: {wandb_run.get_url()}") | |
# Training loop | |
optim_steps = 0 | |
- # eval(0, 0, optim_steps) | |
+ if cfg.eval_before_training: | |
+ eval(0, 0, optim_steps) | |
save(0, 0, optim_steps) | |
if start_epoch < num_total_epochs: | |
for epoch in range(num_total_epochs): | |
epoch_desc = f"Train epoch {epoch:05d}/{num_total_epochs:05d}" | |
- for it, (data, data_info) in enumerate(tqdm(trainer.iter_training_data(), total=trainer.num_batches, desc=epoch_desc)): | |
+ for it, (data, data_info) in enumerate(tqdm(trainer.iter_training_data(), total=trainer.num_batches, desc=epoch_desc, leave=False)): | |
step_counter.update_then_record_alerts() | |
optim_steps += 1 | |
@@ -231,4 +220,9 @@ if __name__ == '__main__': | |
# set up some hydra flags before parsing | |
os.environ['HYDRA_FULL_ERROR'] = str(int(FLAGS.DEBUG)) | |
- train() | |
+ try: | |
+ train() | |
+ except: | |
+ import wandb | |
+ wandb.finish(1) # sometimes crashes are not reported?? let's be safe | |
+ raise | |
diff --git a/offline/trainer.py b/offline/trainer.py | |
index e7f48cf..c75abed 100644 | |
--- a/offline/trainer.py | |
+++ b/offline/trainer.py | |
@@ -11,52 +11,7 @@ import torch | |
import torch.utils.data | |
from quasimetric_rl.modules import QRLConf, QRLAgent, QRLLosses, InfoT | |
-from quasimetric_rl.data import BatchData, Dataset, EpisodeData, EnvSpec, OfflineEnv | |
- | |
- | |
-def first_nonzero(arr: torch.Tensor, dim: int = -1, invalid_val: int = -1): | |
- mask = (arr != 0) | |
- return torch.where(mask.any(dim=dim), mask.to(torch.uint8).argmax(dim=dim), invalid_val) | |
- | |
- | |
-@attrs.define(kw_only=True) | |
-class EvalEpisodeResult: | |
- timestep_reward: List[torch.Tensor] | |
- episode_return: torch.Tensor | |
- episode_score: torch.Tensor | |
- timestep_is_success: Optional[List[torch.Tensor]] | |
- is_success: Optional[torch.Tensor] | |
- hitting_time: Optional[torch.Tensor] | |
- extra_timestep_results: Mapping[str, List[torch.Tensor]] | |
- | |
- @property | |
- def capped_hitting_time(self) -> Optional[torch.Tensor]: | |
- # if not hit -> |ts| + 1 | |
- if self.hitting_time is None: | |
- return None | |
- assert self.timestep_is_success is not None | |
- return torch.stack([torch.where(_x < 0, _succ.shape[0] + 1, _x) for _x, _succ in zip(self.hitting_time, self.timestep_is_success)]) | |
- | |
- @classmethod | |
- def from_timestep_reward_is_success(cls, dataset: Dataset, | |
- timestep_reward: List[torch.Tensor], | |
- timestep_is_success: Optional[List[torch.Tensor]], | |
- extra_timestep_results) -> Self: | |
- return cls( | |
- timestep_reward=timestep_reward, | |
- episode_return=torch.stack([r.sum() for r in timestep_reward]), | |
- episode_score=dataset.normalize_score(timestep_reward), | |
- timestep_is_success=timestep_is_success, | |
- is_success=( | |
- None if timestep_is_success is None | |
- else torch.stack([_x.any(dim=-1) for _x in timestep_is_success]) | |
- ), | |
- hitting_time=( | |
- None if timestep_is_success is None | |
- else torch.stack([first_nonzero(_x, dim=-1) for _x in timestep_is_success]) | |
- ), # NB this is off by 1 | |
- extra_timestep_results=dict(extra_timestep_results), | |
- ) | |
+from quasimetric_rl.data import BatchData, Dataset, EpisodeData, EnvSpec, OfflineEnv, interaction | |
class Trainer(object): | |
@@ -133,32 +88,20 @@ class Trainer(object): | |
adistn = self.agent.actor(obs[None].to(self.device), goal[None].to(self.device)) | |
return adistn.mode.cpu().numpy()[0] | |
- rollout = Dataset.collect_rollout_general( | |
+ rollout = interaction.collect_rollout( | |
actor, env=env, env_spec=EnvSpec.from_env(env), | |
max_episode_length=env.max_episode_steps) | |
return rollout | |
- def evaluate(self) -> Mapping[str, EvalEpisodeResult]: | |
+ def evaluate(self, desc=None) -> Mapping[str, interaction.EvalEpisodeResult]: | |
envs = self.dataset.create_eval_envs(self.eval_seed) | |
- results: Dict[str, EvalEpisodeResult] = {} | |
+ results: Dict[str, interaction.EvalEpisodeResult] = {} | |
for k, env in envs.items(): | |
rollouts: List[EpisodeData] = [] | |
- for _ in tqdm(range(self.num_eval_episodes), desc=f'eval/{k}'): | |
+ this_desc = f'eval/{k}' | |
+ if desc is not None: | |
+ this_desc = f'{desc}/{this_desc}' | |
+ for _ in tqdm(range(self.num_eval_episodes), desc=this_desc): | |
rollouts.append(self.collect_eval_rollout(env=env)) | |
- results[k] = EvalEpisodeResult.from_timestep_reward_is_success( | |
- self.dataset, | |
- timestep_reward=[rollout.rewards for rollout in rollouts], | |
- timestep_is_success=( | |
- None | |
- if len(rollouts) == 0 or 'is_success' not in rollouts[0].transition_infos | |
- else [rollout.transition_infos['is_success'] for rollout in rollouts] | |
- ), | |
- extra_timestep_results=( | |
- {} if len(rollouts) == 0 else | |
- { | |
- k: [rollout.transition_infos[k] for rollout in rollouts] | |
- for k in rollouts[0].transition_infos.keys() if k != 'is_success' | |
- } | |
- ), | |
- ) | |
+ results[k] = interaction.EvalEpisodeResult.from_episode_rollouts(self.dataset, rollouts) | |
return results | |
diff --git a/online/main.py b/online/main.py | |
index cd58f7a..c16e094 100644 | |
--- a/online/main.py | |
+++ b/online/main.py | |
@@ -50,7 +50,7 @@ cs.store(name='config', node=Conf()) | |
@hydra.main(version_base=None, config_name="config") | |
def train(dict_cfg: DictConfig): | |
cfg: Conf = Conf.from_DictConfig(dict_cfg) # type: ignore | |
- cfg.setup_for_experiment() # checking & setup logging | |
+ wandb_run = cfg.setup_for_experiment() # checking & setup logging | |
replay_buffer = cfg.env.make() | |
@@ -85,7 +85,7 @@ def train(dict_cfg: DictConfig): | |
logging.info(f"Checkpointed to {relpath}") | |
def eval(env_steps, optim_steps): | |
- val_result_allenvs = trainer.evaluate() | |
+ val_result_allenvs = trainer.evaluate(desc=f'env{env_steps:08d}_opt{optim_steps:08d}') | |
val_results.clear() | |
val_results.append(dict( | |
env_steps=env_steps, | |
@@ -95,18 +95,10 @@ def train(dict_cfg: DictConfig): | |
val_summaries.append(dict( | |
env_steps=env_steps, | |
optim_steps=optim_steps, | |
- result={}, | |
+ result={ | |
+ k: val_result.summarize() for k, val_result in val_result_allenvs.items() | |
+ }, | |
)) | |
- for k, val_result in val_result_allenvs.items(): | |
- succ_rate_ts = val_result.timestep_is_success.mean(dtype=torch.float32, dim=-1) | |
- hitting_time = val_result.capped_hitting_time | |
- val_summaries[-1]['result'][k] = dict( | |
- epi_return=val_result.episode_return, | |
- epi_score=val_result.episode_score, | |
- succ_rate_ts=succ_rate_ts, | |
- succ_rate=val_result.is_success, | |
- hitting_time=hitting_time, | |
- ) | |
averaged_info = cfg.stage_log_info(dict(eval=val_summaries[-1]), optim_steps) | |
with open(os.path.join(cfg.output_dir, 'eval.log'), 'a') as f: # type: ignore | |
print(json.dumps(averaged_info), file=f) | |
@@ -121,6 +113,8 @@ def train(dict_cfg: DictConfig): | |
), | |
) | |
+ logging.info(f"save folder: {cfg.output_dir}") | |
+ logging.info(f"wandb: {wandb_run.get_url()}") | |
# Training loop | |
eval(0, 0); save(0, 0) | |
for optim_steps, (env_steps, next_iter_new_env_step, data, data_info) in enumerate(trainer.iter_training_data(), start=1): | |
@@ -162,4 +156,9 @@ if __name__ == '__main__': | |
# set up some hydra flags before parsing | |
os.environ['HYDRA_FULL_ERROR'] = str(int(FLAGS.DEBUG)) | |
- train() | |
+ try: | |
+ train() | |
+ except: | |
+ import wandb | |
+ wandb.finish(1) # sometimes crashes are not reported?? let's be safe | |
+ raise | |
diff --git a/online/trainer.py b/online/trainer.py | |
index 01faad8..8926871 100644 | |
--- a/online/trainer.py | |
+++ b/online/trainer.py | |
@@ -12,43 +12,11 @@ import torch | |
import torch.utils.data | |
from quasimetric_rl.modules import QRLConf, QRLAgent, QRLLosses, InfoT | |
-from quasimetric_rl.data import Dataset, BatchData, EpisodeData, MultiEpisodeData | |
+from quasimetric_rl.data import BatchData, EpisodeData, interaction | |
from quasimetric_rl.data.online import ReplayBuffer, OnlineFixedLengthEnv | |
from quasimetric_rl.utils import tqdm | |
-def first_nonzero(arr: torch.Tensor, dim: int = -1, invalid_val: int = -1): | |
- mask = (arr != 0) | |
- return torch.where(mask.any(dim=dim), mask.to(torch.uint8).argmax(dim=dim), invalid_val) | |
- | |
- | |
-@attrs.define(kw_only=True) | |
-class EvalEpisodeResult: | |
- timestep_reward: torch.Tensor | |
- episode_return: torch.Tensor | |
- episode_score: torch.Tensor | |
- timestep_is_success: torch.Tensor | |
- is_success: torch.Tensor | |
- hitting_time: torch.Tensor | |
- | |
- @property | |
- def capped_hitting_time(self) -> torch.Tensor: | |
- # if not hit -> |ts| + 1 | |
- return torch.stack([torch.where(_x < 0, self.timestep_is_success.shape[-1] + 1, _x) for _x in self.hitting_time]) | |
- | |
- @classmethod | |
- def from_timestep_reward_is_success(cls, dataset: Dataset, timestep_reward: torch.Tensor, | |
- timestep_is_success: torch.Tensor) -> Self: | |
- return cls( | |
- timestep_reward=timestep_reward, | |
- episode_return=timestep_reward.sum(-1), | |
- episode_score=dataset.normalize_score(cast(Sequence[torch.Tensor], timestep_reward)), | |
- timestep_is_success=timestep_is_success, | |
- is_success=timestep_is_success.any(dim=-1), | |
- hitting_time=first_nonzero(timestep_is_success, dim=-1), # NB this is off by 1 | |
- ) | |
- | |
- | |
@attrs.define(kw_only=True) | |
class InteractionConf: | |
total_env_steps: int = attrs.field(default=int(1e6), validator=attrs.validators.gt(0)) | |
@@ -165,23 +133,18 @@ class Trainer(object): | |
self.replay.add_rollout(rollout) | |
return rollout | |
- def evaluate(self) -> Mapping[str, EvalEpisodeResult]: | |
+ def evaluate(self, desc=None) -> Mapping[str, interaction.EvalEpisodeResult]: | |
envs = self.make_evaluate_envs() | |
- results: Dict[str, EvalEpisodeResult] = {} | |
+ results: Dict[str, interaction.EvalEpisodeResult] = {} | |
for k, env in envs.items(): | |
rollouts = [] | |
- for _ in tqdm(range(self.num_eval_episodes), desc=f'eval/{k}'): | |
+ this_desc = f'eval/{k}' | |
+ if desc is not None: | |
+ this_desc = f'{desc}/{this_desc}' | |
+ for _ in tqdm(range(self.num_eval_episodes), desc=this_desc): | |
rollouts.append(self.collect_rollout(eval=True, store=False, env=env)) | |
- mrollouts = MultiEpisodeData.cat(rollouts) | |
- results[k] = EvalEpisodeResult.from_timestep_reward_is_success( | |
- self.replay, | |
- mrollouts.rewards.reshape( | |
- self.num_eval_episodes, env.episode_length, | |
- ), | |
- mrollouts.transition_infos['is_success'].reshape( | |
- self.num_eval_episodes, env.episode_length, | |
- ), | |
- ) | |
+ results[k] = interaction.EvalEpisodeResult.from_episode_rollouts( | |
+ self.replay, rollouts) | |
return results | |
def iter_training_data(self) -> Iterator[Tuple[int, bool, BatchData, InfoT]]: | |
@@ -197,7 +160,7 @@ class Trainer(object): | |
""" | |
def yield_data(): | |
num_transitions = self.replay.num_transitions_realized | |
- for icyc in tqdm(range(self.num_samples_per_cycle), desc=f"{num_transitions} env steps, train batches"): | |
+ for icyc in tqdm(range(self.num_samples_per_cycle), desc=f"{num_transitions} env steps, train batches", leave=False): | |
data_t0 = time.time() | |
data = self.sample() | |
info = dict( | |
diff --git a/quasimetric_rl/base_conf.py b/quasimetric_rl/base_conf.py | |
index b3a332e..e9f2202 100644 | |
--- a/quasimetric_rl/base_conf.py | |
+++ b/quasimetric_rl/base_conf.py | |
@@ -172,7 +172,7 @@ class BaseConf(abc.ABC): | |
else: | |
raise RuntimeError(f'Output directory {self.output_dir} exists and is complete') | |
- wandb.init( | |
+ run = wandb.init( | |
project=self.wandb_project, | |
name=self.output_folder.replace('/', '__') + '__' + datetime.now().strftime(r"%Y%m%d_%H:%M:%S"), | |
config=yaml.safe_load(OmegaConf.to_yaml(self)), | |
@@ -219,3 +219,6 @@ class BaseConf(abc.ABC): | |
if self.device.type == 'cuda' and self.device.index is not None: | |
torch.cuda.set_device(self.device.index) | |
+ | |
+ assert run is not None | |
+ return run | |
diff --git a/quasimetric_rl/data/__init__.py b/quasimetric_rl/data/__init__.py | |
index 41fc413..519796b 100644 | |
--- a/quasimetric_rl/data/__init__.py | |
+++ b/quasimetric_rl/data/__init__.py | |
@@ -5,8 +5,9 @@ from .env_spec import EnvSpec | |
from . import online | |
from .online import register_online_env, OnlineFixedLengthEnv | |
from .offline import OfflineEnv | |
+from . import interaction | |
__all__ = [ | |
'BatchData', 'EpisodeData', 'MultiEpisodeData', 'Dataset', 'register_offline_env', | |
- 'EnvSpec', 'online', 'register_online_env', 'OnlineFixedLengthEnv', 'OfflineEnv', | |
+ 'EnvSpec', 'online', 'register_online_env', 'OnlineFixedLengthEnv', 'OfflineEnv', 'interaction', | |
] | |
diff --git a/quasimetric_rl/data/base.py b/quasimetric_rl/data/base.py | |
index c86f525..09711c4 100644 | |
--- a/quasimetric_rl/data/base.py | |
+++ b/quasimetric_rl/data/base.py | |
@@ -36,6 +36,7 @@ class BatchData(TensorCollectionAttrsMixin): # TensorCollectionAttrsMixin has s | |
timeouts: torch.Tensor | |
future_observations: torch.Tensor # sampled! | |
+ future_tdelta: torch.Tensor | |
@property | |
def device(self) -> torch.device: | |
@@ -143,6 +144,12 @@ class EpisodeData(MultiEpisodeData): | |
rewards=self.rewards[:t], | |
terminals=self.terminals[:t], | |
timeouts=self.timeouts[:t], | |
+ observation_infos={ | |
+ k: v[:t + 1] for k, v in self.observation_infos.items() | |
+ }, | |
+ transition_infos={ | |
+ k: v[:t] for k, v in self.transition_infos.items() | |
+ }, | |
) | |
@@ -235,6 +242,7 @@ class Dataset(torch.utils.data.Dataset): | |
indices_to_episode_timesteps: torch.Tensor | |
max_episode_length: int | |
# ----- | |
+ device: torch.device | |
def create_env(self, *, dict_obseravtion: Optional[bool] = None, seed: Optional[int] = None, **kwargs) -> 'OfflineEnv': | |
from .offline import OfflineEnv | |
@@ -272,6 +280,7 @@ class Dataset(torch.utils.data.Dataset): | |
def __init__(self, kind: str, name: str, *, | |
future_observation_discount: float, | |
dummy: bool = False, # when you don't want to load data, e.g., in analysis | |
+ device: torch.device = torch.device('cpu'), # FIXME: get some heuristic | |
) -> None: | |
self.kind = kind | |
self.name = name | |
@@ -298,97 +307,22 @@ class Dataset(torch.utils.data.Dataset): | |
indices_to_episode_timesteps.append(torch.arange(l, dtype=torch.int64)) | |
assert len(episodes) > 0, "must have at least one episode" | |
- self.raw_data = MultiEpisodeData.cat(episodes) | |
+ self.raw_data = MultiEpisodeData.cat(episodes).to(device) | |
- self.obs_indices_to_obs_index_in_episode = torch.cat(obs_indices_to_obs_index_in_episode, dim=0) | |
- self.indices_to_episode_indices = torch.cat(indices_to_episode_indices, dim=0) | |
- self.indices_to_episode_timesteps = torch.cat(indices_to_episode_timesteps, dim=0) | |
+ self.obs_indices_to_obs_index_in_episode = torch.cat(obs_indices_to_obs_index_in_episode, dim=0).to(device) | |
+ self.indices_to_episode_indices = torch.cat(indices_to_episode_indices, dim=0).to(device) | |
+ self.indices_to_episode_timesteps = torch.cat(indices_to_episode_timesteps, dim=0).to(device) | |
self.max_episode_length = int(self.raw_data.episode_lengths.max().item()) | |
+ self.device = device | |
- def get_observations(self, obs_indices: torch.Tensor): | |
- return self.raw_data.all_observations[obs_indices] | |
+ # def max_bytes_used(self): | |
+ # return self | |
- @classmethod | |
- def collect_rollout_general(cls, actor: Callable[[torch.Tensor, torch.Tensor, gym.Space], np.ndarray], *, | |
- env: gym.Env, env_spec: EnvSpec, max_episode_length: int, | |
- assert_exact_episode_length: bool = False, extra_transition_info_keys: Collection[str] = []) -> EpisodeData: | |
- from .utils import get_empty_episode | |
- | |
- epi = get_empty_episode(env_spec, max_episode_length) | |
- | |
- # check observation space | |
- obs_dict_keys = {'observation', 'achieved_goal', 'desired_goal'} | |
- WRONG_OBS_ERR_MESSAGE = ( | |
- f"{cls.__name__} collect_rollout only supports Dict " | |
- f"observation space with keys {obs_dict_keys}, but got {env.observation_space}" | |
- ) | |
- assert isinstance(env.observation_space, gym.spaces.Dict), WRONG_OBS_ERR_MESSAGE | |
- assert set(env.observation_space.spaces.keys()) == {'observation', 'achieved_goal', 'desired_goal'}, WRONG_OBS_ERR_MESSAGE | |
- | |
- observation_dict = cast(Mapping[str, np.ndarray], env.reset()) | |
- observation: torch.Tensor = torch.as_tensor(observation_dict['observation'], dtype=torch.float32) | |
- | |
- goal: torch.Tensor = torch.as_tensor(observation_dict['desired_goal'], dtype=torch.float32) | |
- agoal: torch.Tensor = torch.as_tensor(observation_dict['achieved_goal'], dtype=torch.float32) | |
- epi.all_observations[0] = observation | |
- epi.observation_infos['desired_goals'][0] = goal | |
- epi.observation_infos['achieved_goals'][0] = agoal | |
- if len(extra_transition_info_keys): | |
- epi.transition_infos = dict(epi.transition_infos) | |
- epi.transition_infos.update({ | |
- k: torch.empty([max_episode_length], dtype=torch.float32) for k in extra_transition_info_keys | |
- }) | |
- | |
- t = 0 | |
- timeout = False | |
- terminal = False | |
- while not timeout and not terminal: | |
- assert t < max_episode_length | |
- | |
- action = actor( | |
- observation, | |
- goal, | |
- env_spec.action_space, | |
- ) | |
- transition_out = env.step(np.asarray(action)) | |
- observation_dict, reward, terminal, info = transition_out[:3] + transition_out[-1:] # some BC | |
- | |
- observation = torch.tensor(observation_dict['observation'], dtype=torch.float32) # copy just in case | |
- | |
- goal: torch.Tensor = torch.as_tensor(observation_dict['desired_goal'], dtype=torch.float32) | |
- agoal: torch.Tensor = torch.as_tensor(observation_dict['achieved_goal'], dtype=torch.float32) | |
- | |
- if 'is_success' in info: | |
- is_success: bool = info['is_success'] | |
- epi.transition_infos['is_success'][t] = is_success | |
- else: | |
- if t == 0: | |
- # remove field | |
- transition_infos = dict(epi.transition_infos) | |
- del transition_infos['is_success'] | |
- epi.transition_infos = transition_infos | |
- | |
- for k in extra_transition_info_keys: | |
- epi.transition_infos[k][t] = info[k] | |
- | |
- epi.all_observations[t + 1] = observation | |
- epi.actions[t] = torch.as_tensor(action, dtype=torch.float32) | |
- epi.rewards[t] = reward | |
- epi.observation_infos['desired_goals'][t + 1] = goal | |
- epi.observation_infos['achieved_goals'][t + 1] = agoal | |
- | |
- t += 1 | |
- timeout = info.get('TimeLimit.truncated', False) | |
- if assert_exact_episode_length: | |
- assert (timeout or terminal) == (t == max_episode_length) | |
- | |
- if t < max_episode_length: | |
- epi = epi.first_t(t) | |
- | |
- return epi | |
+ def get_observations(self, obs_indices: torch.Tensor): | |
+ return self.raw_data.all_observations[obs_indices.to(self.device)] | |
def __getitem__(self, indices: torch.Tensor) -> BatchData: | |
- indices = torch.as_tensor(indices) | |
+ indices = torch.as_tensor(indices, device=self.device) | |
eindices = self.indices_to_episode_indices[indices] | |
obs_indices = indices + eindices # index for `observation`: skip the s_last from previous episodes | |
obs = self.get_observations(obs_indices) | |
@@ -398,7 +332,7 @@ class Dataset(torch.utils.data.Dataset): | |
tindices = self.indices_to_episode_timesteps[indices] | |
epilengths = self.raw_data.episode_lengths[eindices] # max idx is this | |
- deltas = torch.arange(self.max_episode_length) | |
+ deltas = torch.arange(self.max_episode_length, device=self.device) | |
pdeltas = torch.where( | |
# test tidx + 1 + delta <= max_idx = epi_length | |
(tindices[:, None] + deltas) < epilengths[:, None], | |
@@ -407,14 +341,16 @@ class Dataset(torch.utils.data.Dataset): | |
) | |
deltas = torch.distributions.Categorical( | |
probs=pdeltas, | |
- ).sample() | |
- future_observations = self.get_observations(obs_indices + 1 + deltas) | |
+ validate_args=False, | |
+ ).sample() + 1 | |
+ future_observations = self.get_observations(obs_indices + deltas) | |
return BatchData( | |
observations=obs, | |
actions=self.raw_data.actions[indices], | |
next_observations=nobs, | |
future_observations=future_observations, | |
+ future_tdelta=deltas, | |
rewards=self.raw_data.rewards[indices], | |
terminals=terminals, | |
timeouts=self.raw_data.timeouts[indices], | |
diff --git a/quasimetric_rl/data/interaction.py b/quasimetric_rl/data/interaction.py | |
new file mode 100644 | |
index 0000000..699fa61 | |
--- /dev/null | |
+++ b/quasimetric_rl/data/interaction.py | |
@@ -0,0 +1,178 @@ | |
+from __future__ import annotations | |
+from typing import * | |
+ | |
+import attrs | |
+ | |
+import gym | |
+import gym.spaces | |
+import numpy as np | |
+import torch | |
+import torch.utils.data | |
+ | |
+from . import Dataset, EpisodeData, EnvSpec | |
+ | |
+ | |
+def first_nonzero(arr: torch.Tensor, dim: int = -1, invalid_val: int = -1): | |
+ mask = (arr != 0) | |
+ return torch.where(mask.any(dim=dim), mask.to(torch.uint8).argmax(dim=dim), invalid_val) | |
+ | |
+ | |
+@attrs.define(kw_only=True) | |
+class EvalEpisodeResult: | |
+ timestep_reward: List[torch.Tensor] | |
+ episode_return: torch.Tensor | |
+ episode_score: torch.Tensor | |
+ timestep_is_success: Optional[List[torch.Tensor]] | |
+ is_success: Optional[torch.Tensor] | |
+ hitting_time: Optional[torch.Tensor] | |
+ extra_timestep_results: Mapping[str, List[torch.Tensor]] | |
+ | |
+ @property | |
+ def capped_hitting_time(self) -> Optional[torch.Tensor]: | |
+ # if not hit -> |ts| + 1 | |
+ if self.hitting_time is None: | |
+ return None | |
+ assert self.timestep_is_success is not None | |
+ return torch.stack([torch.where(_x < 0, _succ.shape[0] + 1, _x) for _x, _succ in zip(self.hitting_time, self.timestep_is_success)]) | |
+ | |
+ @classmethod | |
+ def from_timestep_reward_is_success(cls, dataset: Dataset, | |
+ timestep_reward: List[torch.Tensor], | |
+ timestep_is_success: Optional[List[torch.Tensor]], | |
+ extra_timestep_results) -> Self: | |
+ return cls( | |
+ timestep_reward=timestep_reward, | |
+ episode_return=torch.stack([r.sum() for r in timestep_reward]), | |
+ episode_score=dataset.normalize_score(timestep_reward), | |
+ timestep_is_success=timestep_is_success, | |
+ is_success=( | |
+ None if timestep_is_success is None | |
+ else torch.stack([_x.any(dim=-1) for _x in timestep_is_success]) | |
+ ), | |
+ hitting_time=( | |
+ None if timestep_is_success is None | |
+ else torch.stack([first_nonzero(_x, dim=-1) for _x in timestep_is_success]) | |
+ ), # NB this is off by 1 | |
+ extra_timestep_results=dict(extra_timestep_results), | |
+ ) | |
+ | |
+ @classmethod | |
+ def from_episode_rollouts(cls, dataset: Dataset,rollouts: Sequence[EpisodeData]) -> Self: | |
+ return cls.from_timestep_reward_is_success( | |
+ dataset, | |
+ timestep_reward=[rollout.rewards for rollout in rollouts], | |
+ timestep_is_success=( | |
+ None | |
+ if len(rollouts) == 0 or 'is_success' not in rollouts[0].transition_infos | |
+ else [rollout.transition_infos['is_success'] for rollout in rollouts] | |
+ ), | |
+ extra_timestep_results=( | |
+ {} if len(rollouts) == 0 else | |
+ { | |
+ k: [rollout.transition_infos[k] for rollout in rollouts] | |
+ for k in rollouts[0].transition_infos.keys() if k != 'is_success' | |
+ } | |
+ ), | |
+ ) | |
+ | |
+ def summarize(self) -> Mapping[str, Union[torch.Tensor, float, None]]: | |
+ succ_rate_ts = ( | |
+ None if self.timestep_is_success is None | |
+ else torch.stack([_x.mean(dtype=torch.float32) for _x in self.timestep_is_success]) | |
+ ) | |
+ hitting_time = self.capped_hitting_time | |
+ summary = dict( | |
+ epi_return=self.episode_return, | |
+ epi_score=self.episode_score, | |
+ succ_rate_ts=succ_rate_ts, | |
+ succ_rate=self.is_success, | |
+ hitting_time=hitting_time, | |
+ ) | |
+ for kk, v in self.extra_timestep_results.items(): | |
+ summary[kk] = torch.stack([_v.mean(dtype=torch.float32) for _v in v]) | |
+ summary[kk + '_last'] = torch.stack([_v[-1] for _v in v]) | |
+ | |
+ return summary | |
+ | |
+ | |
+def collect_rollout(actor: Callable[[torch.Tensor, torch.Tensor, gym.Space], np.ndarray], *, | |
+ env: gym.Env, env_spec: EnvSpec, max_episode_length: int, | |
+ assert_exact_episode_length: bool = False) -> EpisodeData: | |
+ # NOTE: extra tracked info can be specified by env.tracked_info_keys | |
+ | |
+ from .utils import get_empty_episode | |
+ | |
+ epi = get_empty_episode(env_spec, max_episode_length) | |
+ | |
+ # check observation space | |
+ obs_dict_keys = {'observation', 'achieved_goal', 'desired_goal'} | |
+ WRONG_OBS_ERR_MESSAGE = ( | |
+ f"collect_rollout only supports Dict " | |
+ f"observation space with keys {obs_dict_keys}, but got {env.observation_space}" | |
+ ) | |
+ assert isinstance(env.observation_space, gym.spaces.Dict), WRONG_OBS_ERR_MESSAGE | |
+ assert set(env.observation_space.spaces.keys()) == {'observation', 'achieved_goal', 'desired_goal'}, WRONG_OBS_ERR_MESSAGE | |
+ | |
+ observation_dict = cast(Mapping[str, np.ndarray], env.reset()) | |
+ observation: torch.Tensor = torch.as_tensor(observation_dict['observation'], dtype=torch.float32) | |
+ | |
+ goal: torch.Tensor = torch.as_tensor(observation_dict['desired_goal'], dtype=torch.float32) | |
+ agoal: torch.Tensor = torch.as_tensor(observation_dict['achieved_goal'], dtype=torch.float32) | |
+ epi.all_observations[0] = observation | |
+ epi.observation_infos['desired_goals'][0] = goal | |
+ epi.observation_infos['achieved_goals'][0] = agoal | |
+ | |
+ extra_transition_info_keys = getattr(env, 'tracked_info_keys', []) | |
+ if len(extra_transition_info_keys): | |
+ epi.transition_infos = dict(epi.transition_infos) | |
+ epi.transition_infos.update({ | |
+ k: torch.empty([max_episode_length], dtype=torch.float32) for k in extra_transition_info_keys | |
+ }) | |
+ | |
+ t = 0 | |
+ timeout = False | |
+ terminal = False | |
+ while not timeout and not terminal: | |
+ assert t < max_episode_length | |
+ | |
+ action = actor( | |
+ observation, | |
+ goal, | |
+ env_spec.action_space, | |
+ ) | |
+ transition_out = env.step(np.asarray(action)) | |
+ observation_dict, reward, terminal, info = transition_out[:3] + transition_out[-1:] # some BC | |
+ | |
+ observation = torch.tensor(observation_dict['observation'], dtype=torch.float32) # copy just in case | |
+ | |
+ goal: torch.Tensor = torch.as_tensor(observation_dict['desired_goal'], dtype=torch.float32) | |
+ agoal: torch.Tensor = torch.as_tensor(observation_dict['achieved_goal'], dtype=torch.float32) | |
+ | |
+ if 'is_success' in info: | |
+ is_success: bool = info['is_success'] | |
+ epi.transition_infos['is_success'][t] = is_success | |
+ else: | |
+ if t == 0: | |
+ # remove field | |
+ transition_infos = dict(epi.transition_infos) | |
+ del transition_infos['is_success'] | |
+ epi.transition_infos = transition_infos | |
+ | |
+ for k in extra_transition_info_keys: | |
+ epi.transition_infos[k][t] = info[k] | |
+ | |
+ epi.all_observations[t + 1] = observation | |
+ epi.actions[t] = torch.as_tensor(action, dtype=torch.float32) | |
+ epi.rewards[t] = reward | |
+ epi.observation_infos['desired_goals'][t + 1] = goal | |
+ epi.observation_infos['achieved_goals'][t + 1] = agoal | |
+ | |
+ t += 1 | |
+ timeout = info.get('TimeLimit.truncated', False) | |
+ if assert_exact_episode_length: | |
+ assert (timeout or terminal) == (t == max_episode_length) | |
+ | |
+ if t < max_episode_length: | |
+ epi = epi.first_t(t) | |
+ | |
+ return epi | |
diff --git a/quasimetric_rl/data/offline/__init__.py b/quasimetric_rl/data/offline/__init__.py | |
index 22bb42c..c816a7b 100644 | |
--- a/quasimetric_rl/data/offline/__init__.py | |
+++ b/quasimetric_rl/data/offline/__init__.py | |
@@ -60,6 +60,10 @@ class OfflineGoalCondEnv(gym.ObservationWrapper, OfflineEnv): # type: ignore | |
self.get_goal_fn = get_goal_fn | |
self.extra_info_fns = extra_info_fns | |
+ @property | |
+ def tracked_info_keys(self): | |
+ return tuple(self.extra_info_fns.keys()) | |
+ | |
def observation(self, observation: np.ndarray): | |
o, g = observation, self.get_goal_fn(self.env) | |
if self.is_image_based: | |
diff --git a/quasimetric_rl/data/offline/d4rl/antmaze.py b/quasimetric_rl/data/offline/d4rl/antmaze.py | |
index c48ab34..dbf5198 100644 | |
--- a/quasimetric_rl/data/offline/d4rl/antmaze.py | |
+++ b/quasimetric_rl/data/offline/d4rl/antmaze.py | |
@@ -182,12 +182,13 @@ def create_env_antmaze(name, dict_obseravtion: Optional[bool] = None, *, random_ | |
return env | |
-def load_episodes_antmaze(name): | |
+def load_episodes_antmaze(name, normalize_observation=True): | |
env = load_environment(name) | |
d4rl_dataset = cached_d4rl_dataset(name) | |
- # normalize | |
- d4rl_dataset['observations'] = obs_norm(name, d4rl_dataset['observations']) | |
- d4rl_dataset['next_observations'] = obs_norm(name, d4rl_dataset['next_observations']) | |
+ if normalize_observation: | |
+ # normalize | |
+ d4rl_dataset['observations'] = obs_norm(name, d4rl_dataset['observations']) | |
+ d4rl_dataset['next_observations'] = obs_norm(name, d4rl_dataset['next_observations']) | |
yield from convert_dict_to_EpisodeData_iter( | |
sequence_dataset( | |
env, | |
@@ -206,3 +207,11 @@ for name in ['antmaze-umaze-v2', 'antmaze-umaze-diverse-v2', | |
normalize_score_fn=functools.partial(get_normalized_score, name), | |
eval_specs=dict(single_task=dict(random_start_goal=False), multi_task=dict(random_start_goal=True)), | |
) | |
+ register_offline_env( | |
+ 'd4rl', name + '-nonorm', | |
+ create_env_fn=functools.partial(create_env_antmaze, name, normalize_observation=False), | |
+ load_episodes_fn=functools.partial(load_episodes_antmaze, name, normalize_observation=False), | |
+ normalize_score_fn=functools.partial(get_normalized_score, name), | |
+ eval_specs=dict(single_task=dict(random_start_goal=False, normalize_observation=False), | |
+ multi_task=dict(random_start_goal=True, normalize_observation=False)), | |
+ ) | |
diff --git a/quasimetric_rl/data/online/memory.py b/quasimetric_rl/data/online/memory.py | |
index b7cb94b..a3445f4 100644 | |
--- a/quasimetric_rl/data/online/memory.py | |
+++ b/quasimetric_rl/data/online/memory.py | |
@@ -12,6 +12,7 @@ import gym.spaces | |
from . import OnlineFixedLengthEnv | |
from ..base import EpisodeData, MultiEpisodeData, Dataset, BatchData | |
+from ..interaction import collect_rollout | |
from ..utils import get_empty_episode, get_empty_episodes | |
@@ -153,7 +154,7 @@ class ReplayBuffer(Dataset): | |
get_empty_episodes( | |
self.env_spec, self.episode_length, | |
int(np.ceil(self.increment_num_transitions / self.episode_length)), | |
- ), | |
+ ).to(self.device), | |
], | |
dim=0, | |
) | |
@@ -164,19 +165,19 @@ class ReplayBuffer(Dataset): | |
# indices_to_episode_timesteps: torch.Tensor | |
self.indices_to_episode_indices = torch.cat([ | |
self.indices_to_episode_indices, | |
- torch.repeat_interleave(torch.arange(original_capacity, new_capacity), self.episode_length), | |
+ torch.repeat_interleave(torch.arange(original_capacity, new_capacity, device=self.device), self.episode_length), | |
], dim=0) | |
self.indices_to_episode_timesteps = torch.cat([ | |
self.indices_to_episode_timesteps, | |
- torch.arange(self.episode_length).repeat(new_capacity - original_capacity), | |
+ torch.arange(self.episode_length, device=self.device).repeat(new_capacity - original_capacity), | |
], dim=0) | |
logging.info(f'ReplayBuffer: Expanded from capacity={original_capacity} to {new_capacity} episodes') | |
def collect_rollout(self, actor: Callable[[torch.Tensor, torch.Tensor, gym.Space], np.ndarray], *, | |
env: Optional[OnlineFixedLengthEnv] = None) -> EpisodeData: | |
- return self.collect_rollout_general(actor, env=(env or self.env), env_spec=self.env_spec, | |
- max_episode_length=self.episode_length, assert_exact_episode_length=True) | |
+ return collect_rollout(actor, env=(env or self.env), env_spec=self.env_spec, | |
+ max_episode_length=self.episode_length, assert_exact_episode_length=True) | |
def add_rollout(self, episode: EpisodeData): | |
if self.num_episodes_realized == self.episodes_capacity: | |
@@ -214,7 +215,8 @@ class ReplayBuffer(Dataset): | |
def sample(self, batch_size: int) -> BatchData: | |
indices = torch.as_tensor( | |
- np.random.choice(self.num_transitions_realized, size=[batch_size]) | |
+ np.random.choice(self.num_transitions_realized, size=[batch_size]), | |
+ device=self.device, | |
) | |
return self[indices] | |
diff --git a/quasimetric_rl/flags.py b/quasimetric_rl/flags.py | |
index 2fb0fab..2f7578a 100644 | |
--- a/quasimetric_rl/flags.py | |
+++ b/quasimetric_rl/flags.py | |
@@ -25,31 +25,33 @@ FLAGS = FlagsDefinition() | |
def pdb_if_DEBUG(fn: Callable): | |
@functools.wraps(fn) | |
def wrapped(*args, **kwargs): | |
- try: | |
- fn(*args, **kwargs) | |
- except: | |
- # follow ABSL: | |
- # https://github.com/abseil/abseil-py/blob/a0ae31683e6cf3667886c500327f292c893a1740/absl/app.py#L311-L327 | |
- | |
- exc = sys.exc_info()[1] | |
- if isinstance(exc, KeyboardInterrupt): | |
- raise | |
- | |
- # Don't try to post-mortem debug successful SystemExits, since those | |
- # mean there wasn't actually an error. In particular, the test framework | |
- # raises SystemExit(False) even if all tests passed. | |
- if isinstance(exc, SystemExit) and not exc.code: | |
- raise | |
- | |
- # Check the tty so that we don't hang waiting for input in an | |
- # non-interactive scenario. | |
- if FLAGS.DEBUG: | |
+ if not FLAGS.DEBUG: # check here, in case it is set after decorator call | |
+ return fn(*args, **kwargs) | |
+ else: | |
+ try: | |
+ return fn(*args, **kwargs) | |
+ except: | |
+ # follow ABSL: | |
+ # https://github.com/abseil/abseil-py/blob/a0ae31683e6cf3667886c500327f292c893a1740/absl/app.py#L311-L327 | |
+ | |
+ exc = sys.exc_info()[1] | |
+ if isinstance(exc, KeyboardInterrupt): | |
+ raise | |
+ | |
+ # Don't try to post-mortem debug successful SystemExits, since those | |
+ # mean there wasn't actually an error. In particular, the test framework | |
+ # raises SystemExit(False) even if all tests passed. | |
+ if isinstance(exc, SystemExit) and not exc.code: | |
+ raise | |
+ | |
+ # Check the tty so that we don't hang waiting for input in an | |
+ # non-interactive scenario. | |
traceback.print_exc() | |
print() | |
print(' *** Entering post-mortem debugging ***') | |
print() | |
pdb.post_mortem() | |
- raise | |
+ raise | |
return wrapped | |
diff --git a/quasimetric_rl/modules/__init__.py b/quasimetric_rl/modules/__init__.py | |
index 79fd35f..1ed2315 100644 | |
--- a/quasimetric_rl/modules/__init__.py | |
+++ b/quasimetric_rl/modules/__init__.py | |
@@ -32,13 +32,17 @@ class QRLLosses(Module): | |
critic_losses: Collection[quasimetric_critic.QuasimetricCriticLosses], | |
critics_total_grad_clip_norm: Optional[float], | |
recompute_critic_for_actor_loss: bool, | |
- critics_share_embedding: bool): | |
+ critics_share_embedding: bool, | |
+ critic_losses_use_target_encoder: bool, | |
+ actor_loss_uses_target_encoder: bool): | |
super().__init__() | |
self.add_module('actor_loss', actor_loss) | |
self.critic_losses = torch.nn.ModuleList(critic_losses) # type: ignore | |
self.critics_total_grad_clip_norm = critics_total_grad_clip_norm | |
self.recompute_critic_for_actor_loss = recompute_critic_for_actor_loss | |
self.critics_share_embedding = critics_share_embedding | |
+ self.critic_losses_use_target_encoder = critic_losses_use_target_encoder | |
+ self.actor_loss_uses_target_encoder = actor_loss_uses_target_encoder | |
def forward(self, agent: QRLAgent, data: BatchData, *, optimize: bool = True) -> LossResult: | |
# compute CriticBatchInfo | |
@@ -57,8 +61,8 @@ class QRLLosses(Module): | |
) | |
else: | |
zx = critic.encoder(data.observations) | |
- zy = critic.target_encoder(data.next_observations) | |
- if critic.has_separate_target_encoder: | |
+ zy = critic.get_encoder(target=self.critic_losses_use_target_encoder)(data.next_observations) | |
+ if critic.has_separate_target_encoder and self.critic_losses_use_target_encoder: | |
assert not zy.requires_grad | |
critic_batch_info = quasimetric_critic.CriticBatchInfo( | |
critic=critic, | |
@@ -85,8 +89,9 @@ class QRLLosses(Module): | |
assert agent.actor is not None | |
with torch.no_grad(), torch.inference_mode(): | |
for idx, critic in enumerate(agent.critics): | |
- if self.recompute_critic_for_actor_loss or critic.has_separate_target_encoder: | |
- zx, zy = critic.target_encoder(torch.stack([data.observations, data.next_observations], dim=0)).unbind(0) | |
+ if self.recompute_critic_for_actor_loss or (critic.has_separate_target_encoder and self.actor_loss_uses_target_encoder): | |
+ zx, zy = critic.get_encoder(target=self.actor_loss_uses_target_encoder)( | |
+ torch.stack([data.observations, data.next_observations], dim=0)).unbind(0) | |
critic_batch_infos[idx] = quasimetric_critic.CriticBatchInfo( | |
critic=critic, | |
zx=zx, | |
@@ -142,7 +147,13 @@ class QRLLosses(Module): | |
critic_loss.dynamics_lagrange_mult_sched.load_state_dict(optim_scheds[f"critic_{idx:02d}"]['dynamics_lagrange_mult_sched']) | |
def extra_repr(self) -> str: | |
- return f'recompute_critic_for_actor_loss={self.recompute_critic_for_actor_loss}' | |
+ return '\n'.join([ | |
+ f'recompute_critic_for_actor_loss={self.recompute_critic_for_actor_loss}', | |
+ f'critics_share_embedding={self.critics_share_embedding}', | |
+ f'critics_total_grad_clip_norm={self.critics_total_grad_clip_norm}', | |
+ f'critic_losses_use_target_encoder={self.critic_losses_use_target_encoder}', | |
+ f'actor_loss_uses_target_encoder={self.actor_loss_uses_target_encoder}', | |
+ ]) | |
@attrs.define(kw_only=True) | |
@@ -155,6 +166,8 @@ class QRLConf: | |
default=None, validator=attrs.validators.optional(attrs.validators.gt(0)), | |
) | |
recompute_critic_for_actor_loss: bool = False | |
+ critic_losses_use_target_encoder: bool = True | |
+ actor_loss_uses_target_encoder: bool = True | |
def make(self, *, env_spec: EnvSpec, total_optim_steps: int) -> Tuple[QRLAgent, QRLLosses]: | |
if self.actor is None: | |
@@ -174,9 +187,13 @@ class QRLConf: | |
critics.append(critic) | |
critic_losses.append(critic_loss) | |
- return QRLAgent(actor=actor, critics=critics), QRLLosses(actor_loss=actor_losses, critic_losses=critic_losses, | |
- critics_share_embedding=self.critics_share_embedding, | |
- critics_total_grad_clip_norm=self.critics_total_grad_clip_norm, | |
- recompute_critic_for_actor_loss=self.recompute_critic_for_actor_loss) | |
+ return QRLAgent(actor=actor, critics=critics), QRLLosses( | |
+ actor_loss=actor_losses, critic_losses=critic_losses, | |
+ critics_share_embedding=self.critics_share_embedding, | |
+ critics_total_grad_clip_norm=self.critics_total_grad_clip_norm, | |
+ recompute_critic_for_actor_loss=self.recompute_critic_for_actor_loss, | |
+ critic_losses_use_target_encoder=self.critic_losses_use_target_encoder, | |
+ actor_loss_uses_target_encoder=self.actor_loss_uses_target_encoder, | |
+ ) | |
__all__ = ['QRLAgent', 'QRLLosses', 'QRLConf', 'InfoT', 'InfoValT'] | |
diff --git a/quasimetric_rl/modules/actor/losses/awr.py b/quasimetric_rl/modules/actor/losses/awr.py | |
index c69e79c..f6cf4de 100644 | |
--- a/quasimetric_rl/modules/actor/losses/awr.py | |
+++ b/quasimetric_rl/modules/actor/losses/awr.py | |
@@ -3,11 +3,10 @@ from typing import * | |
import attrs | |
import torch | |
-import torch.nn as nn | |
-from ....data import BatchData, EnvSpec | |
+from ....data import BatchData | |
-from ...utils import LatentTensor, LossResult, grad_mul | |
+from ...utils import LatentTensor, LossResult, bcast_bshape | |
from ..model import Actor | |
from ...quasimetric_critic import QuasimetricCritic, CriticBatchInfo | |
@@ -140,14 +139,23 @@ class AWRLoss(ActorLossBase): | |
for idx, actor_obs_goal_critic_info in enumerate(actor_obs_goal_critic_infos): | |
critic = actor_obs_goal_critic_info.critic | |
- zo = actor_obs_goal_critic_info.zo.detach() | |
- zg = actor_obs_goal_critic_info.zg.detach() | |
+ zo = actor_obs_goal_critic_info.zo.detach() # [B,D] | |
+ zg = actor_obs_goal_critic_info.zg.detach() # [2?,B,D] | |
with torch.no_grad(), critic.mode(False): | |
- zp = critic.latent_dynamics(data.observations, zo, data.actions) | |
- z = torch.stack([zo, zp], dim=0) | |
- z = z[(slice(None),) + (None,) * (zg.ndim - zo.ndim) + (Ellipsis,)] # if zg is batched, add enough dim to be broadcast-able | |
- dist_noact, dist = critic.quasimetric_model(z, zg).unbind(0) | |
- dist_noact = dist_noact.detach() | |
+ zp = critic.latent_dynamics(data.observations, zo, data.actions) # [B,D] | |
+ if not critic.borrowing_embedding: | |
+ zo, zp, zg = bcast_bshape( | |
+ (zo, 1), | |
+ (zp, 1), | |
+ (zg, 1), | |
+ ) | |
+ z = torch.stack([zo, zp], dim=0) # [2,2?,B,D] | |
+ # z = z[(slice(None),) + (None,) * (zg.ndim - zo.ndim) + (Ellipsis,)] # if zg is batched, add enough dim to be broadcast-able | |
+ dist_noact, dist = critic.quasimetric_model(z, zg).unbind(0) # [2,2?,B] -> 2x [2?,B] | |
+ dist_noact = dist_noact.detach() | |
+ else: | |
+ dist = critic.quasimetric_model(zp, zg) | |
+ dist_noact = dists_noact[0] | |
info[f'dist_delta_{idx:02d}'] = (dist_noact - dist).mean() | |
info[f'dist_{idx:02d}'] = dist.mean() | |
dists_noact.append(dist_noact) | |
diff --git a/quasimetric_rl/modules/actor/losses/min_dist.py b/quasimetric_rl/modules/actor/losses/min_dist.py | |
index 0f90cff..bc36ee8 100644 | |
--- a/quasimetric_rl/modules/actor/losses/min_dist.py | |
+++ b/quasimetric_rl/modules/actor/losses/min_dist.py | |
@@ -147,13 +147,18 @@ class MinDistLoss(ActorLossBase): | |
for idx, actor_obs_goal_critic_info in enumerate(actor_obs_goal_critic_infos): | |
critic = actor_obs_goal_critic_info.critic | |
- zo = actor_obs_goal_critic_info.zo.detach() | |
- zg = actor_obs_goal_critic_info.zg.detach() | |
+ zo = actor_obs_goal_critic_info.zo.detach() # [B,D] | |
+ zg = actor_obs_goal_critic_info.zg.detach() # [2?,B,D] | |
with critic.requiring_grad(False), critic.mode(False): | |
- zp = critic.latent_dynamics(data.observations, zo, action) | |
- z = torch.stack(torch.broadcast_tensors(zo, zp), dim=0) | |
- dist_noact, dist = critic.quasimetric_model(z, zg).unbind(0) | |
- dist_noact = dist_noact.detach() | |
+ zp = critic.latent_dynamics(data.observations, zo, action) # [2?,B,D] | |
+ if not critic.borrowing_embedding: | |
+ # action: [2?,B,A] | |
+ z = torch.stack(torch.broadcast_tensors(zo, zp), dim=0) # [2,2?,B,D] | |
+ dist_noact, dist = critic.quasimetric_model(z, zg).unbind(0) # [2,2?,B] -> 2x [2?,B] | |
+ dist_noact = dist_noact.detach() | |
+ else: | |
+ dist = critic.quasimetric_model(zp, zg) # [2?,B] | |
+ dist_noact = dists_noact[0] | |
info[f'dist_delta_{idx:02d}'] = (dist_noact - dist).mean() | |
info[f'dist_{idx:02d}'] = dist.mean() | |
dists_noact.append(dist_noact) | |
diff --git a/quasimetric_rl/modules/quasimetric_critic/losses/__init__.py b/quasimetric_rl/modules/quasimetric_critic/losses/__init__.py | |
index fa8eba0..48ed700 100644 | |
--- a/quasimetric_rl/modules/quasimetric_critic/losses/__init__.py | |
+++ b/quasimetric_rl/modules/quasimetric_critic/losses/__init__.py | |
@@ -31,7 +31,7 @@ class CriticLossBase(LossBase): | |
return super().__call__(data, critic_batch_info) | |
-from .global_push import GlobalPushLoss, GlobalPushLinearLoss, GlobalPushLogLoss, GlobalPushRBFLoss | |
+from .global_push import GlobalPushLoss, GlobalPushLinearLoss, GlobalPushLogLoss, GlobalPushRBFLoss, GlobalPushNextMSELoss | |
from .local_constraint import LocalConstraintLoss | |
from .latent_dynamics import LatentDynamicsLoss | |
@@ -41,6 +41,7 @@ class QuasimetricCriticLosses(CriticLossBase): | |
class Conf: | |
global_push: GlobalPushLoss.Conf = GlobalPushLoss.Conf() | |
global_push_linear: GlobalPushLinearLoss.Conf = GlobalPushLinearLoss.Conf() | |
+ global_push_next_mse: GlobalPushNextMSELoss.Conf = GlobalPushNextMSELoss.Conf() | |
global_push_log: GlobalPushLogLoss.Conf = GlobalPushLogLoss.Conf() | |
global_push_rbf: GlobalPushRBFLoss.Conf = GlobalPushRBFLoss.Conf() | |
local_constraint: LocalConstraintLoss.Conf = LocalConstraintLoss.Conf() | |
@@ -54,7 +55,10 @@ class QuasimetricCriticLosses(CriticLossBase): | |
local_lagrange_mult_optim: AdamWSpec.Conf = AdamWSpec.Conf(lr=1e-2) | |
dynamics_lagrange_mult_optim: AdamWSpec.Conf = AdamWSpec.Conf(lr=0) | |
- scale_with_best_local_fit: bool = False | |
+ quasimetric_scale: Optional[str] = attrs.field( | |
+ default=None, validator=attrs.validators.optional(attrs.validators.in_( | |
+ ['best_local_fit', 'best_local_fit_clip5', 'best_local_fit_clip10', | |
+ 'best_local_fit_detach']))) # type: ignore | |
def make(self, critic: QuasimetricCritic, total_optim_steps: int, | |
share_embedding_from: Optional[QuasimetricCritic] = None) -> 'QuasimetricCriticLosses': | |
@@ -65,6 +69,7 @@ class QuasimetricCriticLosses(CriticLossBase): | |
# global losses | |
global_push=self.global_push.make(), | |
global_push_linear=self.global_push_linear.make(), | |
+ global_push_next_mse=self.global_push_next_mse.make(), | |
global_push_log=self.global_push_log.make(), | |
global_push_rbf=self.global_push_rbf.make(), | |
# local loss | |
@@ -81,12 +86,13 @@ class QuasimetricCriticLosses(CriticLossBase): | |
local_lagrange_mult_optim_spec=self.local_lagrange_mult_optim.make(), | |
dynamics_lagrange_mult_optim_spec=self.dynamics_lagrange_mult_optim.make(), | |
# | |
- scale_with_best_local_fit=self.scale_with_best_local_fit, | |
+ quasimetric_scale=self.quasimetric_scale, | |
) | |
borrowing_embedding: bool | |
global_push: Optional[GlobalPushLoss] | |
global_push_linear: Optional[GlobalPushLinearLoss] | |
+ global_push_next_mse: Optional[GlobalPushNextMSELoss] | |
global_push_log: Optional[GlobalPushLogLoss] | |
global_push_rbf: Optional[GlobalPushRBFLoss] | |
local_constraint: Optional[LocalConstraintLoss] | |
@@ -98,12 +104,13 @@ class QuasimetricCriticLosses(CriticLossBase): | |
local_lagrange_mult_sched: LRScheduler | |
dynamics_lagrange_mult_optim: OptimWrapper | |
dynamics_lagrange_mult_sched: LRScheduler | |
- scale_with_best_local_fit: bool | |
+ quasimetric_scale: Optional[str] | |
def __init__(self, critic: QuasimetricCritic, *, total_optim_steps: int, | |
share_embedding_from: Optional[QuasimetricCritic] = None, | |
global_push: Optional[GlobalPushLoss], global_push_linear: Optional[GlobalPushLinearLoss], | |
- global_push_log: Optional[GlobalPushLogLoss], global_push_rbf: Optional[GlobalPushRBFLoss], | |
+ global_push_next_mse: Optional[GlobalPushNextMSELoss], global_push_log: Optional[GlobalPushLogLoss], | |
+ global_push_rbf: Optional[GlobalPushRBFLoss], | |
local_constraint: Optional[LocalConstraintLoss], latent_dynamics: LatentDynamicsLoss, | |
critic_optim_spec: AdamWSpec, | |
latent_dynamics_lr_mul: float, | |
@@ -112,7 +119,7 @@ class QuasimetricCriticLosses(CriticLossBase): | |
quasimetric_head_lr_mul: float, | |
local_lagrange_mult_optim_spec: AdamWSpec, | |
dynamics_lagrange_mult_optim_spec: AdamWSpec, | |
- scale_with_best_local_fit: bool): | |
+ quasimetric_scale: Optional[str]): | |
super().__init__() | |
self.borrowing_embedding = share_embedding_from is not None | |
if self.borrowing_embedding: | |
@@ -123,6 +130,7 @@ class QuasimetricCriticLosses(CriticLossBase): | |
local_constraint = None | |
self.add_module('global_push', global_push) | |
self.add_module('global_push_linear', global_push_linear) | |
+ self.add_module('global_push_next_mse', global_push_next_mse) | |
self.add_module('global_push_log', global_push_log) | |
self.add_module('global_push_rbf', global_push_rbf) | |
self.add_module('local_constraint', local_constraint) | |
@@ -147,7 +155,7 @@ class QuasimetricCriticLosses(CriticLossBase): | |
self.dynamics_lagrange_mult_optim, self.dynamics_lagrange_mult_sched = dynamics_lagrange_mult_optim_spec.create_optim_scheduler( | |
latent_dynamics.parameters(), total_optim_steps) | |
assert len(list(latent_dynamics.parameters())) == 1 | |
- self.scale_with_best_local_fit = scale_with_best_local_fit | |
+ self.quasimetric_scale = quasimetric_scale | |
def optimizers(self) -> Iterable[OptimWrapper]: | |
return [self.critic_optim, self.local_lagrange_mult_optim, self.dynamics_lagrange_mult_optim] | |
@@ -155,23 +163,34 @@ class QuasimetricCriticLosses(CriticLossBase): | |
def schedulers(self) -> Iterable[LRScheduler]: | |
return [self.critic_sched, self.local_lagrange_mult_sched, self.dynamics_lagrange_mult_sched] | |
- @torch.no_grad() | |
- def compute_best_quasimetric_scale(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> torch.Tensor: | |
+ def compute_best_quasimetric_scale(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> Tuple[torch.Tensor, torch.Tensor]: | |
assert self.local_constraint is not None and not self.borrowing_embedding | |
+ critic_batch_info.critic.quasimetric_model.quasimetric_head.scale.detach_().fill_(1) # reset | |
dist = critic_batch_info.critic.quasimetric_model(critic_batch_info.zx, critic_batch_info.zy) | |
- return (self.local_constraint.step_cost * (dist.mean() / dist.square().mean().clamp_min_(1e-8))).detach().clamp_(1e-3, 1e3) | |
+ return dist, (self.local_constraint.step_cost * (dist.mean() / dist.square().mean().clamp_min(1e-12))) # .detach().clamp_(1e-1, 1e1) | |
def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult: | |
extra_info: Dict[str, torch.Tensor] = {} | |
- if self.scale_with_best_local_fit and not self.borrowing_embedding: | |
- scale = extra_info['quasimetric_autoscale'] = self.compute_best_quasimetric_scale(data, critic_batch_info) | |
- critic_batch_info.critic.quasimetric_model.quasimetric_head.scale.copy_(scale) | |
+ if self.quasimetric_scale is not None and not self.borrowing_embedding: | |
+ unscaled_dist, scale = self.compute_best_quasimetric_scale(data, critic_batch_info) | |
+ assert scale.grad_fn is not None # allow bp | |
+ if self.quasimetric_scale == 'best_local_fit_detach': | |
+ scale = scale.detach() | |
+ elif self.quasimetric_scale == 'best_local_fit_clip5': | |
+ scale = scale.clamp(1 / 5, 5) | |
+ elif self.quasimetric_scale == 'best_local_fit_clip10': | |
+ scale = scale.clamp(1 / 10, 10) | |
+ extra_info['unscaled_dist'] = unscaled_dist | |
+ extra_info['quasimetric_autoscale'] = scale | |
+ critic_batch_info.critic.quasimetric_model.quasimetric_head.scale = scale | |
loss_results: Dict[str, LossResult] = {} | |
if self.global_push is not None: | |
loss_results.update(global_push=self.global_push(data, critic_batch_info)) | |
if self.global_push_linear is not None: | |
loss_results.update(global_push_linear=self.global_push_linear(data, critic_batch_info)) | |
+ if self.global_push_next_mse is not None: | |
+ loss_results.update(global_push_next_mse=self.global_push_next_mse(data, critic_batch_info)) | |
if self.global_push_log is not None: | |
loss_results.update(global_push_log=self.global_push_log(data, critic_batch_info)) | |
if self.global_push_rbf is not None: | |
@@ -189,4 +208,4 @@ class QuasimetricCriticLosses(CriticLossBase): | |
return torch.nn.Module.__call__(self, data, critic_batch_info) | |
def extra_repr(self) -> str: | |
- return f"borrowing_embedding={self.borrowing_embedding}" | |
+ return f"borrowing_embedding={self.borrowing_embedding}, quasimetric_scale={self.quasimetric_scale!r}" | |
diff --git a/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py b/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py | |
index 7c3f9f3..e469f13 100644 | |
--- a/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py | |
+++ b/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py | |
@@ -53,16 +53,95 @@ from . import CriticLossBase, CriticBatchInfo | |
# return f"weight={self.weight:g}, softplus_beta={self.softplus_beta:g}, softplus_offset={self.softplus_offset:g}" | |
- | |
-class GlobalPushLoss(CriticLossBase): | |
+class GlobalPushLossBase(CriticLossBase): | |
@attrs.define(kw_only=True) | |
- class Conf: | |
+ class Conf(abc.ABC): | |
# config / argparse uses this to specify behavior | |
enabled: bool = True | |
detach_goal: bool = False | |
detach_proj_goal: bool = False | |
+ detach_qmet: bool = False | |
+ step_cost: float = attrs.field(default=1., validator=attrs.validators.gt(0)) | |
weight: float = attrs.field(default=1., validator=attrs.validators.gt(0)) | |
+ weight_future_goal: float = attrs.field(default=0., validator=attrs.validators.ge(0)) | |
+ clamp_max_future_goal: bool = True | |
+ | |
+ @abc.abstractmethod | |
+ def make(self) -> Optional['GlobalPushLossBase']: | |
+ if not self.enabled: | |
+ return None | |
+ | |
+ weight: float | |
+ weight_future_goal: float | |
+ detach_goal: bool | |
+ detach_proj_goal: bool | |
+ detach_qmet: bool | |
+ step_cost: float | |
+ clamp_max_future_goal: bool | |
+ | |
+ def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, | |
+ detach_qmet: bool, step_cost: float, clamp_max_future_goal: bool): | |
+ super().__init__() | |
+ self.weight = weight | |
+ self.weight_future_goal = weight_future_goal | |
+ self.detach_goal = detach_goal | |
+ self.detach_proj_goal = detach_proj_goal | |
+ self.detach_qmet = detach_qmet | |
+ self.step_cost = step_cost | |
+ self.clamp_max_future_goal = clamp_max_future_goal | |
+ | |
+ def generate_dist_weight(self, data: BatchData, critic_batch_info: CriticBatchInfo): | |
+ def get_dist(za: torch.Tensor, zb: torch.Tensor): | |
+ if self.detach_goal: | |
+ zb = zb.detach() | |
+ with critic_batch_info.critic.quasimetric_model.requiring_grad(not self.detach_qmet): | |
+ return critic_batch_info.critic.quasimetric_model(za, zb, proj_grad_enabled=(True, not self.detach_proj_goal)) | |
+ | |
+ # To randomly pair zx, zy, we just roll over zy by 1, because zx and zy | |
+ # are latents of randomly ordered random batches. | |
+ zgoal = torch.roll(critic_batch_info.zy, 1, dims=0) | |
+ yield ( | |
+ 'random_goal', | |
+ zgoal, | |
+ get_dist(critic_batch_info.zx, zgoal), | |
+ self.weight, | |
+ ) | |
+ if self.weight_future_goal > 0: | |
+ zgoal = critic_batch_info.critic.encoder(data.future_observations) | |
+ dist = get_dist(critic_batch_info.zx, zgoal) | |
+ if self.clamp_max_future_goal: | |
+ dist = dist.clamp_max(self.step_cost * data.future_tdelta) | |
+ yield ( | |
+ 'future_goal', | |
+ zgoal, | |
+ dist, | |
+ self.weight_future_goal, | |
+ ) | |
+ | |
+ @abc.abstractmethod | |
+ def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor, dist: torch.Tensor, weight: float) -> LossResult: | |
+ raise NotImplementedError | |
+ | |
+ def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult: | |
+ return LossResult.combine( | |
+ { | |
+ name: self.compute_loss(data, critic_batch_info, zgoal, dist, weight) | |
+ for name, zgoal, dist, weight in self.generate_dist_weight(data, critic_batch_info) | |
+ }, | |
+ ) | |
+ | |
+ def extra_repr(self) -> str: | |
+ return '\n'.join([ | |
+ f"weight={self.weight:g}, detach_proj_goal={self.detach_proj_goal}", | |
+ f"weight_future_goal={self.weight_future_goal:g}, detach_qmet={self.detach_qmet}", | |
+ f"step_cost={self.step_cost:g}, clamp_max_future_goal={self.clamp_max_future_goal}", | |
+ ]) | |
+ | |
+ | |
+class GlobalPushLoss(GlobalPushLossBase): | |
+ @attrs.define(kw_only=True) | |
+ class Conf(GlobalPushLossBase.Conf): | |
# smaller => smoother loss | |
softplus_beta: float = attrs.field(default=0.1, validator=attrs.validators.gt(0)) | |
@@ -75,99 +154,161 @@ class GlobalPushLoss(CriticLossBase): | |
return None | |
return GlobalPushLoss( | |
weight=self.weight, | |
+ weight_future_goal=self.weight_future_goal, | |
detach_goal=self.detach_goal, | |
detach_proj_goal=self.detach_proj_goal, | |
+ detach_qmet=self.detach_qmet, | |
+ step_cost=self.step_cost, | |
+ clamp_max_future_goal=self.clamp_max_future_goal, | |
softplus_beta=self.softplus_beta, | |
softplus_offset=self.softplus_offset, | |
) | |
- weight: float | |
- detach_goal: bool | |
- detach_proj_goal: bool | |
softplus_beta: float | |
softplus_offset: float | |
- def __init__(self, *, weight: float, detach_goal: bool, detach_proj_goal: bool, | |
+ def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, detach_qmet: bool, step_cost: float, clamp_max_future_goal: bool, | |
softplus_beta: float, softplus_offset: float): | |
- super().__init__() | |
- self.weight = weight | |
- self.detach_goal = detach_goal | |
- self.detach_proj_goal = detach_proj_goal | |
+ super().__init__(weight=weight, weight_future_goal=weight_future_goal, detach_goal=detach_goal, | |
+ detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, clamp_max_future_goal=clamp_max_future_goal) | |
self.softplus_beta = softplus_beta | |
self.softplus_offset = softplus_offset | |
- def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult: | |
- # To randomly pair zx, zy, we just roll over zy by 1, because zx and zy | |
- # are latents of randomly ordered random batches. | |
- zgoal = torch.roll(critic_batch_info.zy, 1, dims=0) | |
- if self.detach_goal: | |
- zgoal = zgoal.detach() | |
- dists = critic_batch_info.critic.quasimetric_model(critic_batch_info.zx, zgoal, | |
- proj_grad_enabled=(True, not self.detach_proj_goal)) | |
+ def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor, dist: torch.Tensor, weight: float) -> LossResult: | |
# Sec 3.2. Transform so that we penalize large distances less. | |
- tsfm_dist: torch.Tensor = F.softplus(self.softplus_offset - dists, beta=self.softplus_beta) # type: ignore | |
+ tsfm_dist: torch.Tensor = F.softplus(self.softplus_offset - dist, beta=self.softplus_beta) # type: ignore | |
tsfm_dist = tsfm_dist.mean() | |
- return LossResult(loss=tsfm_dist * self.weight, info=dict(dist=dists.mean(), tsfm_dist=tsfm_dist)) # type: ignore | |
+ return LossResult(loss=tsfm_dist * weight, info=dict(dist=dist.mean(), tsfm_dist=tsfm_dist)) | |
def extra_repr(self) -> str: | |
- return f"weight={self.weight:g}, detach_proj_goal={self.detach_proj_goal}, softplus_beta={self.softplus_beta:g}, softplus_offset={self.softplus_offset:g}" | |
- | |
+ return '\n'.join([ | |
+ super().extra_repr(), | |
+ f"softplus_beta={self.softplus_beta:g}, softplus_offset={self.softplus_offset:g}", | |
+ ]) | |
-class GlobalPushLinearLoss(CriticLossBase): | |
+class GlobalPushLinearLoss(GlobalPushLossBase): | |
@attrs.define(kw_only=True) | |
- class Conf: | |
- # config / argparse uses this to specify behavior | |
- | |
+ class Conf(GlobalPushLossBase.Conf): | |
enabled: bool = False | |
- detach_goal: bool = False | |
- detach_proj_goal: bool = False | |
- weight: float = attrs.field(default=1., validator=attrs.validators.gt(0)) | |
+ | |
+ clamp_max: Optional[float] = attrs.field(default=None, validator=attrs.validators.optional(attrs.validators.gt(0))) | |
def make(self) -> Optional['GlobalPushLinearLoss']: | |
if not self.enabled: | |
return None | |
return GlobalPushLinearLoss( | |
weight=self.weight, | |
+ weight_future_goal=self.weight_future_goal, | |
detach_goal=self.detach_goal, | |
detach_proj_goal=self.detach_proj_goal, | |
+ detach_qmet=self.detach_qmet, | |
+ step_cost=self.step_cost, | |
+ clamp_max_future_goal=self.clamp_max_future_goal, | |
+ clamp_max=self.clamp_max, | |
) | |
- weight: float | |
- detach_goal: bool | |
- detach_proj_goal: bool | |
- | |
- def __init__(self, *, weight: float, detach_goal: bool, detach_proj_goal: bool): | |
- super().__init__() | |
- self.weight = weight | |
- self.detach_goal = detach_goal | |
- self.detach_proj_goal = detach_proj_goal | |
- | |
- def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult: | |
- # To randomly pair zx, zy, we just roll over zy by 1, because zx and zy | |
- # are latents of randomly ordered random batches. | |
- zgoal = torch.roll(critic_batch_info.zy, 1, dims=0) | |
- if self.detach_goal: | |
- zgoal = zgoal.detach() | |
- dists = critic_batch_info.critic.quasimetric_model(critic_batch_info.zx, zgoal, | |
- proj_grad_enabled=(True, not self.detach_proj_goal)) | |
- dists = dists.mean() | |
- return LossResult(loss=dists * (-self.weight), info=dict(dist=dists)) # type: ignore | |
+ clamp_max: Optional[float] | |
+ | |
+ def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, detach_qmet: bool, step_cost: float, clamp_max_future_goal: bool, | |
+ clamp_max: Optional[float]): | |
+ super().__init__(weight=weight, weight_future_goal=weight_future_goal, detach_goal=detach_goal, | |
+ detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, clamp_max_future_goal=clamp_max_future_goal) | |
+ self.clamp_max = clamp_max | |
+ | |
+ def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor, dist: torch.Tensor, weight: float) -> LossResult: | |
+ info: Dict[str, torch.Tensor] | |
+ if self.clamp_max is None: | |
+ dist = dist.mean() | |
+ info = dict(dist=dist) | |
+ neg_loss = dist | |
+ else: | |
+ info = dict(dist=dist.mean()) | |
+ tsfm_dist = dist.clamp_max(self.clamp_max) | |
+ info.update( | |
+ tsfm_dist=tsfm_dist.mean(), | |
+ exceed_rate=(dist >= self.clamp_max).mean(dtype=torch.float32), | |
+ ) | |
+ neg_loss = tsfm_dist | |
+ return LossResult(loss=neg_loss * weight, info=info) | |
def extra_repr(self) -> str: | |
- return f"weight={self.weight:g}, detach_proj_goal={self.detach_proj_goal}" | |
+ return '\n'.join([ | |
+ super().extra_repr(), | |
+ f"clamp_max={self.clamp_max!r}", | |
+ ]) | |
-class GlobalPushLogLoss(CriticLossBase): | |
+ | |
+class GlobalPushNextMSELoss(GlobalPushLossBase): | |
@attrs.define(kw_only=True) | |
- class Conf: | |
- # config / argparse uses this to specify behavior | |
+ class Conf(GlobalPushLossBase.Conf): | |
+ enabled: bool = False | |
+ detach_target_dist: bool = True | |
+ allow_gt: bool = False | |
+ gamma: Optional[float] = attrs.field( | |
+ default=None, validator=attrs.validators.optional(attrs.validators.and_( | |
+ attrs.validators.gt(0), | |
+ attrs.validators.lt(1), | |
+ ))) | |
+ | |
+ def make(self) -> Optional['GlobalPushNextMSELoss']: | |
+ if not self.enabled: | |
+ return None | |
+ return GlobalPushNextMSELoss( | |
+ weight=self.weight, | |
+ weight_future_goal=self.weight_future_goal, | |
+ detach_goal=self.detach_goal, | |
+ detach_proj_goal=self.detach_proj_goal, | |
+ detach_qmet=self.detach_qmet, | |
+ clamp_max_future_goal=self.clamp_max_future_goal, | |
+ step_cost=self.step_cost, | |
+ detach_target_dist=self.detach_target_dist, | |
+ allow_gt=self.allow_gt, | |
+ gamma=self.gamma, | |
+ ) | |
+ | |
+ def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, detach_qmet: bool, step_cost: float, clamp_max_future_goal: bool, | |
+ detach_target_dist: bool, allow_gt: bool, gamma: Optional[float]): | |
+ super().__init__(weight=weight, weight_future_goal=weight_future_goal, detach_goal=detach_goal, | |
+ detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, clamp_max_future_goal=clamp_max_future_goal) | |
+ self.detach_target_dist = detach_target_dist | |
+ self.allow_gt = allow_gt | |
+ self.gamma = gamma | |
+ | |
+ def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor, dist: torch.Tensor, weight: float) -> LossResult: | |
+ with torch.enable_grad(self.detach_target_dist): | |
+ # by tri-eq, the actual cost can't be larger that step_cost + d(s', g) | |
+ next_dist = critic_batch_info.critic.quasimetric_model( | |
+ critic_batch_info.zy, zgoal, proj_grad_enabled=(True, not self.detach_proj_goal) | |
+ ) | |
+ if self.detach_target_dist: | |
+ next_dist = next_dist.detach() | |
+ target_dist = self.step_cost + next_dist | |
+ | |
+ if self.allow_gt: | |
+ dist = dist.clamp_max(target_dist) | |
+ | |
+ if self.gamma is None: | |
+ loss = F.mse_loss(dist, target_dist) | |
+ else: | |
+ loss = F.mse_loss(self.gamma ** dist, self.gamma ** target_dist) | |
+ | |
+ return LossResult(loss=loss * self.weight, info=dict(loss=loss, dist=dist.mean(), target_dist=target_dist.mean())) | |
+ | |
+ def extra_repr(self) -> str: | |
+ return '\n'.join([ | |
+ super().extra_repr(), | |
+ "detach_target_dist={self.detach_target_dist}", | |
+ ]) | |
+ | |
+class GlobalPushLogLoss(GlobalPushLossBase): | |
+ @attrs.define(kw_only=True) | |
+ class Conf(GlobalPushLossBase.Conf): | |
enabled: bool = False | |
- detach_goal: bool = False | |
- detach_proj_goal: bool = False | |
- weight: float = attrs.field(default=1., validator=attrs.validators.gt(0)) | |
+ | |
offset: float = attrs.field(default=1., validator=attrs.validators.gt(0)) | |
def make(self) -> Optional['GlobalPushLogLoss']: | |
@@ -175,54 +316,43 @@ class GlobalPushLogLoss(CriticLossBase): | |
return None | |
return GlobalPushLogLoss( | |
weight=self.weight, | |
+ weight_future_goal=self.weight_future_goal, | |
detach_goal=self.detach_goal, | |
detach_proj_goal=self.detach_proj_goal, | |
+ detach_qmet=self.detach_qmet, | |
+ step_cost=self.step_cost, | |
+ clamp_max_future_goal=self.clamp_max_future_goal, | |
offset=self.offset, | |
) | |
- weight: float | |
- detach_goal: bool | |
- detach_proj_goal: bool | |
offset: float | |
- def __init__(self, *, weight: float, detach_goal: bool, detach_proj_goal: bool, offset: float): | |
- super().__init__() | |
- self.weight = weight | |
- self.detach_goal = detach_goal | |
- self.detach_proj_goal = detach_proj_goal | |
+ def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, detach_qmet: bool, step_cost: float, clamp_max_future_goal: bool, | |
+ offset: float): | |
+ super().__init__(weight=weight, weight_future_goal=weight_future_goal, detach_goal=detach_goal, | |
+ detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, clamp_max_future_goal=clamp_max_future_goal) | |
self.offset = offset | |
- def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult: | |
- # To randomly pair zx, zy, we just roll over zy by 1, because zx and zy | |
- # are latents of randomly ordered random batches. | |
- zgoal = torch.roll(critic_batch_info.zy, 1, dims=0) | |
- if self.detach_goal: | |
- zgoal = zgoal.detach() | |
- dists = critic_batch_info.critic.quasimetric_model(critic_batch_info.zx, zgoal, | |
- proj_grad_enabled=(True, not self.detach_proj_goal)) | |
- # Sec 3.2. Transform so that we penalize large distances less. | |
- tsfm_dist: torch.Tensor = -dists.add(self.offset).log() | |
+ def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor, dist: torch.Tensor, weight: float) -> LossResult: | |
+ tsfm_dist: torch.Tensor = -dist.add(self.offset).log() | |
tsfm_dist = tsfm_dist.mean() | |
- return LossResult(loss=tsfm_dist * self.weight, info=dict(dist=dists.mean(), tsfm_dist=tsfm_dist)) # type: ignore | |
+ return LossResult(loss=tsfm_dist * weight, info=dict(dist=dist.mean(), tsfm_dist=tsfm_dist)) | |
def extra_repr(self) -> str: | |
- return f"weight={self.weight:g}, detach_proj_goal={self.detach_proj_goal}, offset={self.offset:g}" | |
- | |
+ return '\n'.join([ | |
+ super().extra_repr(), | |
+ f"offset={self.offset:g}", | |
+ ]) | |
-class GlobalPushRBFLoss(CriticLossBase): | |
+class GlobalPushRBFLoss(GlobalPushLossBase): | |
# say E[opt T] approx sqrt(2)/2 timeout, so E[opt T^2] approx 1/2 timeout^2 | |
# to emulate log E exp(-2 d^2), where 2 d^2 is around 4, we scale model T with r, and use log E exp(- r^2 T^2), and let r^2 T^2 to be around 4 | |
# so r^2 approx 8 / timeout^2 and r approx 2.82 / timeout. If timeout = 850, this = 300 | |
@attrs.define(kw_only=True) | |
- class Conf: | |
- # config / argparse uses this to specify behavior | |
- | |
+ class Conf(GlobalPushLossBase.Conf): | |
enabled: bool = False | |
- detach_goal: bool = False | |
- detach_proj_goal: bool = False | |
- weight: float = attrs.field(default=1., validator=attrs.validators.gt(0)) | |
inv_scale: float = attrs.field(default=300., validator=attrs.validators.ge(1e-3)) | |
@@ -231,37 +361,33 @@ class GlobalPushRBFLoss(CriticLossBase): | |
return None | |
return GlobalPushRBFLoss( | |
weight=self.weight, | |
+ weight_future_goal=self.weight_future_goal, | |
detach_goal=self.detach_goal, | |
detach_proj_goal=self.detach_proj_goal, | |
+ detach_qmet=self.detach_qmet, | |
+ step_cost=self.step_cost, | |
+ clamp_max_future_goal=self.clamp_max_future_goal, | |
inv_scale=self.inv_scale, | |
) | |
- weight: float | |
- detach_goal: bool | |
- detach_proj_goal: bool | |
inv_scale: float | |
- def __init__(self, *, weight: float, detach_goal: bool, detach_proj_goal: bool, inv_scale: float): | |
- super().__init__() | |
- self.weight = weight | |
- self.detach_goal = detach_goal | |
- self.detach_proj_goal = detach_proj_goal | |
+ def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, detach_qmet: bool, step_cost: float, clamp_max_future_goal: bool, | |
+ inv_scale: float): | |
+ super().__init__(weight=weight, weight_future_goal=weight_future_goal, detach_goal=detach_goal, | |
+ detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, clamp_max_future_goal=clamp_max_future_goal) | |
self.inv_scale = inv_scale | |
- def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult: | |
- # To randomly pair zx, zy, we just roll over zy by 1, because zx and zy | |
- # are latents of randomly ordered random batches. | |
- zgoal = torch.roll(critic_batch_info.zy, 1, dims=0) | |
- if self.detach_goal: | |
- zgoal = zgoal.detach() | |
- dists = critic_batch_info.critic.quasimetric_model(critic_batch_info.zx, zgoal, | |
- proj_grad_enabled=(True, not self.detach_proj_goal)) | |
- inv_scale = dists.detach().square().mean().div(2).sqrt().clamp(1e-3, self.inv_scale) # make E[d^2]/r^2 approx 2 | |
- tsfm_dist: torch.Tensor = (dists / inv_scale).square().neg().exp() | |
+ def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor, dist: torch.Tensor, weight: float) -> LossResult: | |
+ inv_scale = dist.detach().square().mean().div(2).sqrt().clamp(1e-3, self.inv_scale) # make E[d^2]/r^2 approx 2 | |
+ tsfm_dist: torch.Tensor = (dist / inv_scale).square().neg().exp() | |
rbf_potential = tsfm_dist.mean().log() | |
return LossResult(loss=rbf_potential * self.weight, | |
- info=dict(dist=dists.mean(), inv_scale=inv_scale, | |
+ info=dict(dist=dist.mean(), inv_scale=inv_scale, | |
tsfm_dist=tsfm_dist, rbf_potential=rbf_potential)) # type: ignore | |
def extra_repr(self) -> str: | |
- return f"weight={self.weight:g}, detach_proj_goal={self.detach_proj_goal}, inv_scale={self.inv_scale:g}" | |
+ return '\n'.join([ | |
+ super().extra_repr(), | |
+ f"inv_scale={self.inv_scale:g}", | |
+ ]) | |
diff --git a/quasimetric_rl/modules/quasimetric_critic/losses/latent_dynamics.py b/quasimetric_rl/modules/quasimetric_critic/losses/latent_dynamics.py | |
index 2701e6f..f5e8248 100644 | |
--- a/quasimetric_rl/modules/quasimetric_critic/losses/latent_dynamics.py | |
+++ b/quasimetric_rl/modules/quasimetric_critic/losses/latent_dynamics.py | |
@@ -35,6 +35,7 @@ class LatentDynamicsLoss(CriticLossBase): | |
detach_sp: bool = False | |
detach_proj_sp: bool = False | |
detach_qmet: bool = False | |
+ non_quasimetric_dim_mse_weight: float = attrs.field(default=0., validator=attrs.validators.ge(0)) | |
def make(self) -> 'LatentDynamicsLoss': | |
return LatentDynamicsLoss( | |
@@ -45,6 +46,7 @@ class LatentDynamicsLoss(CriticLossBase): | |
detach_sp=self.detach_sp, | |
detach_proj_sp=self.detach_proj_sp, | |
detach_qmet=self.detach_qmet, | |
+ non_quasimetric_dim_mse_weight=self.non_quasimetric_dim_mse_weight, | |
init_lagrange_multiplier=self.init_lagrange_multiplier, | |
) | |
@@ -55,13 +57,15 @@ class LatentDynamicsLoss(CriticLossBase): | |
detach_sp: bool | |
detach_proj_sp: bool | |
detach_qmet: bool | |
+ non_quasimetric_dim_mse_weight: float | |
c: float | |
init_lagrange_multiplier: float | |
def __init__(self, *, epsilon: float, bidirectional: bool, | |
gamma: Optional[float], init_lagrange_multiplier: float, | |
# weight: float, | |
- detach_qmet: bool, detach_proj_sp: bool, detach_sp: bool): | |
+ detach_qmet: bool, detach_proj_sp: bool, detach_sp: bool, | |
+ non_quasimetric_dim_mse_weight: float): | |
super().__init__() | |
# self.weight = weight | |
self.epsilon = epsilon | |
@@ -70,6 +74,7 @@ class LatentDynamicsLoss(CriticLossBase): | |
self.detach_qmet = detach_qmet | |
self.detach_sp = detach_sp | |
self.detach_proj_sp = detach_proj_sp | |
+ self.non_quasimetric_dim_mse_weight = non_quasimetric_dim_mse_weight | |
self.init_lagrange_multiplier = init_lagrange_multiplier | |
self.raw_lagrange_multiplier = nn.Parameter( | |
torch.tensor(softplus_inv_float(init_lagrange_multiplier), dtype=torch.float32)) | |
@@ -85,6 +90,7 @@ class LatentDynamicsLoss(CriticLossBase): | |
dists = critic.quasimetric_model(zy, pred_zy, bidirectional=self.bidirectional, # at least optimize d(s', hat{s'}) | |
proj_grad_enabled=(self.detach_proj_sp, True)) | |
+ | |
lagrange_mult = F.softplus(self.raw_lagrange_multiplier) # make positive | |
# lagrange multiplier is minimax training, so grad_mul -1 | |
lagrange_mult = grad_mul(lagrange_mult, -1) | |
@@ -101,6 +107,16 @@ class LatentDynamicsLoss(CriticLossBase): | |
loss = violation * lagrange_mult | |
info.update(violation=violation, lagrange_mult=lagrange_mult) | |
+ if self.non_quasimetric_dim_mse_weight > 0: | |
+ assert critic.quasimetric_model.input_slice_size < critic.quasimetric_model.input_size, \ | |
+ "non-quasimetric dim mse only makes sense if input_slice_size < input_size, but got " \ | |
+ f"{critic.quasimetric_model.input_slice_size} >= {critic.quasimetric_model.input_size}" | |
+ _zy = zy[..., critic.quasimetric_model.input_slice_size:] | |
+ _pred_zy = pred_zy[..., critic.quasimetric_model.input_slice_size:] | |
+ non_quasimetric_dim_mse = F.mse_loss(_pred_zy, _zy) | |
+ info.update(non_quasimetric_dim_mse=non_quasimetric_dim_mse) | |
+ loss += non_quasimetric_dim_mse * self.non_quasimetric_dim_mse_weight | |
+ | |
if self.bidirectional: | |
dist_p2n, dist_n2p = dists.flatten(0, -2).mean(0).unbind(-1) | |
info.update(dist_p2n=dist_p2n, dist_n2p=dist_n2p) | |
diff --git a/quasimetric_rl/modules/quasimetric_critic/models/__init__.py b/quasimetric_rl/modules/quasimetric_critic/models/__init__.py | |
index f9e05c4..8789800 100644 | |
--- a/quasimetric_rl/modules/quasimetric_critic/models/__init__.py | |
+++ b/quasimetric_rl/modules/quasimetric_critic/models/__init__.py | |
@@ -91,6 +91,9 @@ class QuasimetricCritic(Module): | |
else: | |
return self.encoder | |
+ def get_encoder(self, target: bool = False) -> Encoder: | |
+ return self.target_encoder if target else self.encoder | |
+ | |
@torch.no_grad() | |
def update_target_encoder_(self): | |
if not self.borrowing_embedding and self.target_encoder_ema is not None: | |
diff --git a/quasimetric_rl/modules/quasimetric_critic/models/quasimetric_model.py b/quasimetric_rl/modules/quasimetric_critic/models/quasimetric_model.py | |
index 5d19754..92fbc58 100644 | |
--- a/quasimetric_rl/modules/quasimetric_critic/models/quasimetric_model.py | |
+++ b/quasimetric_rl/modules/quasimetric_critic/models/quasimetric_model.py | |
@@ -41,11 +41,11 @@ def create_quasimetric_head_from_spec(spec: str) -> torchqmet.QuasimetricBase: | |
assert dim % components == 0, "IQE: dim must be divisible by components" | |
return torchqmet.IQE(dim, dim // components) | |
- def iqe2(*, dim: int, components: int, scale: bool = False, norm_delta: bool = False, fake_grad: bool = False) -> torchqmet.IQE: | |
+ def iqe2(*, dim: int, components: int, norm_delta: bool = False, fake_grad: bool = False, reduction: str = 'maxl12_sm') -> torchqmet.IQE: | |
assert dim % components == 0, "IQE: dim must be divisible by components" | |
return torchqmet.IQE2( | |
dim, dim // components, | |
- reduction='maxl12_sm' if not scale else 'maxl12_sm_scale', | |
+ reduction=reduction, | |
learned_delta=True, | |
learned_div=False, | |
div_init_mul=0.25, | |
@@ -80,6 +80,7 @@ class QuasimetricModel(Module): | |
class Conf: | |
# config / argparse uses this to specify behavior | |
+ input_slice_size: Optional[int] = attrs.field(default=None, validator=attrs.validators.optional(attrs.validators.gt(0))) # take the first n dims | |
projector_arch: Optional[Tuple[int, ...]] = (512,) | |
projector_layer_norm: bool = True | |
projector_dropout: float = attrs.field(default=0., validator=attrs.validators.ge(0)) # TD-MPC2 uses 0.01 | |
@@ -88,8 +89,11 @@ class QuasimetricModel(Module): | |
quasimetric_head_spec: str = 'iqe(dim=2048,components=64)' | |
def make(self, *, input_size: int) -> 'QuasimetricModel': | |
+ if self.input_slice_size is not None: | |
+ assert self.input_slice_size <= input_size, f'input_slice_size={self.input_slice_size} > input_size={input_size}' | |
return QuasimetricModel( | |
input_size=input_size, | |
+ input_slice_size=self.input_slice_size or input_size, | |
projector_arch=self.projector_arch, | |
projector_layer_norm=self.projector_layer_norm, | |
projector_dropout=self.projector_dropout, | |
@@ -99,22 +103,24 @@ class QuasimetricModel(Module): | |
) | |
input_size: int | |
+ input_slice_size: int | |
projector: Union[Identity, MLP] | |
quasimetric_head: torchqmet.QuasimetricBase | |
- def __init__(self, *, input_size: int, projector_arch: Optional[Tuple[int, ...]], | |
+ def __init__(self, *, input_size: int, input_slice_size: int, projector_arch: Optional[Tuple[int, ...]], | |
projector_layer_norm: bool, projector_dropout: float, projector_weight_norm: bool, | |
projector_unit_norm: bool, quasimetric_head_spec: str): | |
super().__init__() | |
self.input_size = input_size | |
+ self.input_slice_size = input_slice_size | |
self.quasimetric_head = create_quasimetric_head_from_spec(quasimetric_head_spec) | |
if projector_arch is None: | |
- assert input_size == self.quasimetric_head.input_size, \ | |
- f'no projector but latent input_size={input_size}, quasimetric_head.input_size={self.quasimetric_head.input_size}' | |
+ assert input_slice_size == self.quasimetric_head.input_size, \ | |
+ f'no projector but latent input_slice_size={input_slice_size}, quasimetric_head.input_size={self.quasimetric_head.input_size}' | |
self.projector = Identity() | |
else: | |
self.projector = MLP( | |
- input_size, self.quasimetric_head.input_size, | |
+ input_slice_size, self.quasimetric_head.input_size, | |
hidden_sizes=projector_arch, | |
layer_norm=projector_layer_norm, | |
dropout=projector_dropout, | |
@@ -123,6 +129,8 @@ class QuasimetricModel(Module): | |
def forward(self, zx: LatentTensor, zy: LatentTensor, *, bidirectional: bool = False, | |
proj_grad_enabled: Tuple[bool, bool] = (True, True)) -> torch.Tensor: | |
+ zx = zx[..., :self.input_slice_size] | |
+ zy = zy[..., :self.input_slice_size] | |
with self.projector.requiring_grad(proj_grad_enabled[0]): | |
px = self.projector(zx) # [B x D] | |
with self.projector.requiring_grad(proj_grad_enabled[1]): | |
@@ -149,4 +157,4 @@ class QuasimetricModel(Module): | |
return super().__call__(zx, zy, bidirectional=bidirectional, proj_grad_enabled=proj_grad_enabled) | |
def extra_repr(self) -> str: | |
- return f"input_size={self.input_size}" | |
+ return f"input_size={self.input_size}, input_slice_size={self.input_slice_size}" | |
diff --git a/quasimetric_rl/utils/logging.py b/quasimetric_rl/utils/logging.py | |
index 2a9e23f..6efab8d 100644 | |
--- a/quasimetric_rl/utils/logging.py | |
+++ b/quasimetric_rl/utils/logging.py | |
@@ -4,6 +4,8 @@ | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
+from typing import * | |
+ | |
import sys | |
import os | |
import logging | |
@@ -34,10 +36,12 @@ class TqdmLoggingHandler(logging.Handler): | |
class MultiLineFormatter(logging.Formatter): | |
+ _fmt: str | |
def __init__(self, fmt=None, datefmt=None, style='%'): | |
assert style == '%' | |
super(MultiLineFormatter, self).__init__(fmt, datefmt, style) | |
+ assert fmt is not None | |
self.multiline_fmt = fmt | |
def format(self, record): | |
@@ -75,7 +79,7 @@ class MultiLineFormatter(logging.Formatter): | |
output += '\n'.join( | |
self.multiline_fmt % dict(record.__dict__, message=line) | |
for index, line | |
- in enumerate(record.exc_text.decode(sys.getfilesystemencoding(), 'replace').splitlines()) | |
+ in enumerate(record.exc_text.decode(sys.getfilesystemencoding(), 'replace').splitlines()) # type: ignore | |
) | |
return output | |
@@ -96,7 +100,7 @@ def configure(logging_file, log_level=logging.INFO, level_prefix='', prefix='', | |
sys.excepthook = handle_exception # automatically log uncaught errors | |
- handlers = [] | |
+ handlers: List[logging.Handler] = [] | |
if write_to_stdout: | |
handlers.append(TqdmLoggingHandler()) | |
Submodule third_party/torch-quasimetric c5213ff..0fce12e: | |
diff --git a/third_party/torch-quasimetric/torchqmet/__init__.py b/third_party/torch-quasimetric/torchqmet/__init__.py | |
index 0008b7a..3afb467 100644 | |
--- a/third_party/torch-quasimetric/torchqmet/__init__.py | |
+++ b/third_party/torch-quasimetric/torchqmet/__init__.py | |
@@ -58,7 +58,10 @@ class QuasimetricBase(nn.Module, metaclass=abc.ABCMeta): | |
assert x.shape[-1] == y.shape[-1] == self.input_size | |
d = self.compute_components(x, y) | |
d: torch.Tensor = self.transforms(d) | |
- return self.reduction(d) * self.scale | |
+ scale = self.scale | |
+ if not self.training: | |
+ scale = scale.detach() | |
+ return self.reduction(d) * scale | |
def __call__(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
# Manually define for typing | |
diff --git a/third_party/torch-quasimetric/torchqmet/iqe.py b/third_party/torch-quasimetric/torchqmet/iqe.py | |
index bc03f05..a8e8c92 100644 | |
--- a/third_party/torch-quasimetric/torchqmet/iqe.py | |
+++ b/third_party/torch-quasimetric/torchqmet/iqe.py | |
@@ -12,7 +12,6 @@ from . import QuasimetricBase | |
# The PQELH function. | |
-@torch.jit.script | |
def f_PQELH(h: torch.Tensor): # PQELH: strictly monotonically increasing mapping from [0, +infty) -> [0, 1) | |
return -torch.expm1(-h) | |
@@ -21,22 +20,22 @@ def iqe_tensor_delta(x: torch.Tensor, y: torch.Tensor, delta: torch.Tensor, div_ | |
D = x.shape[-1] # D: component_dim | |
# ignore pairs that x >= y | |
- valid = (x < y) | |
+ valid = (x < y) # [..., K, D] | |
# sort to better count | |
- xy = torch.cat(torch.broadcast_tensors(x, y), dim=-1) | |
+ xy = torch.cat(torch.broadcast_tensors(x, y), dim=-1) # [..., K, 2D] | |
sxy, ixy = xy.sort(dim=-1) | |
# neg_inc: the **negated** increment of **input** of f at sorted locations | |
# inc = torch.gather(delta * valid, dim=-1, index=ixy % D) * torch.where(ixy < D, 1, -1) | |
- neg_inc = torch.gather(delta * valid, dim=-1, index=ixy % D) * torch.where(ixy < D, -1, 1) | |
+ neg_inc = torch.gather(delta * valid, dim=-1, index=ixy % D) * torch.where(ixy < D, -1, 1) # [..., K, 2D-sort] | |
# neg_incf: the **negated** increment of **output** of f at sorted locations | |
neg_f_input = torch.cumsum(neg_inc, dim=-1) / div_pre_f[:, None] | |
if fake_grad: | |
neg_f_input__grad_path = neg_f_input.clone() | |
- neg_f_input__grad_path.data.clamp_(max=17) # fake grad | |
+ neg_f_input__grad_path.data.clamp_(min=-15) # fake grad | |
neg_f_input = neg_f_input__grad_path + ( | |
neg_f_input - neg_f_input__grad_path | |
).detach() | |
@@ -95,13 +94,29 @@ def iqe(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
return (sxy * neg_incf).sum(-1) | |
-if torch.__version__ >= '2.0.1' and False: # well, broken process pool in notebooks | |
- iqe = torch.compile(iqe) | |
- iqe_tensor_delta = torch.compile(iqe_tensor_delta) | |
+def is_notebook(): | |
+ r""" | |
+ Inspired by | |
+ https://github.com/tqdm/tqdm/blob/cc372d09dcd5a5eabdc6ed4cf365bdb0be004d44/tqdm/autonotebook.py | |
+ """ | |
+ import sys | |
+ try: | |
+ get_ipython = sys.modules['IPython'].get_ipython | |
+ if 'IPKernelApp' not in get_ipython().config: # pragma: no cover | |
+ raise ImportError("console") | |
+ except Exception: | |
+ return False | |
+ else: # pragma: no cover | |
+ return True | |
+ | |
+ | |
+if torch.__version__ >= '2.0.1' and not is_notebook(): # well, broken process pool in notebooks | |
+ iqe = torch.compile(iqe, mode="max-autotune") | |
+ iqe_tensor_delta = torch.compile(iqe_tensor_delta, mode="max-autotune") | |
# iqe = torch.compile(iqe, dynamic=True) | |
else: | |
- iqe = torch.jit.script(iqe) | |
- iqe_tensor_delta = torch.jit.script(iqe_tensor_delta) | |
+ iqe = torch.jit.script(iqe) # type: ignore | |
+ iqe_tensor_delta = torch.jit.script(iqe_tensor_delta) # type: ignore | |
class IQE(QuasimetricBase): | |
@@ -231,8 +246,8 @@ class IQE2(IQE): | |
ema_weight: float = 0.95): | |
super().__init__(input_size, dim_per_component, transforms=transforms, reduction=reduction, | |
discount=discount, warn_if_not_quasimetric=warn_if_not_quasimetric) | |
- self.component_dropout_thresh = tuple(component_dropout_thresh) | |
- self.dropout_p_thresh = tuple(dropout_p_thresh) | |
+ self.component_dropout_thresh = tuple(component_dropout_thresh) # type: ignore | |
+ self.dropout_p_thresh = tuple(dropout_p_thresh) # type: ignore | |
self.dropout_batch_frac = float(dropout_batch_frac) | |
self.fake_grad = fake_grad | |
assert 0 <= self.dropout_batch_frac <= 1 | |
@@ -249,7 +264,7 @@ class IQE2(IQE): | |
# ) | |
self.register_parameter( | |
'raw_delta', | |
- torch.nn.Parameter( | |
+ torch.nn.Parameter( # type: ignore | |
torch.zeros(self.latent_2d_shape).requires_grad_() | |
) | |
) | |
@@ -270,7 +285,7 @@ class IQE2(IQE): | |
self.register_parameter( | |
'raw_div', | |
- torch.nn.Parameter(torch.zeros(self.num_components).requires_grad_()) | |
+ torch.nn.Parameter(torch.zeros(self.num_components).requires_grad_()) # type: ignore | |
) | |
else: | |
self.register_buffer( | |
@@ -285,8 +300,8 @@ class IQE2(IQE): | |
self.div_init_mul = div_init_mul | |
self.mul_kind = mul_kind | |
- self.last_components = None | |
- self.last_drop_p = None | |
+ self.last_components = None # type: ignore | |
+ self.last_drop_p = None # type: ignore | |
def compute_components(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
diff --git a/third_party/torch-quasimetric/torchqmet/reductions.py b/third_party/torch-quasimetric/torchqmet/reductions.py | |
index 7681242..a87be8b 100644 | |
--- a/third_party/torch-quasimetric/torchqmet/reductions.py | |
+++ b/third_party/torch-quasimetric/torchqmet/reductions.py | |
@@ -59,6 +59,11 @@ class Mean(ReductionBase): | |
return d.mean(dim=-1) | |
+class L2(ReductionBase): | |
+ def reduce_distance(self, d: torch.Tensor) -> torch.Tensor: | |
+ return d.norm(p=2, dim=-1) | |
+ | |
+ | |
class MaxMean(ReductionBase): | |
r''' | |
`maxmean` from Neural Norms paper: | |
@@ -144,7 +149,7 @@ class MaxL12_PGsm(ReductionBase): | |
super().__init__(input_num_components=input_num_components, discount=discount) | |
self.raw_alpha = nn.Parameter(torch.tensor([0., 0., 0., 0.], dtype=torch.float32).requires_grad_()) # pre normalizing | |
self.raw_alpha_w = nn.Parameter(torch.tensor([0., 0., 0.], dtype=torch.float32).requires_grad_()) # pre normalizing | |
- self.last_logp = None | |
+ self.last_logp = None # type: ignore | |
self.on_pi = True | |
# self.last_p = None | |
@@ -222,7 +227,7 @@ class MaxL12_PG3(ReductionBase): | |
super().__init__(input_num_components=input_num_components, discount=discount) | |
self.raw_alpha = nn.Parameter(torch.tensor([0., 0., 0.], dtype=torch.float32).requires_grad_()) # pre normalizing | |
self.raw_alpha_w = torch.tensor([], dtype=torch.float32) # just to make logging easier | |
- self.last_logp = None | |
+ self.last_logp = None # type: ignore | |
self.on_pi = True | |
# self.last_p = None | |
@@ -322,6 +327,7 @@ class DeepLinearNetWeightedSum(ReductionBase): | |
REDUCTIONS: Mapping[str, Type[ReductionBase]] = dict( | |
sum=Sum, | |
mean=Mean, | |
+ l2=L2, | |
maxmean=MaxMean, | |
maxl12=MaxL12, | |
maxl12_sm=MaxL12_sm, |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment