Created
March 15, 2024 14:59
-
-
Save ssnl/99a0f80224dd69026858007620532f87 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/offline/main.py b/offline/main.py | |
index 1581d30..5502749 100644 | |
--- a/offline/main.py | |
+++ b/offline/main.py | |
@@ -43,7 +43,6 @@ class Conf(BaseConf): | |
total_optim_steps: int = attrs.field(default=int(2e5), validator=attrs.validators.gt(0)) | |
log_steps: int = attrs.field(default=250, validator=attrs.validators.gt(0)) # type: ignore | |
- eval_before_training: bool = False | |
eval_steps: int = attrs.field(default=20000, validator=attrs.validators.gt(0)) # type: ignore | |
save_steps: int = attrs.field(default=50000, validator=attrs.validators.gt(0)) # type: ignore | |
num_eval_episodes: int = attrs.field(default=50, validator=attrs.validators.ge(0)) # type: ignore | |
@@ -58,7 +57,7 @@ cs.store(name='config', node=Conf()) | |
@hydra.main(version_base=None, config_name="config") | |
def train(dict_cfg: DictConfig): | |
cfg: Conf = Conf.from_DictConfig(dict_cfg) # type: ignore | |
- wandb_run = cfg.setup_for_experiment() # checking & setup logging | |
+ cfg.setup_for_experiment() # checking & setup logging | |
dataset = cfg.env.make() | |
@@ -117,7 +116,7 @@ def train(dict_cfg: DictConfig): | |
logging.info(f"Checkpointed to {relpath}") | |
def eval(epoch, it, optim_steps): | |
- val_result_allenvs = trainer.evaluate(desc=f"opt{optim_steps:08d}") | |
+ val_result_allenvs = trainer.evaluate() | |
val_results.clear() | |
val_results.append(dict( | |
epoch=epoch, | |
@@ -129,10 +128,25 @@ def train(dict_cfg: DictConfig): | |
epoch=epoch, | |
it=it, | |
optim_steps=optim_steps, | |
- result={ | |
- k: val_result.summarize() for k, val_result in val_result_allenvs.items() | |
- }, | |
+ result={}, | |
)) | |
+ for k, val_result in val_result_allenvs.items(): | |
+ succ_rate_ts = ( | |
+ None if val_result.timestep_is_success is None | |
+ else torch.stack([_x.mean(dtype=torch.float32) for _x in val_result.timestep_is_success]) | |
+ ) | |
+ hitting_time = val_result.capped_hitting_time | |
+ summary = dict( | |
+ epi_return=val_result.episode_return, | |
+ epi_score=val_result.episode_score, | |
+ succ_rate_ts=succ_rate_ts, | |
+ succ_rate=val_result.is_success, | |
+ hitting_time=hitting_time, | |
+ ) | |
+ for kk, v in val_result.extra_timestep_results.items(): | |
+ summary[kk] = torch.stack([_v.mean(dtype=torch.float32) for _v in v]) | |
+ summary[kk + '_last'] = torch.stack([_v[-1] for _v in v]) | |
+ val_summaries[-1]['result'][k] = summary | |
averaged_info = cfg.stage_log_info(dict(eval=val_summaries[-1]), optim_steps) | |
with open(os.path.join(cfg.output_dir, 'eval.log'), 'a') as f: # type: ignore | |
print(json.dumps(averaged_info), file=f) | |
@@ -170,18 +184,15 @@ def train(dict_cfg: DictConfig): | |
) | |
num_total_epochs = int(np.ceil(cfg.total_optim_steps / trainer.num_batches)) | |
- logging.info(f"save folder: {cfg.output_dir}") | |
- logging.info(f"wandb: {wandb_run.get_url()}") | |
# Training loop | |
optim_steps = 0 | |
- if cfg.eval_before_training: | |
- eval(0, 0, optim_steps) | |
+ # eval(0, 0, optim_steps) | |
save(0, 0, optim_steps) | |
if start_epoch < num_total_epochs: | |
for epoch in range(num_total_epochs): | |
epoch_desc = f"Train epoch {epoch:05d}/{num_total_epochs:05d}" | |
- for it, (data, data_info) in enumerate(tqdm(trainer.iter_training_data(), total=trainer.num_batches, desc=epoch_desc, leave=False)): | |
+ for it, (data, data_info) in enumerate(tqdm(trainer.iter_training_data(), total=trainer.num_batches, desc=epoch_desc)): | |
step_counter.update_then_record_alerts() | |
optim_steps += 1 | |
@@ -220,9 +231,4 @@ if __name__ == '__main__': | |
# set up some hydra flags before parsing | |
os.environ['HYDRA_FULL_ERROR'] = str(int(FLAGS.DEBUG)) | |
- try: | |
- train() | |
- except: | |
- import wandb | |
- wandb.finish(1) # sometimes crashes are not reported?? let's be safe | |
- raise | |
+ train() | |
diff --git a/offline/trainer.py b/offline/trainer.py | |
index c75abed..e7f48cf 100644 | |
--- a/offline/trainer.py | |
+++ b/offline/trainer.py | |
@@ -11,7 +11,52 @@ import torch | |
import torch.utils.data | |
from quasimetric_rl.modules import QRLConf, QRLAgent, QRLLosses, InfoT | |
-from quasimetric_rl.data import BatchData, Dataset, EpisodeData, EnvSpec, OfflineEnv, interaction | |
+from quasimetric_rl.data import BatchData, Dataset, EpisodeData, EnvSpec, OfflineEnv | |
+ | |
+ | |
+def first_nonzero(arr: torch.Tensor, dim: int = -1, invalid_val: int = -1): | |
+ mask = (arr != 0) | |
+ return torch.where(mask.any(dim=dim), mask.to(torch.uint8).argmax(dim=dim), invalid_val) | |
+ | |
+ | |
+@attrs.define(kw_only=True) | |
+class EvalEpisodeResult: | |
+ timestep_reward: List[torch.Tensor] | |
+ episode_return: torch.Tensor | |
+ episode_score: torch.Tensor | |
+ timestep_is_success: Optional[List[torch.Tensor]] | |
+ is_success: Optional[torch.Tensor] | |
+ hitting_time: Optional[torch.Tensor] | |
+ extra_timestep_results: Mapping[str, List[torch.Tensor]] | |
+ | |
+ @property | |
+ def capped_hitting_time(self) -> Optional[torch.Tensor]: | |
+ # if not hit -> |ts| + 1 | |
+ if self.hitting_time is None: | |
+ return None | |
+ assert self.timestep_is_success is not None | |
+ return torch.stack([torch.where(_x < 0, _succ.shape[0] + 1, _x) for _x, _succ in zip(self.hitting_time, self.timestep_is_success)]) | |
+ | |
+ @classmethod | |
+ def from_timestep_reward_is_success(cls, dataset: Dataset, | |
+ timestep_reward: List[torch.Tensor], | |
+ timestep_is_success: Optional[List[torch.Tensor]], | |
+ extra_timestep_results) -> Self: | |
+ return cls( | |
+ timestep_reward=timestep_reward, | |
+ episode_return=torch.stack([r.sum() for r in timestep_reward]), | |
+ episode_score=dataset.normalize_score(timestep_reward), | |
+ timestep_is_success=timestep_is_success, | |
+ is_success=( | |
+ None if timestep_is_success is None | |
+ else torch.stack([_x.any(dim=-1) for _x in timestep_is_success]) | |
+ ), | |
+ hitting_time=( | |
+ None if timestep_is_success is None | |
+ else torch.stack([first_nonzero(_x, dim=-1) for _x in timestep_is_success]) | |
+ ), # NB this is off by 1 | |
+ extra_timestep_results=dict(extra_timestep_results), | |
+ ) | |
class Trainer(object): | |
@@ -88,20 +133,32 @@ class Trainer(object): | |
adistn = self.agent.actor(obs[None].to(self.device), goal[None].to(self.device)) | |
return adistn.mode.cpu().numpy()[0] | |
- rollout = interaction.collect_rollout( | |
+ rollout = Dataset.collect_rollout_general( | |
actor, env=env, env_spec=EnvSpec.from_env(env), | |
max_episode_length=env.max_episode_steps) | |
return rollout | |
- def evaluate(self, desc=None) -> Mapping[str, interaction.EvalEpisodeResult]: | |
+ def evaluate(self) -> Mapping[str, EvalEpisodeResult]: | |
envs = self.dataset.create_eval_envs(self.eval_seed) | |
- results: Dict[str, interaction.EvalEpisodeResult] = {} | |
+ results: Dict[str, EvalEpisodeResult] = {} | |
for k, env in envs.items(): | |
rollouts: List[EpisodeData] = [] | |
- this_desc = f'eval/{k}' | |
- if desc is not None: | |
- this_desc = f'{desc}/{this_desc}' | |
- for _ in tqdm(range(self.num_eval_episodes), desc=this_desc): | |
+ for _ in tqdm(range(self.num_eval_episodes), desc=f'eval/{k}'): | |
rollouts.append(self.collect_eval_rollout(env=env)) | |
- results[k] = interaction.EvalEpisodeResult.from_episode_rollouts(self.dataset, rollouts) | |
+ results[k] = EvalEpisodeResult.from_timestep_reward_is_success( | |
+ self.dataset, | |
+ timestep_reward=[rollout.rewards for rollout in rollouts], | |
+ timestep_is_success=( | |
+ None | |
+ if len(rollouts) == 0 or 'is_success' not in rollouts[0].transition_infos | |
+ else [rollout.transition_infos['is_success'] for rollout in rollouts] | |
+ ), | |
+ extra_timestep_results=( | |
+ {} if len(rollouts) == 0 else | |
+ { | |
+ k: [rollout.transition_infos[k] for rollout in rollouts] | |
+ for k in rollouts[0].transition_infos.keys() if k != 'is_success' | |
+ } | |
+ ), | |
+ ) | |
return results | |
diff --git a/online/main.py b/online/main.py | |
index c16e094..cd58f7a 100644 | |
--- a/online/main.py | |
+++ b/online/main.py | |
@@ -50,7 +50,7 @@ cs.store(name='config', node=Conf()) | |
@hydra.main(version_base=None, config_name="config") | |
def train(dict_cfg: DictConfig): | |
cfg: Conf = Conf.from_DictConfig(dict_cfg) # type: ignore | |
- wandb_run = cfg.setup_for_experiment() # checking & setup logging | |
+ cfg.setup_for_experiment() # checking & setup logging | |
replay_buffer = cfg.env.make() | |
@@ -85,7 +85,7 @@ def train(dict_cfg: DictConfig): | |
logging.info(f"Checkpointed to {relpath}") | |
def eval(env_steps, optim_steps): | |
- val_result_allenvs = trainer.evaluate(desc=f'env{env_steps:08d}_opt{optim_steps:08d}') | |
+ val_result_allenvs = trainer.evaluate() | |
val_results.clear() | |
val_results.append(dict( | |
env_steps=env_steps, | |
@@ -95,10 +95,18 @@ def train(dict_cfg: DictConfig): | |
val_summaries.append(dict( | |
env_steps=env_steps, | |
optim_steps=optim_steps, | |
- result={ | |
- k: val_result.summarize() for k, val_result in val_result_allenvs.items() | |
- }, | |
+ result={}, | |
)) | |
+ for k, val_result in val_result_allenvs.items(): | |
+ succ_rate_ts = val_result.timestep_is_success.mean(dtype=torch.float32, dim=-1) | |
+ hitting_time = val_result.capped_hitting_time | |
+ val_summaries[-1]['result'][k] = dict( | |
+ epi_return=val_result.episode_return, | |
+ epi_score=val_result.episode_score, | |
+ succ_rate_ts=succ_rate_ts, | |
+ succ_rate=val_result.is_success, | |
+ hitting_time=hitting_time, | |
+ ) | |
averaged_info = cfg.stage_log_info(dict(eval=val_summaries[-1]), optim_steps) | |
with open(os.path.join(cfg.output_dir, 'eval.log'), 'a') as f: # type: ignore | |
print(json.dumps(averaged_info), file=f) | |
@@ -113,8 +121,6 @@ def train(dict_cfg: DictConfig): | |
), | |
) | |
- logging.info(f"save folder: {cfg.output_dir}") | |
- logging.info(f"wandb: {wandb_run.get_url()}") | |
# Training loop | |
eval(0, 0); save(0, 0) | |
for optim_steps, (env_steps, next_iter_new_env_step, data, data_info) in enumerate(trainer.iter_training_data(), start=1): | |
@@ -156,9 +162,4 @@ if __name__ == '__main__': | |
# set up some hydra flags before parsing | |
os.environ['HYDRA_FULL_ERROR'] = str(int(FLAGS.DEBUG)) | |
- try: | |
- train() | |
- except: | |
- import wandb | |
- wandb.finish(1) # sometimes crashes are not reported?? let's be safe | |
- raise | |
+ train() | |
diff --git a/online/trainer.py b/online/trainer.py | |
index 8926871..01faad8 100644 | |
--- a/online/trainer.py | |
+++ b/online/trainer.py | |
@@ -12,11 +12,43 @@ import torch | |
import torch.utils.data | |
from quasimetric_rl.modules import QRLConf, QRLAgent, QRLLosses, InfoT | |
-from quasimetric_rl.data import BatchData, EpisodeData, interaction | |
+from quasimetric_rl.data import Dataset, BatchData, EpisodeData, MultiEpisodeData | |
from quasimetric_rl.data.online import ReplayBuffer, OnlineFixedLengthEnv | |
from quasimetric_rl.utils import tqdm | |
+def first_nonzero(arr: torch.Tensor, dim: int = -1, invalid_val: int = -1): | |
+ mask = (arr != 0) | |
+ return torch.where(mask.any(dim=dim), mask.to(torch.uint8).argmax(dim=dim), invalid_val) | |
+ | |
+ | |
+@attrs.define(kw_only=True) | |
+class EvalEpisodeResult: | |
+ timestep_reward: torch.Tensor | |
+ episode_return: torch.Tensor | |
+ episode_score: torch.Tensor | |
+ timestep_is_success: torch.Tensor | |
+ is_success: torch.Tensor | |
+ hitting_time: torch.Tensor | |
+ | |
+ @property | |
+ def capped_hitting_time(self) -> torch.Tensor: | |
+ # if not hit -> |ts| + 1 | |
+ return torch.stack([torch.where(_x < 0, self.timestep_is_success.shape[-1] + 1, _x) for _x in self.hitting_time]) | |
+ | |
+ @classmethod | |
+ def from_timestep_reward_is_success(cls, dataset: Dataset, timestep_reward: torch.Tensor, | |
+ timestep_is_success: torch.Tensor) -> Self: | |
+ return cls( | |
+ timestep_reward=timestep_reward, | |
+ episode_return=timestep_reward.sum(-1), | |
+ episode_score=dataset.normalize_score(cast(Sequence[torch.Tensor], timestep_reward)), | |
+ timestep_is_success=timestep_is_success, | |
+ is_success=timestep_is_success.any(dim=-1), | |
+ hitting_time=first_nonzero(timestep_is_success, dim=-1), # NB this is off by 1 | |
+ ) | |
+ | |
+ | |
@attrs.define(kw_only=True) | |
class InteractionConf: | |
total_env_steps: int = attrs.field(default=int(1e6), validator=attrs.validators.gt(0)) | |
@@ -133,18 +165,23 @@ class Trainer(object): | |
self.replay.add_rollout(rollout) | |
return rollout | |
- def evaluate(self, desc=None) -> Mapping[str, interaction.EvalEpisodeResult]: | |
+ def evaluate(self) -> Mapping[str, EvalEpisodeResult]: | |
envs = self.make_evaluate_envs() | |
- results: Dict[str, interaction.EvalEpisodeResult] = {} | |
+ results: Dict[str, EvalEpisodeResult] = {} | |
for k, env in envs.items(): | |
rollouts = [] | |
- this_desc = f'eval/{k}' | |
- if desc is not None: | |
- this_desc = f'{desc}/{this_desc}' | |
- for _ in tqdm(range(self.num_eval_episodes), desc=this_desc): | |
+ for _ in tqdm(range(self.num_eval_episodes), desc=f'eval/{k}'): | |
rollouts.append(self.collect_rollout(eval=True, store=False, env=env)) | |
- results[k] = interaction.EvalEpisodeResult.from_episode_rollouts( | |
- self.replay, rollouts) | |
+ mrollouts = MultiEpisodeData.cat(rollouts) | |
+ results[k] = EvalEpisodeResult.from_timestep_reward_is_success( | |
+ self.replay, | |
+ mrollouts.rewards.reshape( | |
+ self.num_eval_episodes, env.episode_length, | |
+ ), | |
+ mrollouts.transition_infos['is_success'].reshape( | |
+ self.num_eval_episodes, env.episode_length, | |
+ ), | |
+ ) | |
return results | |
def iter_training_data(self) -> Iterator[Tuple[int, bool, BatchData, InfoT]]: | |
@@ -160,7 +197,7 @@ class Trainer(object): | |
""" | |
def yield_data(): | |
num_transitions = self.replay.num_transitions_realized | |
- for icyc in tqdm(range(self.num_samples_per_cycle), desc=f"{num_transitions} env steps, train batches", leave=False): | |
+ for icyc in tqdm(range(self.num_samples_per_cycle), desc=f"{num_transitions} env steps, train batches"): | |
data_t0 = time.time() | |
data = self.sample() | |
info = dict( | |
diff --git a/quasimetric_rl/base_conf.py b/quasimetric_rl/base_conf.py | |
index e9f2202..b3a332e 100644 | |
--- a/quasimetric_rl/base_conf.py | |
+++ b/quasimetric_rl/base_conf.py | |
@@ -172,7 +172,7 @@ class BaseConf(abc.ABC): | |
else: | |
raise RuntimeError(f'Output directory {self.output_dir} exists and is complete') | |
- run = wandb.init( | |
+ wandb.init( | |
project=self.wandb_project, | |
name=self.output_folder.replace('/', '__') + '__' + datetime.now().strftime(r"%Y%m%d_%H:%M:%S"), | |
config=yaml.safe_load(OmegaConf.to_yaml(self)), | |
@@ -219,6 +219,3 @@ class BaseConf(abc.ABC): | |
if self.device.type == 'cuda' and self.device.index is not None: | |
torch.cuda.set_device(self.device.index) | |
- | |
- assert run is not None | |
- return run | |
diff --git a/quasimetric_rl/data/__init__.py b/quasimetric_rl/data/__init__.py | |
index 519796b..41fc413 100644 | |
--- a/quasimetric_rl/data/__init__.py | |
+++ b/quasimetric_rl/data/__init__.py | |
@@ -5,9 +5,8 @@ from .env_spec import EnvSpec | |
from . import online | |
from .online import register_online_env, OnlineFixedLengthEnv | |
from .offline import OfflineEnv | |
-from . import interaction | |
__all__ = [ | |
'BatchData', 'EpisodeData', 'MultiEpisodeData', 'Dataset', 'register_offline_env', | |
- 'EnvSpec', 'online', 'register_online_env', 'OnlineFixedLengthEnv', 'OfflineEnv', 'interaction', | |
+ 'EnvSpec', 'online', 'register_online_env', 'OnlineFixedLengthEnv', 'OfflineEnv', | |
] | |
diff --git a/quasimetric_rl/data/base.py b/quasimetric_rl/data/base.py | |
index 09711c4..c86f525 100644 | |
--- a/quasimetric_rl/data/base.py | |
+++ b/quasimetric_rl/data/base.py | |
@@ -36,7 +36,6 @@ class BatchData(TensorCollectionAttrsMixin): # TensorCollectionAttrsMixin has s | |
timeouts: torch.Tensor | |
future_observations: torch.Tensor # sampled! | |
- future_tdelta: torch.Tensor | |
@property | |
def device(self) -> torch.device: | |
@@ -144,12 +143,6 @@ class EpisodeData(MultiEpisodeData): | |
rewards=self.rewards[:t], | |
terminals=self.terminals[:t], | |
timeouts=self.timeouts[:t], | |
- observation_infos={ | |
- k: v[:t + 1] for k, v in self.observation_infos.items() | |
- }, | |
- transition_infos={ | |
- k: v[:t] for k, v in self.transition_infos.items() | |
- }, | |
) | |
@@ -242,7 +235,6 @@ class Dataset(torch.utils.data.Dataset): | |
indices_to_episode_timesteps: torch.Tensor | |
max_episode_length: int | |
# ----- | |
- device: torch.device | |
def create_env(self, *, dict_obseravtion: Optional[bool] = None, seed: Optional[int] = None, **kwargs) -> 'OfflineEnv': | |
from .offline import OfflineEnv | |
@@ -280,7 +272,6 @@ class Dataset(torch.utils.data.Dataset): | |
def __init__(self, kind: str, name: str, *, | |
future_observation_discount: float, | |
dummy: bool = False, # when you don't want to load data, e.g., in analysis | |
- device: torch.device = torch.device('cpu'), # FIXME: get some heuristic | |
) -> None: | |
self.kind = kind | |
self.name = name | |
@@ -307,22 +298,97 @@ class Dataset(torch.utils.data.Dataset): | |
indices_to_episode_timesteps.append(torch.arange(l, dtype=torch.int64)) | |
assert len(episodes) > 0, "must have at least one episode" | |
- self.raw_data = MultiEpisodeData.cat(episodes).to(device) | |
+ self.raw_data = MultiEpisodeData.cat(episodes) | |
- self.obs_indices_to_obs_index_in_episode = torch.cat(obs_indices_to_obs_index_in_episode, dim=0).to(device) | |
- self.indices_to_episode_indices = torch.cat(indices_to_episode_indices, dim=0).to(device) | |
- self.indices_to_episode_timesteps = torch.cat(indices_to_episode_timesteps, dim=0).to(device) | |
+ self.obs_indices_to_obs_index_in_episode = torch.cat(obs_indices_to_obs_index_in_episode, dim=0) | |
+ self.indices_to_episode_indices = torch.cat(indices_to_episode_indices, dim=0) | |
+ self.indices_to_episode_timesteps = torch.cat(indices_to_episode_timesteps, dim=0) | |
self.max_episode_length = int(self.raw_data.episode_lengths.max().item()) | |
- self.device = device | |
- | |
- # def max_bytes_used(self): | |
- # return self | |
def get_observations(self, obs_indices: torch.Tensor): | |
- return self.raw_data.all_observations[obs_indices.to(self.device)] | |
+ return self.raw_data.all_observations[obs_indices] | |
+ | |
+ @classmethod | |
+ def collect_rollout_general(cls, actor: Callable[[torch.Tensor, torch.Tensor, gym.Space], np.ndarray], *, | |
+ env: gym.Env, env_spec: EnvSpec, max_episode_length: int, | |
+ assert_exact_episode_length: bool = False, extra_transition_info_keys: Collection[str] = []) -> EpisodeData: | |
+ from .utils import get_empty_episode | |
+ | |
+ epi = get_empty_episode(env_spec, max_episode_length) | |
+ | |
+ # check observation space | |
+ obs_dict_keys = {'observation', 'achieved_goal', 'desired_goal'} | |
+ WRONG_OBS_ERR_MESSAGE = ( | |
+ f"{cls.__name__} collect_rollout only supports Dict " | |
+ f"observation space with keys {obs_dict_keys}, but got {env.observation_space}" | |
+ ) | |
+ assert isinstance(env.observation_space, gym.spaces.Dict), WRONG_OBS_ERR_MESSAGE | |
+ assert set(env.observation_space.spaces.keys()) == {'observation', 'achieved_goal', 'desired_goal'}, WRONG_OBS_ERR_MESSAGE | |
+ | |
+ observation_dict = cast(Mapping[str, np.ndarray], env.reset()) | |
+ observation: torch.Tensor = torch.as_tensor(observation_dict['observation'], dtype=torch.float32) | |
+ | |
+ goal: torch.Tensor = torch.as_tensor(observation_dict['desired_goal'], dtype=torch.float32) | |
+ agoal: torch.Tensor = torch.as_tensor(observation_dict['achieved_goal'], dtype=torch.float32) | |
+ epi.all_observations[0] = observation | |
+ epi.observation_infos['desired_goals'][0] = goal | |
+ epi.observation_infos['achieved_goals'][0] = agoal | |
+ if len(extra_transition_info_keys): | |
+ epi.transition_infos = dict(epi.transition_infos) | |
+ epi.transition_infos.update({ | |
+ k: torch.empty([max_episode_length], dtype=torch.float32) for k in extra_transition_info_keys | |
+ }) | |
+ | |
+ t = 0 | |
+ timeout = False | |
+ terminal = False | |
+ while not timeout and not terminal: | |
+ assert t < max_episode_length | |
+ | |
+ action = actor( | |
+ observation, | |
+ goal, | |
+ env_spec.action_space, | |
+ ) | |
+ transition_out = env.step(np.asarray(action)) | |
+ observation_dict, reward, terminal, info = transition_out[:3] + transition_out[-1:] # some BC | |
+ | |
+ observation = torch.tensor(observation_dict['observation'], dtype=torch.float32) # copy just in case | |
+ | |
+ goal: torch.Tensor = torch.as_tensor(observation_dict['desired_goal'], dtype=torch.float32) | |
+ agoal: torch.Tensor = torch.as_tensor(observation_dict['achieved_goal'], dtype=torch.float32) | |
+ | |
+ if 'is_success' in info: | |
+ is_success: bool = info['is_success'] | |
+ epi.transition_infos['is_success'][t] = is_success | |
+ else: | |
+ if t == 0: | |
+ # remove field | |
+ transition_infos = dict(epi.transition_infos) | |
+ del transition_infos['is_success'] | |
+ epi.transition_infos = transition_infos | |
+ | |
+ for k in extra_transition_info_keys: | |
+ epi.transition_infos[k][t] = info[k] | |
+ | |
+ epi.all_observations[t + 1] = observation | |
+ epi.actions[t] = torch.as_tensor(action, dtype=torch.float32) | |
+ epi.rewards[t] = reward | |
+ epi.observation_infos['desired_goals'][t + 1] = goal | |
+ epi.observation_infos['achieved_goals'][t + 1] = agoal | |
+ | |
+ t += 1 | |
+ timeout = info.get('TimeLimit.truncated', False) | |
+ if assert_exact_episode_length: | |
+ assert (timeout or terminal) == (t == max_episode_length) | |
+ | |
+ if t < max_episode_length: | |
+ epi = epi.first_t(t) | |
+ | |
+ return epi | |
def __getitem__(self, indices: torch.Tensor) -> BatchData: | |
- indices = torch.as_tensor(indices, device=self.device) | |
+ indices = torch.as_tensor(indices) | |
eindices = self.indices_to_episode_indices[indices] | |
obs_indices = indices + eindices # index for `observation`: skip the s_last from previous episodes | |
obs = self.get_observations(obs_indices) | |
@@ -332,7 +398,7 @@ class Dataset(torch.utils.data.Dataset): | |
tindices = self.indices_to_episode_timesteps[indices] | |
epilengths = self.raw_data.episode_lengths[eindices] # max idx is this | |
- deltas = torch.arange(self.max_episode_length, device=self.device) | |
+ deltas = torch.arange(self.max_episode_length) | |
pdeltas = torch.where( | |
# test tidx + 1 + delta <= max_idx = epi_length | |
(tindices[:, None] + deltas) < epilengths[:, None], | |
@@ -341,16 +407,14 @@ class Dataset(torch.utils.data.Dataset): | |
) | |
deltas = torch.distributions.Categorical( | |
probs=pdeltas, | |
- validate_args=False, | |
- ).sample() + 1 | |
- future_observations = self.get_observations(obs_indices + deltas) | |
+ ).sample() | |
+ future_observations = self.get_observations(obs_indices + 1 + deltas) | |
return BatchData( | |
observations=obs, | |
actions=self.raw_data.actions[indices], | |
next_observations=nobs, | |
future_observations=future_observations, | |
- future_tdelta=deltas, | |
rewards=self.raw_data.rewards[indices], | |
terminals=terminals, | |
timeouts=self.raw_data.timeouts[indices], | |
diff --git a/quasimetric_rl/data/interaction.py b/quasimetric_rl/data/interaction.py | |
deleted file mode 100644 | |
index 699fa61..0000000 | |
--- a/quasimetric_rl/data/interaction.py | |
+++ /dev/null | |
@@ -1,178 +0,0 @@ | |
-from __future__ import annotations | |
-from typing import * | |
- | |
-import attrs | |
- | |
-import gym | |
-import gym.spaces | |
-import numpy as np | |
-import torch | |
-import torch.utils.data | |
- | |
-from . import Dataset, EpisodeData, EnvSpec | |
- | |
- | |
-def first_nonzero(arr: torch.Tensor, dim: int = -1, invalid_val: int = -1): | |
- mask = (arr != 0) | |
- return torch.where(mask.any(dim=dim), mask.to(torch.uint8).argmax(dim=dim), invalid_val) | |
- | |
- | |
-@attrs.define(kw_only=True) | |
-class EvalEpisodeResult: | |
- timestep_reward: List[torch.Tensor] | |
- episode_return: torch.Tensor | |
- episode_score: torch.Tensor | |
- timestep_is_success: Optional[List[torch.Tensor]] | |
- is_success: Optional[torch.Tensor] | |
- hitting_time: Optional[torch.Tensor] | |
- extra_timestep_results: Mapping[str, List[torch.Tensor]] | |
- | |
- @property | |
- def capped_hitting_time(self) -> Optional[torch.Tensor]: | |
- # if not hit -> |ts| + 1 | |
- if self.hitting_time is None: | |
- return None | |
- assert self.timestep_is_success is not None | |
- return torch.stack([torch.where(_x < 0, _succ.shape[0] + 1, _x) for _x, _succ in zip(self.hitting_time, self.timestep_is_success)]) | |
- | |
- @classmethod | |
- def from_timestep_reward_is_success(cls, dataset: Dataset, | |
- timestep_reward: List[torch.Tensor], | |
- timestep_is_success: Optional[List[torch.Tensor]], | |
- extra_timestep_results) -> Self: | |
- return cls( | |
- timestep_reward=timestep_reward, | |
- episode_return=torch.stack([r.sum() for r in timestep_reward]), | |
- episode_score=dataset.normalize_score(timestep_reward), | |
- timestep_is_success=timestep_is_success, | |
- is_success=( | |
- None if timestep_is_success is None | |
- else torch.stack([_x.any(dim=-1) for _x in timestep_is_success]) | |
- ), | |
- hitting_time=( | |
- None if timestep_is_success is None | |
- else torch.stack([first_nonzero(_x, dim=-1) for _x in timestep_is_success]) | |
- ), # NB this is off by 1 | |
- extra_timestep_results=dict(extra_timestep_results), | |
- ) | |
- | |
- @classmethod | |
- def from_episode_rollouts(cls, dataset: Dataset,rollouts: Sequence[EpisodeData]) -> Self: | |
- return cls.from_timestep_reward_is_success( | |
- dataset, | |
- timestep_reward=[rollout.rewards for rollout in rollouts], | |
- timestep_is_success=( | |
- None | |
- if len(rollouts) == 0 or 'is_success' not in rollouts[0].transition_infos | |
- else [rollout.transition_infos['is_success'] for rollout in rollouts] | |
- ), | |
- extra_timestep_results=( | |
- {} if len(rollouts) == 0 else | |
- { | |
- k: [rollout.transition_infos[k] for rollout in rollouts] | |
- for k in rollouts[0].transition_infos.keys() if k != 'is_success' | |
- } | |
- ), | |
- ) | |
- | |
- def summarize(self) -> Mapping[str, Union[torch.Tensor, float, None]]: | |
- succ_rate_ts = ( | |
- None if self.timestep_is_success is None | |
- else torch.stack([_x.mean(dtype=torch.float32) for _x in self.timestep_is_success]) | |
- ) | |
- hitting_time = self.capped_hitting_time | |
- summary = dict( | |
- epi_return=self.episode_return, | |
- epi_score=self.episode_score, | |
- succ_rate_ts=succ_rate_ts, | |
- succ_rate=self.is_success, | |
- hitting_time=hitting_time, | |
- ) | |
- for kk, v in self.extra_timestep_results.items(): | |
- summary[kk] = torch.stack([_v.mean(dtype=torch.float32) for _v in v]) | |
- summary[kk + '_last'] = torch.stack([_v[-1] for _v in v]) | |
- | |
- return summary | |
- | |
- | |
-def collect_rollout(actor: Callable[[torch.Tensor, torch.Tensor, gym.Space], np.ndarray], *, | |
- env: gym.Env, env_spec: EnvSpec, max_episode_length: int, | |
- assert_exact_episode_length: bool = False) -> EpisodeData: | |
- # NOTE: extra tracked info can be specified by env.tracked_info_keys | |
- | |
- from .utils import get_empty_episode | |
- | |
- epi = get_empty_episode(env_spec, max_episode_length) | |
- | |
- # check observation space | |
- obs_dict_keys = {'observation', 'achieved_goal', 'desired_goal'} | |
- WRONG_OBS_ERR_MESSAGE = ( | |
- f"collect_rollout only supports Dict " | |
- f"observation space with keys {obs_dict_keys}, but got {env.observation_space}" | |
- ) | |
- assert isinstance(env.observation_space, gym.spaces.Dict), WRONG_OBS_ERR_MESSAGE | |
- assert set(env.observation_space.spaces.keys()) == {'observation', 'achieved_goal', 'desired_goal'}, WRONG_OBS_ERR_MESSAGE | |
- | |
- observation_dict = cast(Mapping[str, np.ndarray], env.reset()) | |
- observation: torch.Tensor = torch.as_tensor(observation_dict['observation'], dtype=torch.float32) | |
- | |
- goal: torch.Tensor = torch.as_tensor(observation_dict['desired_goal'], dtype=torch.float32) | |
- agoal: torch.Tensor = torch.as_tensor(observation_dict['achieved_goal'], dtype=torch.float32) | |
- epi.all_observations[0] = observation | |
- epi.observation_infos['desired_goals'][0] = goal | |
- epi.observation_infos['achieved_goals'][0] = agoal | |
- | |
- extra_transition_info_keys = getattr(env, 'tracked_info_keys', []) | |
- if len(extra_transition_info_keys): | |
- epi.transition_infos = dict(epi.transition_infos) | |
- epi.transition_infos.update({ | |
- k: torch.empty([max_episode_length], dtype=torch.float32) for k in extra_transition_info_keys | |
- }) | |
- | |
- t = 0 | |
- timeout = False | |
- terminal = False | |
- while not timeout and not terminal: | |
- assert t < max_episode_length | |
- | |
- action = actor( | |
- observation, | |
- goal, | |
- env_spec.action_space, | |
- ) | |
- transition_out = env.step(np.asarray(action)) | |
- observation_dict, reward, terminal, info = transition_out[:3] + transition_out[-1:] # some BC | |
- | |
- observation = torch.tensor(observation_dict['observation'], dtype=torch.float32) # copy just in case | |
- | |
- goal: torch.Tensor = torch.as_tensor(observation_dict['desired_goal'], dtype=torch.float32) | |
- agoal: torch.Tensor = torch.as_tensor(observation_dict['achieved_goal'], dtype=torch.float32) | |
- | |
- if 'is_success' in info: | |
- is_success: bool = info['is_success'] | |
- epi.transition_infos['is_success'][t] = is_success | |
- else: | |
- if t == 0: | |
- # remove field | |
- transition_infos = dict(epi.transition_infos) | |
- del transition_infos['is_success'] | |
- epi.transition_infos = transition_infos | |
- | |
- for k in extra_transition_info_keys: | |
- epi.transition_infos[k][t] = info[k] | |
- | |
- epi.all_observations[t + 1] = observation | |
- epi.actions[t] = torch.as_tensor(action, dtype=torch.float32) | |
- epi.rewards[t] = reward | |
- epi.observation_infos['desired_goals'][t + 1] = goal | |
- epi.observation_infos['achieved_goals'][t + 1] = agoal | |
- | |
- t += 1 | |
- timeout = info.get('TimeLimit.truncated', False) | |
- if assert_exact_episode_length: | |
- assert (timeout or terminal) == (t == max_episode_length) | |
- | |
- if t < max_episode_length: | |
- epi = epi.first_t(t) | |
- | |
- return epi | |
diff --git a/quasimetric_rl/data/offline/__init__.py b/quasimetric_rl/data/offline/__init__.py | |
index c816a7b..22bb42c 100644 | |
--- a/quasimetric_rl/data/offline/__init__.py | |
+++ b/quasimetric_rl/data/offline/__init__.py | |
@@ -60,10 +60,6 @@ class OfflineGoalCondEnv(gym.ObservationWrapper, OfflineEnv): # type: ignore | |
self.get_goal_fn = get_goal_fn | |
self.extra_info_fns = extra_info_fns | |
- @property | |
- def tracked_info_keys(self): | |
- return tuple(self.extra_info_fns.keys()) | |
- | |
def observation(self, observation: np.ndarray): | |
o, g = observation, self.get_goal_fn(self.env) | |
if self.is_image_based: | |
diff --git a/quasimetric_rl/data/offline/d4rl/antmaze.py b/quasimetric_rl/data/offline/d4rl/antmaze.py | |
index dbf5198..c48ab34 100644 | |
--- a/quasimetric_rl/data/offline/d4rl/antmaze.py | |
+++ b/quasimetric_rl/data/offline/d4rl/antmaze.py | |
@@ -182,13 +182,12 @@ def create_env_antmaze(name, dict_obseravtion: Optional[bool] = None, *, random_ | |
return env | |
-def load_episodes_antmaze(name, normalize_observation=True): | |
+def load_episodes_antmaze(name): | |
env = load_environment(name) | |
d4rl_dataset = cached_d4rl_dataset(name) | |
- if normalize_observation: | |
- # normalize | |
- d4rl_dataset['observations'] = obs_norm(name, d4rl_dataset['observations']) | |
- d4rl_dataset['next_observations'] = obs_norm(name, d4rl_dataset['next_observations']) | |
+ # normalize | |
+ d4rl_dataset['observations'] = obs_norm(name, d4rl_dataset['observations']) | |
+ d4rl_dataset['next_observations'] = obs_norm(name, d4rl_dataset['next_observations']) | |
yield from convert_dict_to_EpisodeData_iter( | |
sequence_dataset( | |
env, | |
@@ -207,11 +206,3 @@ for name in ['antmaze-umaze-v2', 'antmaze-umaze-diverse-v2', | |
normalize_score_fn=functools.partial(get_normalized_score, name), | |
eval_specs=dict(single_task=dict(random_start_goal=False), multi_task=dict(random_start_goal=True)), | |
) | |
- register_offline_env( | |
- 'd4rl', name + '-nonorm', | |
- create_env_fn=functools.partial(create_env_antmaze, name, normalize_observation=False), | |
- load_episodes_fn=functools.partial(load_episodes_antmaze, name, normalize_observation=False), | |
- normalize_score_fn=functools.partial(get_normalized_score, name), | |
- eval_specs=dict(single_task=dict(random_start_goal=False, normalize_observation=False), | |
- multi_task=dict(random_start_goal=True, normalize_observation=False)), | |
- ) | |
diff --git a/quasimetric_rl/data/online/memory.py b/quasimetric_rl/data/online/memory.py | |
index a3445f4..b7cb94b 100644 | |
--- a/quasimetric_rl/data/online/memory.py | |
+++ b/quasimetric_rl/data/online/memory.py | |
@@ -12,7 +12,6 @@ import gym.spaces | |
from . import OnlineFixedLengthEnv | |
from ..base import EpisodeData, MultiEpisodeData, Dataset, BatchData | |
-from ..interaction import collect_rollout | |
from ..utils import get_empty_episode, get_empty_episodes | |
@@ -154,7 +153,7 @@ class ReplayBuffer(Dataset): | |
get_empty_episodes( | |
self.env_spec, self.episode_length, | |
int(np.ceil(self.increment_num_transitions / self.episode_length)), | |
- ).to(self.device), | |
+ ), | |
], | |
dim=0, | |
) | |
@@ -165,19 +164,19 @@ class ReplayBuffer(Dataset): | |
# indices_to_episode_timesteps: torch.Tensor | |
self.indices_to_episode_indices = torch.cat([ | |
self.indices_to_episode_indices, | |
- torch.repeat_interleave(torch.arange(original_capacity, new_capacity, device=self.device), self.episode_length), | |
+ torch.repeat_interleave(torch.arange(original_capacity, new_capacity), self.episode_length), | |
], dim=0) | |
self.indices_to_episode_timesteps = torch.cat([ | |
self.indices_to_episode_timesteps, | |
- torch.arange(self.episode_length, device=self.device).repeat(new_capacity - original_capacity), | |
+ torch.arange(self.episode_length).repeat(new_capacity - original_capacity), | |
], dim=0) | |
logging.info(f'ReplayBuffer: Expanded from capacity={original_capacity} to {new_capacity} episodes') | |
def collect_rollout(self, actor: Callable[[torch.Tensor, torch.Tensor, gym.Space], np.ndarray], *, | |
env: Optional[OnlineFixedLengthEnv] = None) -> EpisodeData: | |
- return collect_rollout(actor, env=(env or self.env), env_spec=self.env_spec, | |
- max_episode_length=self.episode_length, assert_exact_episode_length=True) | |
+ return self.collect_rollout_general(actor, env=(env or self.env), env_spec=self.env_spec, | |
+ max_episode_length=self.episode_length, assert_exact_episode_length=True) | |
def add_rollout(self, episode: EpisodeData): | |
if self.num_episodes_realized == self.episodes_capacity: | |
@@ -215,8 +214,7 @@ class ReplayBuffer(Dataset): | |
def sample(self, batch_size: int) -> BatchData: | |
indices = torch.as_tensor( | |
- np.random.choice(self.num_transitions_realized, size=[batch_size]), | |
- device=self.device, | |
+ np.random.choice(self.num_transitions_realized, size=[batch_size]) | |
) | |
return self[indices] | |
diff --git a/quasimetric_rl/flags.py b/quasimetric_rl/flags.py | |
index 2f7578a..2fb0fab 100644 | |
--- a/quasimetric_rl/flags.py | |
+++ b/quasimetric_rl/flags.py | |
@@ -25,33 +25,31 @@ FLAGS = FlagsDefinition() | |
def pdb_if_DEBUG(fn: Callable): | |
@functools.wraps(fn) | |
def wrapped(*args, **kwargs): | |
- if not FLAGS.DEBUG: # check here, in case it is set after decorator call | |
- return fn(*args, **kwargs) | |
- else: | |
- try: | |
- return fn(*args, **kwargs) | |
- except: | |
- # follow ABSL: | |
- # https://github.com/abseil/abseil-py/blob/a0ae31683e6cf3667886c500327f292c893a1740/absl/app.py#L311-L327 | |
- | |
- exc = sys.exc_info()[1] | |
- if isinstance(exc, KeyboardInterrupt): | |
- raise | |
- | |
- # Don't try to post-mortem debug successful SystemExits, since those | |
- # mean there wasn't actually an error. In particular, the test framework | |
- # raises SystemExit(False) even if all tests passed. | |
- if isinstance(exc, SystemExit) and not exc.code: | |
- raise | |
- | |
- # Check the tty so that we don't hang waiting for input in an | |
- # non-interactive scenario. | |
+ try: | |
+ fn(*args, **kwargs) | |
+ except: | |
+ # follow ABSL: | |
+ # https://github.com/abseil/abseil-py/blob/a0ae31683e6cf3667886c500327f292c893a1740/absl/app.py#L311-L327 | |
+ | |
+ exc = sys.exc_info()[1] | |
+ if isinstance(exc, KeyboardInterrupt): | |
+ raise | |
+ | |
+ # Don't try to post-mortem debug successful SystemExits, since those | |
+ # mean there wasn't actually an error. In particular, the test framework | |
+ # raises SystemExit(False) even if all tests passed. | |
+ if isinstance(exc, SystemExit) and not exc.code: | |
+ raise | |
+ | |
+ # Check the tty so that we don't hang waiting for input in an | |
+ # non-interactive scenario. | |
+ if FLAGS.DEBUG: | |
traceback.print_exc() | |
print() | |
print(' *** Entering post-mortem debugging ***') | |
print() | |
pdb.post_mortem() | |
- raise | |
+ raise | |
return wrapped | |
diff --git a/quasimetric_rl/modules/__init__.py b/quasimetric_rl/modules/__init__.py | |
index 81b5443..79fd35f 100644 | |
--- a/quasimetric_rl/modules/__init__.py | |
+++ b/quasimetric_rl/modules/__init__.py | |
@@ -9,7 +9,6 @@ from . import actor, quasimetric_critic | |
from ..data import EnvSpec, BatchData | |
from .utils import LossResult, Module, InfoT, InfoValT | |
-from ..flags import FLAGS | |
class QRLAgent(Module): | |
@@ -33,15 +32,13 @@ class QRLLosses(Module): | |
critic_losses: Collection[quasimetric_critic.QuasimetricCriticLosses], | |
critics_total_grad_clip_norm: Optional[float], | |
recompute_critic_for_actor_loss: bool, | |
- critics_share_embedding: bool, | |
- critic_losses_use_target_encoder: bool): | |
+ critics_share_embedding: bool): | |
super().__init__() | |
self.add_module('actor_loss', actor_loss) | |
self.critic_losses = torch.nn.ModuleList(critic_losses) # type: ignore | |
self.critics_total_grad_clip_norm = critics_total_grad_clip_norm | |
self.recompute_critic_for_actor_loss = recompute_critic_for_actor_loss | |
self.critics_share_embedding = critics_share_embedding | |
- self.critic_losses_use_target_encoder = critic_losses_use_target_encoder | |
def forward(self, agent: QRLAgent, data: BatchData, *, optimize: bool = True) -> LossResult: | |
# compute CriticBatchInfo | |
@@ -53,106 +50,53 @@ class QRLLosses(Module): | |
stack.enter_context(critic_loss.optim_update_context(optimize=optimize)) | |
if self.critics_share_embedding and idx > 0: | |
- critic_batch_info = attrs.evolve(critic_batch_infos[0], critic=critic) | |
+ critic_batch_info = quasimetric_critic.CriticBatchInfo( | |
+ critic=critic, | |
+ zx=critic_batch_infos[0].zx, | |
+ zy=critic_batch_infos[0].zy, | |
+ ) | |
else: | |
zx = critic.encoder(data.observations) | |
- zy = critic.get_encoder(target=self.critic_losses_use_target_encoder)(data.next_observations) | |
- if critic.has_separate_target_encoder and self.critic_losses_use_target_encoder: | |
+ zy = critic.target_encoder(data.next_observations) | |
+ if critic.has_separate_target_encoder: | |
assert not zy.requires_grad | |
critic_batch_info = quasimetric_critic.CriticBatchInfo( | |
critic=critic, | |
zx=zx, | |
zy=zy, | |
- zy_from_target_encoder=self.critic_losses_use_target_encoder, | |
) | |
loss_results[f"critic_{idx:02d}"] = critic_loss(data, critic_batch_info) # we update together to handle shared embedding | |
critic_batch_infos.append(critic_batch_info) | |
- critic_grad_norm: InfoValT = {} | |
- | |
- if FLAGS.DEBUG: | |
- def get_grad(loss: Union[torch.Tensor, float]) -> Union[torch.Tensor, float]: | |
- if isinstance(loss, (int, float)): | |
- return 0 | |
- loss_grads = torch.autograd.grad( | |
- loss, | |
- list(cast(torch.nn.ModuleList, agent.critics).parameters()), | |
- retain_graph=True, | |
- allow_unused=True, | |
- ) | |
- return cast( | |
- torch.Tensor, | |
- sum(pg.pow(2).sum() for pg in loss_grads if pg is not None), | |
- ).sqrt() | |
- | |
- for k, loss_r in loss_results.items(): | |
- critic_grad_norm.update({ | |
- k: loss_r.map_losses(get_grad), | |
- }) | |
- | |
torch.stack( | |
- [loss_r.total_loss for loss_r in loss_results.values()] | |
+ [cast(torch.Tensor, loss_r.loss) for loss_r in loss_results.values()] | |
).sum().backward() | |
if self.critics_total_grad_clip_norm is not None: | |
- critic_grad_norm['total'] = torch.nn.utils.clip_grad_norm_( | |
- cast(torch.nn.ModuleList, agent.critics).parameters(), | |
- max_norm=self.critics_total_grad_clip_norm) | |
- else: | |
- critic_grad_norm['total'] = cast( | |
- torch.Tensor, | |
- sum(p.grad.pow(2).sum() for p in cast(torch.nn.ModuleList, agent.critics).parameters() if p.grad is not None), | |
- ).sqrt() | |
+ torch.nn.utils.clip_grad_norm_(cast(torch.nn.ModuleList, agent.critics).parameters(), | |
+ max_norm=self.critics_total_grad_clip_norm) | |
if optimize: | |
for critic in agent.critics: | |
- critic.update_target_models_() | |
+ critic.update_target_encoder_() | |
- actor_grad_norm: InfoValT = {} | |
if self.actor_loss is not None: | |
assert agent.actor is not None | |
with torch.no_grad(), torch.inference_mode(): | |
- for idx, critic in enumerate(agent.critics): # FIXME? | |
- if self.recompute_critic_for_actor_loss or (critic.has_separate_target_encoder and self.actor_loss.use_target_encoder): | |
- zx, zy = critic.get_encoder(target=self.actor_loss.use_target_encoder)( | |
- torch.stack([data.observations, data.next_observations], dim=0)).unbind(0) | |
+ for idx, critic in enumerate(agent.critics): | |
+ if self.recompute_critic_for_actor_loss or critic.has_separate_target_encoder: | |
+ zx, zy = critic.target_encoder(torch.stack([data.observations, data.next_observations], dim=0)).unbind(0) | |
critic_batch_infos[idx] = quasimetric_critic.CriticBatchInfo( | |
critic=critic, | |
zx=zx, | |
zy=zy, | |
- zx_from_target_encoder=self.actor_loss.use_target_encoder, | |
- zy_from_target_encoder=self.actor_loss.use_target_encoder, | |
) | |
with self.actor_loss.optim_update_context(optimize=optimize): | |
loss_results['actor'] = loss_r = self.actor_loss(agent.actor, critic_batch_infos, data) | |
+ cast(torch.Tensor, loss_r.loss).backward() | |
- if FLAGS.DEBUG: | |
- def get_grad(loss: Union[torch.Tensor, float]) -> Union[torch.Tensor, float]: | |
- if isinstance(loss, (int, float)): | |
- return 0 | |
- assert agent.actor is not None | |
- loss_grads = torch.autograd.grad( | |
- loss, | |
- list(agent.actor.parameters()), | |
- retain_graph=True, | |
- allow_unused=True, | |
- ) | |
- return cast( | |
- torch.Tensor, | |
- sum(pg.pow(2).sum() for pg in loss_grads if pg is not None), | |
- ).sqrt() | |
- | |
- actor_grad_norm.update(cast(Mapping, loss_r.map_losses(get_grad))) | |
- | |
- loss_r.total_loss.backward() | |
- | |
- actor_grad_norm['total'] = cast( | |
- torch.Tensor, | |
- sum(p.grad.pow(2).sum() for p in agent.actor.parameters() if p.grad is not None), | |
- ).sqrt() | |
- | |
- return LossResult.combine(loss_results, grad_norm=dict(critic=critic_grad_norm, actor=actor_grad_norm)) | |
+ return LossResult.combine(loss_results) | |
# for type hints | |
def __call__(self, agent: QRLAgent, data: BatchData, *, optimize: bool = True) -> LossResult: | |
@@ -198,12 +142,7 @@ class QRLLosses(Module): | |
critic_loss.dynamics_lagrange_mult_sched.load_state_dict(optim_scheds[f"critic_{idx:02d}"]['dynamics_lagrange_mult_sched']) | |
def extra_repr(self) -> str: | |
- return '\n'.join([ | |
- f'recompute_critic_for_actor_loss={self.recompute_critic_for_actor_loss}', | |
- f'critics_share_embedding={self.critics_share_embedding}', | |
- f'critics_total_grad_clip_norm={self.critics_total_grad_clip_norm}', | |
- f'critic_losses_use_target_encoder={self.critic_losses_use_target_encoder}', | |
- ]) | |
+ return f'recompute_critic_for_actor_loss={self.recompute_critic_for_actor_loss}' | |
@attrs.define(kw_only=True) | |
@@ -216,7 +155,6 @@ class QRLConf: | |
default=None, validator=attrs.validators.optional(attrs.validators.gt(0)), | |
) | |
recompute_critic_for_actor_loss: bool = False | |
- critic_losses_use_target_encoder: bool = True | |
def make(self, *, env_spec: EnvSpec, total_optim_steps: int) -> Tuple[QRLAgent, QRLLosses]: | |
if self.actor is None: | |
@@ -236,12 +174,9 @@ class QRLConf: | |
critics.append(critic) | |
critic_losses.append(critic_loss) | |
- return QRLAgent(actor=actor, critics=critics), QRLLosses( | |
- actor_loss=actor_losses, critic_losses=critic_losses, | |
- critics_share_embedding=self.critics_share_embedding, | |
- critics_total_grad_clip_norm=self.critics_total_grad_clip_norm, | |
- recompute_critic_for_actor_loss=self.recompute_critic_for_actor_loss, | |
- critic_losses_use_target_encoder=self.critic_losses_use_target_encoder, | |
- ) | |
+ return QRLAgent(actor=actor, critics=critics), QRLLosses(actor_loss=actor_losses, critic_losses=critic_losses, | |
+ critics_share_embedding=self.critics_share_embedding, | |
+ critics_total_grad_clip_norm=self.critics_total_grad_clip_norm, | |
+ recompute_critic_for_actor_loss=self.recompute_critic_for_actor_loss) | |
__all__ = ['QRLAgent', 'QRLLosses', 'QRLConf', 'InfoT', 'InfoValT'] | |
diff --git a/quasimetric_rl/modules/actor/losses/__init__.py b/quasimetric_rl/modules/actor/losses/__init__.py | |
index f9c78a1..c8f1ec7 100644 | |
--- a/quasimetric_rl/modules/actor/losses/__init__.py | |
+++ b/quasimetric_rl/modules/actor/losses/__init__.py | |
@@ -15,15 +15,12 @@ from ...optim import OptimWrapper, AdamWSpec, LRScheduler | |
class ActorLossBase(LossBase): | |
@abc.abstractmethod | |
- def forward(self, actor: Actor, critic_batch_infos: Collection[CriticBatchInfo], data: BatchData, | |
- use_target_encoder: bool, use_target_quasimetric_model: bool) -> LossResult: | |
+ def forward(self, actor: Actor, critic_batch_infos: Collection[CriticBatchInfo], data: BatchData) -> LossResult: | |
pass | |
# for type hints | |
- def __call__(self, actor: Actor, critic_batch_infos: Collection[CriticBatchInfo], data: BatchData, | |
- use_target_encoder: bool, use_target_quasimetric_model: bool) -> LossResult: | |
- return super().__call__(actor, critic_batch_infos, data, use_target_encoder=use_target_encoder, | |
- use_target_quasimetric_model=use_target_quasimetric_model) | |
+ def __call__(self, actor: Actor, critic_batch_infos: Collection[CriticBatchInfo], data: BatchData) -> LossResult: | |
+ return super().__call__(actor, critic_batch_infos, data) | |
from .min_dist import MinDistLoss | |
@@ -43,9 +40,6 @@ class ActorLosses(ActorLossBase): | |
actor_optim: AdamWSpec.Conf = AdamWSpec.Conf(lr=3e-5) | |
entropy_weight_optim: AdamWSpec.Conf = AdamWSpec.Conf(lr=3e-4) # TODO: move to loss | |
- use_target_encoder: bool = True | |
- use_target_quasimetric_model: bool = True | |
- | |
def make(self, actor: Actor, total_optim_steps: int, env_spec: EnvSpec) -> 'ActorLosses': | |
return ActorLosses( | |
actor, | |
@@ -55,8 +49,6 @@ class ActorLosses(ActorLossBase): | |
advantage_weighted_regression=self.advantage_weighted_regression.make(), | |
actor_optim_spec=self.actor_optim.make(), | |
entropy_weight_optim_spec=self.entropy_weight_optim.make(), | |
- use_target_encoder=self.use_target_encoder, | |
- use_target_quasimetric_model=self.use_target_quasimetric_model, | |
) | |
min_dist: Optional[MinDistLoss] | |
@@ -68,14 +60,10 @@ class ActorLosses(ActorLossBase): | |
entropy_weight_optim: OptimWrapper | |
entropy_weight_sched: LRScheduler | |
- use_target_encoder: bool | |
- use_target_quasimetric_model: bool | |
- | |
def __init__(self, actor: Actor, *, total_optim_steps: int, | |
min_dist: Optional[MinDistLoss], behavior_cloning: Optional[BCLoss], | |
advantage_weighted_regression: Optional[AWRLoss], | |
- actor_optim_spec: AdamWSpec, entropy_weight_optim_spec: AdamWSpec, | |
- use_target_encoder: bool, use_target_quasimetric_model: bool): | |
+ actor_optim_spec: AdamWSpec, entropy_weight_optim_spec: AdamWSpec): | |
super().__init__() | |
self.add_module('min_dist', min_dist) | |
self.add_module('behavior_cloning', behavior_cloning) | |
@@ -88,9 +76,6 @@ class ActorLosses(ActorLossBase): | |
if min_dist is not None: | |
assert len(list(min_dist.parameters())) <= 1 | |
- self.use_target_encoder = use_target_encoder | |
- self.use_target_quasimetric_model = use_target_quasimetric_model | |
- | |
def optimizers(self) -> Iterable[OptimWrapper]: | |
return [self.actor_optim, self.entropy_weight_optim] | |
@@ -101,24 +86,18 @@ class ActorLosses(ActorLossBase): | |
loss_results: Dict[str, LossResult] = {} | |
if self.min_dist is not None: | |
loss_results.update( | |
- min_dist=self.min_dist(actor, critic_batch_infos, data, self.use_target_encoder, self.use_target_quasimetric_model), | |
+ min_dist=self.min_dist(actor, critic_batch_infos, data), | |
) | |
if self.behavior_cloning is not None: | |
loss_results.update( | |
- bc=self.behavior_cloning(actor, critic_batch_infos, data, self.use_target_encoder, self.use_target_quasimetric_model), | |
+ bc=self.behavior_cloning(actor, critic_batch_infos, data), | |
) | |
if self.advantage_weighted_regression is not None: | |
loss_results.update( | |
- awr=self.advantage_weighted_regression(actor, critic_batch_infos, data, self.use_target_encoder, self.use_target_quasimetric_model), | |
+ awr=self.advantage_weighted_regression(actor, critic_batch_infos, data), | |
) | |
return LossResult.combine(loss_results) | |
# for type hints | |
def __call__(self, actor: Actor, critic_batch_infos: Sequence[CriticBatchInfo], data: BatchData) -> LossResult: | |
return torch.nn.Module.__call__(self, actor, critic_batch_infos, data) | |
- | |
- def extra_repr(self) -> str: | |
- return '\n'.join([ | |
- f"actor_optim={self.actor_optim}, entropy_weight_optim={self.entropy_weight_optim}", | |
- f"use_target_encoder={self.use_target_encoder}, use_target_quasimetric_model={self.use_target_quasimetric_model}", | |
- ]) | |
diff --git a/quasimetric_rl/modules/actor/losses/awr.py b/quasimetric_rl/modules/actor/losses/awr.py | |
index 1313dc3..c69e79c 100644 | |
--- a/quasimetric_rl/modules/actor/losses/awr.py | |
+++ b/quasimetric_rl/modules/actor/losses/awr.py | |
@@ -3,10 +3,11 @@ from typing import * | |
import attrs | |
import torch | |
+import torch.nn as nn | |
-from ....data import BatchData | |
+from ....data import BatchData, EnvSpec | |
-from ...utils import LatentTensor, LossResult, bcast_bshape | |
+from ...utils import LatentTensor, LossResult, grad_mul | |
from ..model import Actor | |
from ...quasimetric_critic import QuasimetricCritic, CriticBatchInfo | |
@@ -81,8 +82,8 @@ class AWRLoss(ActorLossBase): | |
self.clamp = clamp | |
self.add_goal_as_future_state = add_goal_as_future_state | |
- def gather_obs_goal_pairs(self, critic_batch_infos: Sequence[CriticBatchInfo], data: BatchData, | |
- use_target_encoder: bool) -> Tuple[torch.Tensor, torch.Tensor, Collection[ActorObsGoalCriticInfo]]: | |
+ def gather_obs_goal_pairs(self, critic_batch_infos: Sequence[CriticBatchInfo], | |
+ data: BatchData) -> Tuple[torch.Tensor, torch.Tensor, Collection[ActorObsGoalCriticInfo]]: | |
r""" | |
Returns ( | |
obs, | |
@@ -115,7 +116,7 @@ class AWRLoss(ActorLossBase): | |
# add future_observations | |
zg = torch.stack([ | |
zg, | |
- critic_batch_info.critic.get_encoder(target=use_target_encoder)(data.future_observations), | |
+ critic_batch_info.critic.target_encoder(data.future_observations), | |
], 0) | |
# zo = zo.expand_as(zg) | |
@@ -127,11 +128,9 @@ class AWRLoss(ActorLossBase): | |
return obs, goal, actor_obs_goal_critic_infos | |
- def forward(self, actor: Actor, critic_batch_infos: Sequence[CriticBatchInfo], data: BatchData, | |
- use_target_encoder: bool, use_target_quasimetric_model: bool) -> LossResult: | |
- | |
+ def forward(self, actor: Actor, critic_batch_infos: Sequence[CriticBatchInfo], data: BatchData) -> LossResult: | |
with torch.no_grad(): | |
- obs, goal, actor_obs_goal_critic_infos = self.gather_obs_goal_pairs(critic_batch_infos, data, use_target_encoder) | |
+ obs, goal, actor_obs_goal_critic_infos = self.gather_obs_goal_pairs(critic_batch_infos, data) | |
info: Dict[str, Union[float, torch.Tensor]] = {} | |
@@ -141,25 +140,14 @@ class AWRLoss(ActorLossBase): | |
for idx, actor_obs_goal_critic_info in enumerate(actor_obs_goal_critic_infos): | |
critic = actor_obs_goal_critic_info.critic | |
- latent_dynamics = critic.latent_dynamics | |
- quasimetric_model = critic.get_quasimetric_model(target=use_target_quasimetric_model) | |
- zo = actor_obs_goal_critic_info.zo.detach() # [B,D] | |
- zg = actor_obs_goal_critic_info.zg.detach() # [2?,B,D] | |
+ zo = actor_obs_goal_critic_info.zo.detach() | |
+ zg = actor_obs_goal_critic_info.zg.detach() | |
with torch.no_grad(), critic.mode(False): | |
- zp = latent_dynamics(data.observations, zo, data.actions) # [B,D] | |
- if idx == 0 or not critic.borrowing_embedding: | |
- zo, zp, zg = bcast_bshape( | |
- (zo, 1), | |
- (zp, 1), | |
- (zg, 1), | |
- ) | |
- z = torch.stack([zo, zp], dim=0) # [2,2?,B,D] | |
- # z = z[(slice(None),) + (None,) * (zg.ndim - zo.ndim) + (Ellipsis,)] # if zg is batched, add enough dim to be broadcast-able | |
- dist_noact, dist = quasimetric_model(z, zg).unbind(0) # [2,2?,B] -> 2x [2?,B] | |
- dist_noact = dist_noact.detach() | |
- else: | |
- dist = quasimetric_model(zp, zg) | |
- dist_noact = dists_noact[0] | |
+ zp = critic.latent_dynamics(data.observations, zo, data.actions) | |
+ z = torch.stack([zo, zp], dim=0) | |
+ z = z[(slice(None),) + (None,) * (zg.ndim - zo.ndim) + (Ellipsis,)] # if zg is batched, add enough dim to be broadcast-able | |
+ dist_noact, dist = critic.quasimetric_model(z, zg).unbind(0) | |
+ dist_noact = dist_noact.detach() | |
info[f'dist_delta_{idx:02d}'] = (dist_noact - dist).mean() | |
info[f'dist_{idx:02d}'] = dist.mean() | |
dists_noact.append(dist_noact) | |
diff --git a/quasimetric_rl/modules/actor/losses/behavior_cloning.py b/quasimetric_rl/modules/actor/losses/behavior_cloning.py | |
index eba36af..632221e 100644 | |
--- a/quasimetric_rl/modules/actor/losses/behavior_cloning.py | |
+++ b/quasimetric_rl/modules/actor/losses/behavior_cloning.py | |
@@ -34,8 +34,7 @@ class BCLoss(ActorLossBase): | |
super().__init__() | |
self.weight = weight | |
- def forward(self, actor: Actor, critic_batch_infos: Collection[CriticBatchInfo], data: BatchData, | |
- use_target_encoder: bool, use_target_quasimetric_model: bool) -> LossResult: | |
+ def forward(self, actor: Actor, critic_batch_infos: Collection[CriticBatchInfo], data: BatchData) -> LossResult: | |
actor_distn = actor(data.observations, data.future_observations) | |
log_prob: torch.Tensor = actor_distn.log_prob(data.actions).mean() | |
loss = -log_prob * self.weight | |
diff --git a/quasimetric_rl/modules/actor/losses/min_dist.py b/quasimetric_rl/modules/actor/losses/min_dist.py | |
index 1d57609..0f90cff 100644 | |
--- a/quasimetric_rl/modules/actor/losses/min_dist.py | |
+++ b/quasimetric_rl/modules/actor/losses/min_dist.py | |
@@ -87,8 +87,8 @@ class MinDistLoss(ActorLossBase): | |
self.register_parameter('raw_entropy_weight', None) | |
self.target_entropy = None | |
- def gather_obs_goal_pairs(self, critic_batch_infos: Sequence[CriticBatchInfo], data: BatchData, | |
- use_target_encoder: bool) -> Tuple[torch.Tensor, torch.Tensor, Collection[ActorObsGoalCriticInfo]]: | |
+ def gather_obs_goal_pairs(self, critic_batch_infos: Sequence[CriticBatchInfo], | |
+ data: BatchData) -> Tuple[torch.Tensor, torch.Tensor, Collection[ActorObsGoalCriticInfo]]: | |
r""" | |
Returns ( | |
obs, | |
@@ -121,7 +121,7 @@ class MinDistLoss(ActorLossBase): | |
# add future_observations | |
zg = torch.stack([ | |
zg, | |
- critic_batch_info.critic.get_encoder(target=use_target_encoder)(data.future_observations), | |
+ critic_batch_info.critic.target_encoder(data.future_observations), | |
], 0) | |
# zo = zo.expand_as(zg) | |
@@ -133,10 +133,9 @@ class MinDistLoss(ActorLossBase): | |
return obs, goal, actor_obs_goal_critic_infos | |
- def forward(self, actor: Actor, critic_batch_infos: Sequence[CriticBatchInfo], data: BatchData, | |
- use_target_encoder: bool, use_target_quasimetric_model: bool) -> LossResult: | |
+ def forward(self, actor: Actor, critic_batch_infos: Sequence[CriticBatchInfo], data: BatchData) -> LossResult: | |
with torch.no_grad(): | |
- obs, goal, actor_obs_goal_critic_infos = self.gather_obs_goal_pairs(critic_batch_infos, data, use_target_encoder) | |
+ obs, goal, actor_obs_goal_critic_infos = self.gather_obs_goal_pairs(critic_batch_infos, data) | |
actor_distn = actor(obs, goal) | |
action = actor_distn.rsample() | |
@@ -148,20 +147,13 @@ class MinDistLoss(ActorLossBase): | |
for idx, actor_obs_goal_critic_info in enumerate(actor_obs_goal_critic_infos): | |
critic = actor_obs_goal_critic_info.critic | |
- latent_dynamics = critic.latent_dynamics | |
- quasimetric_model = critic.get_quasimetric_model(target=use_target_quasimetric_model) | |
- zo = actor_obs_goal_critic_info.zo.detach() # [B,D] | |
- zg = actor_obs_goal_critic_info.zg.detach() # [2?,B,D] | |
+ zo = actor_obs_goal_critic_info.zo.detach() | |
+ zg = actor_obs_goal_critic_info.zg.detach() | |
with critic.requiring_grad(False), critic.mode(False): | |
- zp = latent_dynamics(data.observations, zo, action) # [2?,B,D] | |
- if idx == 0 or not critic.borrowing_embedding: | |
- # action: [2?,B,A] | |
- z = torch.stack(torch.broadcast_tensors(zo, zp), dim=0) # [2,2?,B,D] | |
- dist_noact, dist = quasimetric_model(z, zg).unbind(0) # [2,2?,B] -> 2x [2?,B] | |
- dist_noact = dist_noact.detach() | |
- else: | |
- dist = quasimetric_model(zp, zg) # [2?,B] | |
- dist_noact = dists_noact[0] | |
+ zp = critic.latent_dynamics(data.observations, zo, action) | |
+ z = torch.stack(torch.broadcast_tensors(zo, zp), dim=0) | |
+ dist_noact, dist = critic.quasimetric_model(z, zg).unbind(0) | |
+ dist_noact = dist_noact.detach() | |
info[f'dist_delta_{idx:02d}'] = (dist_noact - dist).mean() | |
info[f'dist_{idx:02d}'] = dist.mean() | |
dists_noact.append(dist_noact) | |
diff --git a/quasimetric_rl/modules/optim.py b/quasimetric_rl/modules/optim.py | |
index 5019e1b..5131ce0 100644 | |
--- a/quasimetric_rl/modules/optim.py | |
+++ b/quasimetric_rl/modules/optim.py | |
@@ -1,7 +1,6 @@ | |
from typing import * | |
import attrs | |
-import logging | |
import contextlib | |
import torch | |
@@ -92,11 +91,9 @@ class AdamWSpec: | |
if len(params) == 0: | |
params = [dict(params=[])] # dummy param group so pytorch doesn't complain | |
for ii in range(len(params)): | |
- pg = params[ii] | |
- if isinstance(pg, Mapping) and 'lr_mul' in pg: # handle lr_multiplier | |
- pg['params'] = params_list = cast(List[torch.Tensor], list(pg['params'])) | |
+ if isinstance(params, Mapping) and 'lr_mul' in params: # handle lr_multiplier | |
+ pg = params[ii] | |
assert 'lr' not in pg | |
- logging.info(f'params (#tensor={len(params_list)}, #={sum(p.numel() for p in params_list)}): lr_mul={pg["lr_mul"]}') | |
pg['lr'] = self.lr * pg['lr_mul'] # type: ignore | |
del pg['lr_mul'] | |
return OptimWrapper( | |
diff --git a/quasimetric_rl/modules/quasimetric_critic/losses/__init__.py b/quasimetric_rl/modules/quasimetric_critic/losses/__init__.py | |
index 5487442..fa8eba0 100644 | |
--- a/quasimetric_rl/modules/quasimetric_critic/losses/__init__.py | |
+++ b/quasimetric_rl/modules/quasimetric_critic/losses/__init__.py | |
@@ -19,8 +19,6 @@ class CriticBatchInfo: | |
critic: QuasimetricCritic | |
zx: LatentTensor | |
zy: LatentTensor | |
- zx_from_target_encoder: bool = False | |
- zy_from_target_encoder: bool = False | |
class CriticLossBase(LossBase): | |
@@ -33,7 +31,7 @@ class CriticLossBase(LossBase): | |
return super().__call__(data, critic_batch_info) | |
-from .global_push import GlobalPushLoss, GlobalPushLinearLoss, GlobalPushLogLoss, GlobalPushRBFLoss, GlobalPushNextMSELoss | |
+from .global_push import GlobalPushLoss, GlobalPushLinearLoss, GlobalPushLogLoss, GlobalPushRBFLoss | |
from .local_constraint import LocalConstraintLoss | |
from .latent_dynamics import LatentDynamicsLoss | |
@@ -43,24 +41,20 @@ class QuasimetricCriticLosses(CriticLossBase): | |
class Conf: | |
global_push: GlobalPushLoss.Conf = GlobalPushLoss.Conf() | |
global_push_linear: GlobalPushLinearLoss.Conf = GlobalPushLinearLoss.Conf() | |
- global_push_next_mse: GlobalPushNextMSELoss.Conf = GlobalPushNextMSELoss.Conf() | |
global_push_log: GlobalPushLogLoss.Conf = GlobalPushLogLoss.Conf() | |
global_push_rbf: GlobalPushRBFLoss.Conf = GlobalPushRBFLoss.Conf() | |
local_constraint: LocalConstraintLoss.Conf = LocalConstraintLoss.Conf() | |
latent_dynamics: LatentDynamicsLoss.Conf = LatentDynamicsLoss.Conf() | |
critic_optim: AdamWSpec.Conf = AdamWSpec.Conf(lr=1e-4) | |
- latent_dynamics_lr_mul: float = attrs.field(default=1., validator=attrs.validators.ge(0)) | |
- quasimetric_model_lr_mul: float = attrs.field(default=1., validator=attrs.validators.ge(0)) | |
- encoder_lr_mul: float = attrs.field(default=1., validator=attrs.validators.ge(0)) # TD-MPC2 uses 0.3 | |
- quasimetric_head_lr_mul: float = attrs.field(default=1., validator=attrs.validators.ge(0)) # IQE2 can benefit from smaller lr, ~1e-5 | |
+ latent_dynamics_lr_mul: float = 1 | |
+ quasimetric_model_lr_mul: float = 1 | |
+ encoder_lr_mul: float = 1 # TD-MPC2 uses 0.3 | |
+ quasimetric_head_lr_mul: float = 1 # IQE2 can benefit from smaller lr, ~1e-5 | |
local_lagrange_mult_optim: AdamWSpec.Conf = AdamWSpec.Conf(lr=1e-2) | |
dynamics_lagrange_mult_optim: AdamWSpec.Conf = AdamWSpec.Conf(lr=0) | |
- quasimetric_scale: Optional[str] = attrs.field( | |
- default=None, validator=attrs.validators.optional(attrs.validators.in_( | |
- ['best_local_fit', 'best_local_fit_clip5', 'best_local_fit_clip10', | |
- 'best_local_fit_detach']))) # type: ignore | |
+ scale_with_best_local_fit: bool = False | |
def make(self, critic: QuasimetricCritic, total_optim_steps: int, | |
share_embedding_from: Optional[QuasimetricCritic] = None) -> 'QuasimetricCriticLosses': | |
@@ -71,7 +65,6 @@ class QuasimetricCriticLosses(CriticLossBase): | |
# global losses | |
global_push=self.global_push.make(), | |
global_push_linear=self.global_push_linear.make(), | |
- global_push_next_mse=self.global_push_next_mse.make(), | |
global_push_log=self.global_push_log.make(), | |
global_push_rbf=self.global_push_rbf.make(), | |
# local loss | |
@@ -88,13 +81,12 @@ class QuasimetricCriticLosses(CriticLossBase): | |
local_lagrange_mult_optim_spec=self.local_lagrange_mult_optim.make(), | |
dynamics_lagrange_mult_optim_spec=self.dynamics_lagrange_mult_optim.make(), | |
# | |
- quasimetric_scale=self.quasimetric_scale, | |
+ scale_with_best_local_fit=self.scale_with_best_local_fit, | |
) | |
borrowing_embedding: bool | |
global_push: Optional[GlobalPushLoss] | |
global_push_linear: Optional[GlobalPushLinearLoss] | |
- global_push_next_mse: Optional[GlobalPushNextMSELoss] | |
global_push_log: Optional[GlobalPushLogLoss] | |
global_push_rbf: Optional[GlobalPushRBFLoss] | |
local_constraint: Optional[LocalConstraintLoss] | |
@@ -106,13 +98,12 @@ class QuasimetricCriticLosses(CriticLossBase): | |
local_lagrange_mult_sched: LRScheduler | |
dynamics_lagrange_mult_optim: OptimWrapper | |
dynamics_lagrange_mult_sched: LRScheduler | |
- quasimetric_scale: Optional[str] | |
+ scale_with_best_local_fit: bool | |
def __init__(self, critic: QuasimetricCritic, *, total_optim_steps: int, | |
share_embedding_from: Optional[QuasimetricCritic] = None, | |
global_push: Optional[GlobalPushLoss], global_push_linear: Optional[GlobalPushLinearLoss], | |
- global_push_next_mse: Optional[GlobalPushNextMSELoss], global_push_log: Optional[GlobalPushLogLoss], | |
- global_push_rbf: Optional[GlobalPushRBFLoss], | |
+ global_push_log: Optional[GlobalPushLogLoss], global_push_rbf: Optional[GlobalPushRBFLoss], | |
local_constraint: Optional[LocalConstraintLoss], latent_dynamics: LatentDynamicsLoss, | |
critic_optim_spec: AdamWSpec, | |
latent_dynamics_lr_mul: float, | |
@@ -121,7 +112,7 @@ class QuasimetricCriticLosses(CriticLossBase): | |
quasimetric_head_lr_mul: float, | |
local_lagrange_mult_optim_spec: AdamWSpec, | |
dynamics_lagrange_mult_optim_spec: AdamWSpec, | |
- quasimetric_scale: Optional[str]): | |
+ scale_with_best_local_fit: bool): | |
super().__init__() | |
self.borrowing_embedding = share_embedding_from is not None | |
if self.borrowing_embedding: | |
@@ -132,7 +123,6 @@ class QuasimetricCriticLosses(CriticLossBase): | |
local_constraint = None | |
self.add_module('global_push', global_push) | |
self.add_module('global_push_linear', global_push_linear) | |
- self.add_module('global_push_next_mse', global_push_next_mse) | |
self.add_module('global_push_log', global_push_log) | |
self.add_module('global_push_rbf', global_push_rbf) | |
self.add_module('local_constraint', local_constraint) | |
@@ -157,7 +147,7 @@ class QuasimetricCriticLosses(CriticLossBase): | |
self.dynamics_lagrange_mult_optim, self.dynamics_lagrange_mult_sched = dynamics_lagrange_mult_optim_spec.create_optim_scheduler( | |
latent_dynamics.parameters(), total_optim_steps) | |
assert len(list(latent_dynamics.parameters())) == 1 | |
- self.quasimetric_scale = quasimetric_scale | |
+ self.scale_with_best_local_fit = scale_with_best_local_fit | |
def optimizers(self) -> Iterable[OptimWrapper]: | |
return [self.critic_optim, self.local_lagrange_mult_optim, self.dynamics_lagrange_mult_optim] | |
@@ -165,34 +155,23 @@ class QuasimetricCriticLosses(CriticLossBase): | |
def schedulers(self) -> Iterable[LRScheduler]: | |
return [self.critic_sched, self.local_lagrange_mult_sched, self.dynamics_lagrange_mult_sched] | |
- def compute_best_quasimetric_scale(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> Tuple[torch.Tensor, torch.Tensor]: | |
+ @torch.no_grad() | |
+ def compute_best_quasimetric_scale(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> torch.Tensor: | |
assert self.local_constraint is not None and not self.borrowing_embedding | |
- critic_batch_info.critic.quasimetric_model.quasimetric_head.scale.detach_().fill_(1) # reset | |
dist = critic_batch_info.critic.quasimetric_model(critic_batch_info.zx, critic_batch_info.zy) | |
- return dist, (self.local_constraint.step_cost * (dist.mean() / dist.square().mean().clamp_min(1e-12))) # .detach().clamp_(1e-1, 1e1) | |
+ return (self.local_constraint.step_cost * (dist.mean() / dist.square().mean().clamp_min_(1e-8))).detach().clamp_(1e-3, 1e3) | |
def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult: | |
extra_info: Dict[str, torch.Tensor] = {} | |
- if self.quasimetric_scale is not None and not self.borrowing_embedding: | |
- unscaled_dist, scale = self.compute_best_quasimetric_scale(data, critic_batch_info) | |
- assert scale.grad_fn is not None # allow bp | |
- if self.quasimetric_scale == 'best_local_fit_detach': | |
- scale = scale.detach() | |
- elif self.quasimetric_scale == 'best_local_fit_clip5': | |
- scale = scale.clamp(1 / 5, 5) | |
- elif self.quasimetric_scale == 'best_local_fit_clip10': | |
- scale = scale.clamp(1 / 10, 10) | |
- extra_info['unscaled_dist'] = unscaled_dist | |
- extra_info['quasimetric_autoscale'] = scale | |
- critic_batch_info.critic.quasimetric_model.quasimetric_head.scale = scale | |
+ if self.scale_with_best_local_fit and not self.borrowing_embedding: | |
+ scale = extra_info['quasimetric_autoscale'] = self.compute_best_quasimetric_scale(data, critic_batch_info) | |
+ critic_batch_info.critic.quasimetric_model.quasimetric_head.scale.copy_(scale) | |
loss_results: Dict[str, LossResult] = {} | |
if self.global_push is not None: | |
loss_results.update(global_push=self.global_push(data, critic_batch_info)) | |
if self.global_push_linear is not None: | |
loss_results.update(global_push_linear=self.global_push_linear(data, critic_batch_info)) | |
- if self.global_push_next_mse is not None: | |
- loss_results.update(global_push_next_mse=self.global_push_next_mse(data, critic_batch_info)) | |
if self.global_push_log is not None: | |
loss_results.update(global_push_log=self.global_push_log(data, critic_batch_info)) | |
if self.global_push_rbf is not None: | |
@@ -210,4 +189,4 @@ class QuasimetricCriticLosses(CriticLossBase): | |
return torch.nn.Module.__call__(self, data, critic_batch_info) | |
def extra_repr(self) -> str: | |
- return f"borrowing_embedding={self.borrowing_embedding}, quasimetric_scale={self.quasimetric_scale!r}" | |
+ return f"borrowing_embedding={self.borrowing_embedding}" | |
diff --git a/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py b/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py | |
index 3879518..7c3f9f3 100644 | |
--- a/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py | |
+++ b/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py | |
@@ -53,104 +53,16 @@ from . import CriticLossBase, CriticBatchInfo | |
# return f"weight={self.weight:g}, softplus_beta={self.softplus_beta:g}, softplus_offset={self.softplus_offset:g}" | |
-class GlobalPushLossBase(CriticLossBase): | |
+ | |
+class GlobalPushLoss(CriticLossBase): | |
@attrs.define(kw_only=True) | |
- class Conf(abc.ABC): | |
+ class Conf: | |
# config / argparse uses this to specify behavior | |
enabled: bool = True | |
detach_goal: bool = False | |
detach_proj_goal: bool = False | |
- detach_qmet: bool = False | |
- step_cost: float = attrs.field(default=1., validator=attrs.validators.gt(0)) | |
weight: float = attrs.field(default=1., validator=attrs.validators.gt(0)) | |
- weight_future_goal: float = attrs.field(default=0., validator=attrs.validators.ge(0)) | |
- clamp_max_future_goal: bool = True | |
- | |
- @abc.abstractmethod | |
- def make(self) -> Optional['GlobalPushLossBase']: | |
- if not self.enabled: | |
- return None | |
- | |
- weight: float | |
- weight_future_goal: float | |
- detach_goal: bool | |
- detach_proj_goal: bool | |
- detach_qmet: bool | |
- step_cost: float | |
- clamp_max_future_goal: bool | |
- | |
- def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, | |
- detach_qmet: bool, step_cost: float, clamp_max_future_goal: bool): | |
- super().__init__() | |
- self.weight = weight | |
- self.weight_future_goal = weight_future_goal | |
- self.detach_goal = detach_goal | |
- self.detach_proj_goal = detach_proj_goal | |
- self.detach_qmet = detach_qmet | |
- self.step_cost = step_cost | |
- self.clamp_max_future_goal = clamp_max_future_goal | |
- | |
- def generate_dist_weight(self, data: BatchData, critic_batch_info: CriticBatchInfo): | |
- def get_dist(za: torch.Tensor, zb: torch.Tensor): | |
- if self.detach_goal: | |
- zb = zb.detach() | |
- with critic_batch_info.critic.quasimetric_model.requiring_grad(not self.detach_qmet): | |
- return critic_batch_info.critic.quasimetric_model(za, zb, proj_grad_enabled=(True, not self.detach_proj_goal)) | |
- | |
- # To randomly pair zx, zy, we just roll over zy by 1, because zx and zy | |
- # are latents of randomly ordered random batches. | |
- zgoal = torch.roll(critic_batch_info.zy, 1, dims=0) | |
- yield ( | |
- 'random_goal', | |
- zgoal, | |
- get_dist(critic_batch_info.zx, zgoal), | |
- self.weight, | |
- {}, | |
- ) | |
- if self.weight_future_goal > 0: | |
- zgoal = critic_batch_info.critic.get_encoder(critic_batch_info.zy_from_target_encoder)(data.future_observations) | |
- dist = get_dist(critic_batch_info.zx, zgoal) | |
- if self.clamp_max_future_goal: | |
- observed_upper_bound = self.step_cost * data.future_tdelta | |
- info = dict( | |
- ratio_future_observed_dist=(dist / observed_upper_bound).mean(), | |
- exceed_future_observed_dist_rate=(dist > observed_upper_bound).mean(dtype=torch.float32), | |
- ) | |
- dist = dist.clamp_max(self.step_cost * data.future_tdelta) | |
- else: | |
- info = {} | |
- yield ( | |
- 'future_goal', | |
- zgoal, | |
- dist, | |
- self.weight_future_goal,info, | |
- ) | |
- | |
- @abc.abstractmethod | |
- def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor, | |
- dist: torch.Tensor, weight: float, info: Mapping[str, torch.Tensor]) -> LossResult: | |
- raise NotImplementedError | |
- | |
- def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult: | |
- return LossResult.combine( | |
- { | |
- name: self.compute_loss(data, critic_batch_info, zgoal, dist, weight, info) | |
- for name, zgoal, dist, weight, info in self.generate_dist_weight(data, critic_batch_info) | |
- }, | |
- ) | |
- | |
- def extra_repr(self) -> str: | |
- return '\n'.join([ | |
- f"weight={self.weight:g}, detach_proj_goal={self.detach_proj_goal}", | |
- f"weight_future_goal={self.weight_future_goal:g}, detach_qmet={self.detach_qmet}", | |
- f"step_cost={self.step_cost:g}, clamp_max_future_goal={self.clamp_max_future_goal}", | |
- ]) | |
- | |
- | |
-class GlobalPushLoss(GlobalPushLossBase): | |
- @attrs.define(kw_only=True) | |
- class Conf(GlobalPushLossBase.Conf): | |
# smaller => smoother loss | |
softplus_beta: float = attrs.field(default=0.1, validator=attrs.validators.gt(0)) | |
@@ -163,175 +75,99 @@ class GlobalPushLoss(GlobalPushLossBase): | |
return None | |
return GlobalPushLoss( | |
weight=self.weight, | |
- weight_future_goal=self.weight_future_goal, | |
detach_goal=self.detach_goal, | |
detach_proj_goal=self.detach_proj_goal, | |
- detach_qmet=self.detach_qmet, | |
- step_cost=self.step_cost, | |
- clamp_max_future_goal=self.clamp_max_future_goal, | |
softplus_beta=self.softplus_beta, | |
softplus_offset=self.softplus_offset, | |
) | |
+ weight: float | |
+ detach_goal: bool | |
+ detach_proj_goal: bool | |
softplus_beta: float | |
softplus_offset: float | |
- def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, detach_qmet: bool, step_cost: float, clamp_max_future_goal: bool, | |
+ def __init__(self, *, weight: float, detach_goal: bool, detach_proj_goal: bool, | |
softplus_beta: float, softplus_offset: float): | |
- super().__init__(weight=weight, weight_future_goal=weight_future_goal, detach_goal=detach_goal, | |
- detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, clamp_max_future_goal=clamp_max_future_goal) | |
+ super().__init__() | |
+ self.weight = weight | |
+ self.detach_goal = detach_goal | |
+ self.detach_proj_goal = detach_proj_goal | |
self.softplus_beta = softplus_beta | |
self.softplus_offset = softplus_offset | |
- def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor, | |
- dist: torch.Tensor, weight: float, info: Mapping[str, torch.Tensor]) -> LossResult: | |
- dict_info: Dict[str, torch.Tensor] = dict(info) | |
+ def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult: | |
+ # To randomly pair zx, zy, we just roll over zy by 1, because zx and zy | |
+ # are latents of randomly ordered random batches. | |
+ zgoal = torch.roll(critic_batch_info.zy, 1, dims=0) | |
+ if self.detach_goal: | |
+ zgoal = zgoal.detach() | |
+ dists = critic_batch_info.critic.quasimetric_model(critic_batch_info.zx, zgoal, | |
+ proj_grad_enabled=(True, not self.detach_proj_goal)) | |
# Sec 3.2. Transform so that we penalize large distances less. | |
- tsfm_dist: torch.Tensor = F.softplus(self.softplus_offset - dist, beta=self.softplus_beta) # type: ignore | |
+ tsfm_dist: torch.Tensor = F.softplus(self.softplus_offset - dists, beta=self.softplus_beta) # type: ignore | |
tsfm_dist = tsfm_dist.mean() | |
- dict_info.update(dist=dist.mean(), tsfm_dist=tsfm_dist) | |
- return LossResult(loss=tsfm_dist * weight, info=dict_info) | |
+ return LossResult(loss=tsfm_dist * self.weight, info=dict(dist=dists.mean(), tsfm_dist=tsfm_dist)) # type: ignore | |
def extra_repr(self) -> str: | |
- return '\n'.join([ | |
- super().extra_repr(), | |
- f"softplus_beta={self.softplus_beta:g}, softplus_offset={self.softplus_offset:g}", | |
- ]) | |
+ return f"weight={self.weight:g}, detach_proj_goal={self.detach_proj_goal}, softplus_beta={self.softplus_beta:g}, softplus_offset={self.softplus_offset:g}" | |
+ | |
-class GlobalPushLinearLoss(GlobalPushLossBase): | |
+class GlobalPushLinearLoss(CriticLossBase): | |
@attrs.define(kw_only=True) | |
- class Conf(GlobalPushLossBase.Conf): | |
- enabled: bool = False | |
+ class Conf: | |
+ # config / argparse uses this to specify behavior | |
- clamp_max: Optional[float] = attrs.field(default=None, validator=attrs.validators.optional(attrs.validators.gt(0))) | |
+ enabled: bool = False | |
+ detach_goal: bool = False | |
+ detach_proj_goal: bool = False | |
+ weight: float = attrs.field(default=1., validator=attrs.validators.gt(0)) | |
def make(self) -> Optional['GlobalPushLinearLoss']: | |
if not self.enabled: | |
return None | |
return GlobalPushLinearLoss( | |
weight=self.weight, | |
- weight_future_goal=self.weight_future_goal, | |
detach_goal=self.detach_goal, | |
detach_proj_goal=self.detach_proj_goal, | |
- detach_qmet=self.detach_qmet, | |
- step_cost=self.step_cost, | |
- clamp_max_future_goal=self.clamp_max_future_goal, | |
- clamp_max=self.clamp_max, | |
) | |
- clamp_max: Optional[float] | |
- | |
- def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, detach_qmet: bool, step_cost: float, clamp_max_future_goal: bool, | |
- clamp_max: Optional[float]): | |
- super().__init__(weight=weight, weight_future_goal=weight_future_goal, detach_goal=detach_goal, | |
- detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, clamp_max_future_goal=clamp_max_future_goal) | |
- self.clamp_max = clamp_max | |
- | |
- def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor, | |
- dist: torch.Tensor, weight: float, info: Mapping[str, torch.Tensor]) -> LossResult: | |
- dict_info: Dict[str, torch.Tensor] = dict(info) | |
- if self.clamp_max is None: | |
- dist = dist.mean() | |
- dict_info.update(dist=dist) | |
- neg_loss = dist | |
- else: | |
- tsfm_dist = dist.clamp_max(self.clamp_max) | |
- dict_info.update( | |
- dist=dist.mean(), | |
- tsfm_dist=tsfm_dist.mean(), | |
- exceed_rate=(dist >= self.clamp_max).mean(dtype=torch.float32), | |
- ) | |
- neg_loss = tsfm_dist | |
- return LossResult(loss=neg_loss * weight, info=dict_info) | |
- | |
- def extra_repr(self) -> str: | |
- return '\n'.join([ | |
- super().extra_repr(), | |
- f"clamp_max={self.clamp_max!r}", | |
- ]) | |
- | |
- | |
- | |
- | |
-class GlobalPushNextMSELoss(GlobalPushLossBase): | |
- @attrs.define(kw_only=True) | |
- class Conf(GlobalPushLossBase.Conf): | |
- enabled: bool = False | |
- detach_target_dist: bool = True | |
- allow_gt: bool = False | |
- gamma: Optional[float] = attrs.field( | |
- default=None, validator=attrs.validators.optional(attrs.validators.and_( | |
- attrs.validators.gt(0), | |
- attrs.validators.lt(1), | |
- ))) | |
- | |
- def make(self) -> Optional['GlobalPushNextMSELoss']: | |
- if not self.enabled: | |
- return None | |
- return GlobalPushNextMSELoss( | |
- weight=self.weight, | |
- weight_future_goal=self.weight_future_goal, | |
- detach_goal=self.detach_goal, | |
- detach_proj_goal=self.detach_proj_goal, | |
- detach_qmet=self.detach_qmet, | |
- clamp_max_future_goal=self.clamp_max_future_goal, | |
- step_cost=self.step_cost, | |
- detach_target_dist=self.detach_target_dist, | |
- allow_gt=self.allow_gt, | |
- gamma=self.gamma, | |
- ) | |
- | |
- def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, detach_qmet: bool, step_cost: float, clamp_max_future_goal: bool, | |
- detach_target_dist: bool, allow_gt: bool, gamma: Optional[float]): | |
- super().__init__(weight=weight, weight_future_goal=weight_future_goal, detach_goal=detach_goal, | |
- detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, clamp_max_future_goal=clamp_max_future_goal) | |
- self.detach_target_dist = detach_target_dist | |
- self.allow_gt = allow_gt | |
- self.gamma = gamma | |
- | |
- def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor, | |
- dist: torch.Tensor, weight: float, info: Mapping[str, torch.Tensor]) -> LossResult: | |
- dict_info: Dict[str, torch.Tensor] = dict(info) | |
- | |
- with torch.enable_grad(self.detach_target_dist): | |
- # by tri-eq, the actual cost can't be larger than step_cost + d(s', g) | |
- next_dist = critic_batch_info.critic.quasimetric_model( | |
- critic_batch_info.zy, zgoal, proj_grad_enabled=(True, not self.detach_proj_goal) | |
- ) | |
- if self.detach_target_dist: | |
- next_dist = next_dist.detach() | |
- target_dist = self.step_cost + next_dist | |
- | |
- if self.allow_gt: | |
- dist = dist.clamp_max(target_dist) | |
- dict_info.update( | |
- exceed_rate=(dist >= target_dist).mean(dtype=torch.float32), | |
- ) | |
- | |
- if self.gamma is None: | |
- loss = F.mse_loss(dist, target_dist) | |
- else: | |
- loss = F.mse_loss(self.gamma ** dist, self.gamma ** target_dist) | |
+ weight: float | |
+ detach_goal: bool | |
+ detach_proj_goal: bool | |
- dict_info.update( | |
- loss=loss, dist=dist.mean(), target_dist=target_dist.mean() | |
- ) | |
+ def __init__(self, *, weight: float, detach_goal: bool, detach_proj_goal: bool): | |
+ super().__init__() | |
+ self.weight = weight | |
+ self.detach_goal = detach_goal | |
+ self.detach_proj_goal = detach_proj_goal | |
- return LossResult(loss=loss * self.weight, info=dict_info) | |
+ def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult: | |
+ # To randomly pair zx, zy, we just roll over zy by 1, because zx and zy | |
+ # are latents of randomly ordered random batches. | |
+ zgoal = torch.roll(critic_batch_info.zy, 1, dims=0) | |
+ if self.detach_goal: | |
+ zgoal = zgoal.detach() | |
+ dists = critic_batch_info.critic.quasimetric_model(critic_batch_info.zx, zgoal, | |
+ proj_grad_enabled=(True, not self.detach_proj_goal)) | |
+ dists = dists.mean() | |
+ return LossResult(loss=dists * (-self.weight), info=dict(dist=dists)) # type: ignore | |
def extra_repr(self) -> str: | |
- return '\n'.join([ | |
- super().extra_repr(), | |
- "detach_target_dist={self.detach_target_dist}", | |
- ]) | |
+ return f"weight={self.weight:g}, detach_proj_goal={self.detach_proj_goal}" | |
+ | |
-class GlobalPushLogLoss(GlobalPushLossBase): | |
+class GlobalPushLogLoss(CriticLossBase): | |
@attrs.define(kw_only=True) | |
- class Conf(GlobalPushLossBase.Conf): | |
- enabled: bool = False | |
+ class Conf: | |
+ # config / argparse uses this to specify behavior | |
+ enabled: bool = False | |
+ detach_goal: bool = False | |
+ detach_proj_goal: bool = False | |
+ weight: float = attrs.field(default=1., validator=attrs.validators.gt(0)) | |
offset: float = attrs.field(default=1., validator=attrs.validators.gt(0)) | |
def make(self) -> Optional['GlobalPushLogLoss']: | |
@@ -339,46 +175,54 @@ class GlobalPushLogLoss(GlobalPushLossBase): | |
return None | |
return GlobalPushLogLoss( | |
weight=self.weight, | |
- weight_future_goal=self.weight_future_goal, | |
detach_goal=self.detach_goal, | |
detach_proj_goal=self.detach_proj_goal, | |
- detach_qmet=self.detach_qmet, | |
- step_cost=self.step_cost, | |
- clamp_max_future_goal=self.clamp_max_future_goal, | |
offset=self.offset, | |
) | |
+ weight: float | |
+ detach_goal: bool | |
+ detach_proj_goal: bool | |
offset: float | |
- def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, detach_qmet: bool, step_cost: float, clamp_max_future_goal: bool, | |
- offset: float): | |
- super().__init__(weight=weight, weight_future_goal=weight_future_goal, detach_goal=detach_goal, | |
- detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, clamp_max_future_goal=clamp_max_future_goal) | |
+ def __init__(self, *, weight: float, detach_goal: bool, detach_proj_goal: bool, offset: float): | |
+ super().__init__() | |
+ self.weight = weight | |
+ self.detach_goal = detach_goal | |
+ self.detach_proj_goal = detach_proj_goal | |
self.offset = offset | |
- def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor, | |
- dist: torch.Tensor, weight: float, info: Mapping[str, torch.Tensor]) -> LossResult: | |
- dict_info: Dict[str, torch.Tensor] = dict(info) | |
- tsfm_dist: torch.Tensor = -dist.add(self.offset).log() | |
+ def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult: | |
+ # To randomly pair zx, zy, we just roll over zy by 1, because zx and zy | |
+ # are latents of randomly ordered random batches. | |
+ zgoal = torch.roll(critic_batch_info.zy, 1, dims=0) | |
+ if self.detach_goal: | |
+ zgoal = zgoal.detach() | |
+ dists = critic_batch_info.critic.quasimetric_model(critic_batch_info.zx, zgoal, | |
+ proj_grad_enabled=(True, not self.detach_proj_goal)) | |
+ # Sec 3.2. Transform so that we penalize large distances less. | |
+ tsfm_dist: torch.Tensor = -dists.add(self.offset).log() | |
tsfm_dist = tsfm_dist.mean() | |
- dict_info.update(dist=dist.mean(), tsfm_dist=tsfm_dist) | |
- return LossResult(loss=tsfm_dist * weight, info=dict_info) | |
+ return LossResult(loss=tsfm_dist * self.weight, info=dict(dist=dists.mean(), tsfm_dist=tsfm_dist)) # type: ignore | |
def extra_repr(self) -> str: | |
- return '\n'.join([ | |
- super().extra_repr(), | |
- f"offset={self.offset:g}", | |
- ]) | |
+ return f"weight={self.weight:g}, detach_proj_goal={self.detach_proj_goal}, offset={self.offset:g}" | |
+ | |
-class GlobalPushRBFLoss(GlobalPushLossBase): | |
+class GlobalPushRBFLoss(CriticLossBase): | |
# say E[opt T] approx sqrt(2)/2 timeout, so E[opt T^2] approx 1/2 timeout^2 | |
# to emulate log E exp(-2 d^2), where 2 d^2 is around 4, we scale model T with r, and use log E exp(- r^2 T^2), and let r^2 T^2 to be around 4 | |
# so r^2 approx 8 / timeout^2 and r approx 2.82 / timeout. If timeout = 850, this = 300 | |
@attrs.define(kw_only=True) | |
- class Conf(GlobalPushLossBase.Conf): | |
+ class Conf: | |
+ # config / argparse uses this to specify behavior | |
+ | |
enabled: bool = False | |
+ detach_goal: bool = False | |
+ detach_proj_goal: bool = False | |
+ weight: float = attrs.field(default=1., validator=attrs.validators.gt(0)) | |
inv_scale: float = attrs.field(default=300., validator=attrs.validators.ge(1e-3)) | |
@@ -387,37 +231,37 @@ class GlobalPushRBFLoss(GlobalPushLossBase): | |
return None | |
return GlobalPushRBFLoss( | |
weight=self.weight, | |
- weight_future_goal=self.weight_future_goal, | |
detach_goal=self.detach_goal, | |
detach_proj_goal=self.detach_proj_goal, | |
- detach_qmet=self.detach_qmet, | |
- step_cost=self.step_cost, | |
- clamp_max_future_goal=self.clamp_max_future_goal, | |
inv_scale=self.inv_scale, | |
) | |
+ weight: float | |
+ detach_goal: bool | |
+ detach_proj_goal: bool | |
inv_scale: float | |
- def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, detach_qmet: bool, step_cost: float, clamp_max_future_goal: bool, | |
- inv_scale: float): | |
- super().__init__(weight=weight, weight_future_goal=weight_future_goal, detach_goal=detach_goal, | |
- detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, clamp_max_future_goal=clamp_max_future_goal) | |
+ def __init__(self, *, weight: float, detach_goal: bool, detach_proj_goal: bool, inv_scale: float): | |
+ super().__init__() | |
+ self.weight = weight | |
+ self.detach_goal = detach_goal | |
+ self.detach_proj_goal = detach_proj_goal | |
self.inv_scale = inv_scale | |
- def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor, | |
- dist: torch.Tensor, weight: float, info: Mapping[str, torch.Tensor]) -> LossResult: | |
- dict_info: Dict[str, torch.Tensor] = dict(info) | |
- inv_scale = dist.detach().square().mean().div(2).sqrt().clamp(1e-3, self.inv_scale) # make E[d^2]/r^2 approx 2 | |
- tsfm_dist: torch.Tensor = (dist / inv_scale).square().neg().exp() | |
+ def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult: | |
+ # To randomly pair zx, zy, we just roll over zy by 1, because zx and zy | |
+ # are latents of randomly ordered random batches. | |
+ zgoal = torch.roll(critic_batch_info.zy, 1, dims=0) | |
+ if self.detach_goal: | |
+ zgoal = zgoal.detach() | |
+ dists = critic_batch_info.critic.quasimetric_model(critic_batch_info.zx, zgoal, | |
+ proj_grad_enabled=(True, not self.detach_proj_goal)) | |
+ inv_scale = dists.detach().square().mean().div(2).sqrt().clamp(1e-3, self.inv_scale) # make E[d^2]/r^2 approx 2 | |
+ tsfm_dist: torch.Tensor = (dists / inv_scale).square().neg().exp() | |
rbf_potential = tsfm_dist.mean().log() | |
- dict_info.update( | |
- dist=dist.mean(), inv_scale=inv_scale, | |
- tsfm_dist=tsfm_dist, rbf_potential=rbf_potential, | |
- ) | |
- return LossResult(loss=rbf_potential * self.weight, info=dict_info) | |
+ return LossResult(loss=rbf_potential * self.weight, | |
+ info=dict(dist=dists.mean(), inv_scale=inv_scale, | |
+ tsfm_dist=tsfm_dist, rbf_potential=rbf_potential)) # type: ignore | |
def extra_repr(self) -> str: | |
- return '\n'.join([ | |
- super().extra_repr(), | |
- f"inv_scale={self.inv_scale:g}", | |
- ]) | |
+ return f"weight={self.weight:g}, detach_proj_goal={self.detach_proj_goal}, inv_scale={self.inv_scale:g}" | |
diff --git a/quasimetric_rl/modules/quasimetric_critic/losses/latent_dynamics.py b/quasimetric_rl/modules/quasimetric_critic/losses/latent_dynamics.py | |
index 19e4886..2701e6f 100644 | |
--- a/quasimetric_rl/modules/quasimetric_critic/losses/latent_dynamics.py | |
+++ b/quasimetric_rl/modules/quasimetric_critic/losses/latent_dynamics.py | |
@@ -2,7 +2,6 @@ from typing import * | |
import attrs | |
-import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
@@ -26,7 +25,7 @@ class LatentDynamicsLoss(CriticLossBase): | |
# weight: float = attrs.field(default=0.1, validator=attrs.validators.gt(0)) | |
epsilon: float = attrs.field(default=0.25, validator=attrs.validators.gt(0)) | |
- # bidirectional: bool = True | |
+ bidirectional: bool = True | |
gamma: Optional[float] = attrs.field( | |
default=None, validator=attrs.validators.optional(attrs.validators.and_( | |
attrs.validators.gt(0), | |
@@ -36,54 +35,41 @@ class LatentDynamicsLoss(CriticLossBase): | |
detach_sp: bool = False | |
detach_proj_sp: bool = False | |
detach_qmet: bool = False | |
- non_quasimetric_dim_mse_weight: float = attrs.field(default=0., validator=attrs.validators.ge(0)) | |
- | |
- kind: str = attrs.field( | |
- default='mean', validator=attrs.validators.in_({'mean', 'unidir', 'max', 'bound', 'bound_l2comp', 'l2comp'})) # type: ignore | |
def make(self) -> 'LatentDynamicsLoss': | |
return LatentDynamicsLoss( | |
# weight=self.weight, | |
epsilon=self.epsilon, | |
- # bidirectional=self.bidirectional, | |
+ bidirectional=self.bidirectional, | |
gamma=self.gamma, | |
detach_sp=self.detach_sp, | |
detach_proj_sp=self.detach_proj_sp, | |
detach_qmet=self.detach_qmet, | |
- non_quasimetric_dim_mse_weight=self.non_quasimetric_dim_mse_weight, | |
init_lagrange_multiplier=self.init_lagrange_multiplier, | |
- kind=self.kind, | |
) | |
# weight: float | |
epsilon: float | |
- # bidirectional: bool | |
+ bidirectional: bool | |
gamma: Optional[float] | |
detach_sp: bool | |
detach_proj_sp: bool | |
detach_qmet: bool | |
- non_quasimetric_dim_mse_weight: float | |
c: float | |
init_lagrange_multiplier: float | |
- kind: str | |
- def __init__(self, *, epsilon: float, | |
- # bidirectional: bool, | |
+ def __init__(self, *, epsilon: float, bidirectional: bool, | |
gamma: Optional[float], init_lagrange_multiplier: float, | |
# weight: float, | |
- detach_qmet: bool, detach_proj_sp: bool, detach_sp: bool, | |
- non_quasimetric_dim_mse_weight: float, | |
- kind: str): | |
+ detach_qmet: bool, detach_proj_sp: bool, detach_sp: bool): | |
super().__init__() | |
# self.weight = weight | |
self.epsilon = epsilon | |
- # self.bidirectional = bidirectional | |
+ self.bidirectional = bidirectional | |
self.gamma = gamma | |
self.detach_qmet = detach_qmet | |
self.detach_sp = detach_sp | |
self.detach_proj_sp = detach_proj_sp | |
- self.kind = kind | |
- self.non_quasimetric_dim_mse_weight = non_quasimetric_dim_mse_weight | |
self.init_lagrange_multiplier = init_lagrange_multiplier | |
self.raw_lagrange_multiplier = nn.Parameter( | |
torch.tensor(softplus_inv_float(init_lagrange_multiplier), dtype=torch.float32)) | |
@@ -95,58 +81,15 @@ class LatentDynamicsLoss(CriticLossBase): | |
pred_zy = critic.latent_dynamics(data.observations, zx, data.actions) | |
if self.detach_sp: | |
zy = zy.detach() | |
+ with critic.quasimetric_model.requiring_grad(not self.detach_qmet): | |
+ dists = critic.quasimetric_model(zy, pred_zy, bidirectional=self.bidirectional, # at least optimize d(s', hat{s'}) | |
+ proj_grad_enabled=(self.detach_proj_sp, True)) | |
lagrange_mult = F.softplus(self.raw_lagrange_multiplier) # make positive | |
# lagrange multiplier is minimax training, so grad_mul -1 | |
lagrange_mult = grad_mul(lagrange_mult, -1) | |
- info: Dict[str, torch.Tensor] = {} | |
- | |
- if self.kind == 'unidir': | |
- with critic.quasimetric_model.requiring_grad(not self.detach_qmet): | |
- dists = critic.quasimetric_model(zy, pred_zy, bidirectional=False, # at least optimize d(s', hat{s'}) | |
- proj_grad_enabled=(self.detach_proj_sp, True)) | |
- info.update(dist_n2p=dists.mean()) | |
- elif self.kind in ['mean', 'max', 'bound', 'l2comp']: | |
- with critic.quasimetric_model.requiring_grad(not self.detach_qmet): | |
- dists_comps = critic.quasimetric_model( | |
- zy, pred_zy, bidirectional=True, proj_grad_enabled=(self.detach_proj_sp, True), | |
- reduced=False) | |
- dists = critic.quasimetric_model.quasimetric_head.reduction(dists_comps) | |
- | |
- dist_p2n, dist_n2p = dists.unbind(-1) | |
- info.update(dist_p2n=dist_p2n.mean(), dist_n2p=dist_n2p.mean()) | |
- if self.kind == 'max': | |
- dists = dists.max(-1).values | |
- elif self.kind == 'bound': | |
- with critic.quasimetric_model.requiring_grad(not self.detach_qmet): | |
- dists = critic.quasimetric_model( | |
- zy, pred_zy, bidirectional=False, # NB: false is enough for bound | |
- proj_grad_enabled=(self.detach_proj_sp, True), | |
- symmetric_upperbound=True, | |
- ) | |
- info.update(dist_upbnd=dists.mean()) | |
- elif self.kind == 'bound_l2comp': | |
- with critic.quasimetric_model.requiring_grad(not self.detach_qmet): | |
- dists_comps = critic.quasimetric_model( | |
- zy, pred_zy, bidirectional=False, # NB: false is enough for bound | |
- proj_grad_enabled=(self.detach_proj_sp, True), | |
- reduced=False, symmetric_upperbound=True, | |
- ) | |
- dists = cast( | |
- torch.Tensor, | |
- dists_comps.norm(dim=-1) / np.sqrt(dists_comps.shape[-1]), | |
- ) | |
- info.update(dist_upbnd_l2comp=dists.mean()) | |
- elif self.kind == 'l2comp': | |
- dists = cast( | |
- torch.Tensor, | |
- dists_comps.norm(dim=-1) / np.sqrt(dists_comps.shape[-1]), | |
- ) | |
- info.update(dist_l2comp=dists.mean()) | |
- else: | |
- raise NotImplementedError(self.kind) | |
- | |
+ info = {} | |
if self.gamma is None: | |
sq_dists = dists.square().mean() | |
violation = (sq_dists - self.epsilon ** 2) | |
@@ -158,15 +101,11 @@ class LatentDynamicsLoss(CriticLossBase): | |
loss = violation * lagrange_mult | |
info.update(violation=violation, lagrange_mult=lagrange_mult) | |
- if self.non_quasimetric_dim_mse_weight > 0: | |
- assert critic.quasimetric_model.input_slice_size < critic.quasimetric_model.input_size, \ | |
- "non-quasimetric dim mse only makes sense if input_slice_size < input_size, but got " \ | |
- f"{critic.quasimetric_model.input_slice_size} >= {critic.quasimetric_model.input_size}" | |
- _zy = zy[..., critic.quasimetric_model.input_slice_size:] | |
- _pred_zy = pred_zy[..., critic.quasimetric_model.input_slice_size:] | |
- non_quasimetric_dim_mse = F.mse_loss(_pred_zy, _zy) | |
- info.update(non_quasimetric_dim_mse=non_quasimetric_dim_mse) | |
- loss += non_quasimetric_dim_mse * self.non_quasimetric_dim_mse_weight | |
+ if self.bidirectional: | |
+ dist_p2n, dist_n2p = dists.flatten(0, -2).mean(0).unbind(-1) | |
+ info.update(dist_p2n=dist_p2n, dist_n2p=dist_n2p) | |
+ else: | |
+ info.update(dist_n2p=dists.mean()) | |
return LossResult( | |
loss=loss, | |
@@ -175,7 +114,7 @@ class LatentDynamicsLoss(CriticLossBase): | |
def extra_repr(self) -> str: | |
return '\n'.join([ | |
- f"kind={self.kind}, gamma={self.gamma!r}", | |
- f"epsilon={self.epsilon!r}, detach_sp={self.detach_sp}, detach_proj_sp={self.detach_proj_sp}, detach_qmet={self.detach_qmet}", | |
+ f"epsilon={self.epsilon!r}, bidirectional={self.bidirectional}", | |
+ f"gamma={self.gamma!r}, detach_sp={self.detach_sp}, detach_proj_sp={self.detach_proj_sp}, detach_qmet={self.detach_qmet}", | |
]) | |
# return f"weight={self.weight:g}, detach_sp={self.detach_sp}" | |
diff --git a/quasimetric_rl/modules/quasimetric_critic/losses/local_constraint.py b/quasimetric_rl/modules/quasimetric_critic/losses/local_constraint.py | |
index eb87576..1806248 100644 | |
--- a/quasimetric_rl/modules/quasimetric_critic/losses/local_constraint.py | |
+++ b/quasimetric_rl/modules/quasimetric_critic/losses/local_constraint.py | |
@@ -2,7 +2,6 @@ from typing import * | |
import attrs | |
-import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
@@ -31,14 +30,8 @@ class LocalConstraintLoss(CriticLossBase): | |
init_lagrange_multiplier: float = attrs.field(default=0.01, validator=attrs.validators.gt(0)) | |
- p: float = attrs.field(default=2., validator=attrs.validators.ge(1)) | |
- compare_before_pow: bool = True | |
- # kind: str = attrs.field( | |
- # default='mse', validator=attrs.validators.in_(['mse', 'sq'])) # type: ignore | |
- # batch_reduction: str = attrs.field( | |
- # default='mean', validator=attrs.validators.in_(['mean', 'l2'])) # type: ignore | |
- invpow_after_batch_agg: bool = False | |
- log: bool = False | |
+ kind: str = attrs.field( | |
+ default='mse', validator=attrs.validators.in_(['mse', 'sq'])) # type: ignore | |
detach_proj_sp: bool = False | |
detach_sp: bool = False | |
@@ -56,22 +49,12 @@ class LocalConstraintLoss(CriticLossBase): | |
step_cost=self.step_cost, | |
step_cost_high=self.step_cost_high, | |
init_lagrange_multiplier=self.init_lagrange_multiplier, | |
- # kind=self.kind, | |
- p=self.p, | |
- compare_before_pow=self.compare_before_pow, | |
- log=self.log, | |
- # batch_reduction=self.batch_reduction, | |
- invpow_after_batch_agg=self.invpow_after_batch_agg, | |
+ kind=self.kind, | |
detach_sp=self.detach_sp, | |
detach_proj_sp=self.detach_proj_sp, | |
) | |
- # kind: str | |
- p: float | |
- compare_before_pow: bool | |
- log: bool | |
- # batch_reduction: str | |
- invpow_after_batch_agg: bool | |
+ kind: str | |
epsilon: float | |
allow_le: bool | |
step_cost: float | |
@@ -82,21 +65,11 @@ class LocalConstraintLoss(CriticLossBase): | |
raw_lagrange_multiplier: nn.Parameter # for the QRL constrained optimization | |
- def __init__(self, *, | |
- # kind: str, | |
- p: float, compare_before_pow: bool, invpow_after_batch_agg: bool, | |
- log: bool, | |
- # batch_reduction: str, | |
- epsilon: float, allow_le: bool, | |
+ def __init__(self, *, kind: str, epsilon: float, allow_le: bool, | |
step_cost: float, step_cost_high: float, init_lagrange_multiplier: float, | |
detach_sp: bool, detach_proj_sp: bool): | |
super().__init__() | |
- # self.kind = kind | |
- self.p = p | |
- self.compare_before_pow = compare_before_pow | |
- self.log = log | |
- # self.batch_reduction = batch_reduction | |
- self.invpow_after_batch_agg = invpow_after_batch_agg | |
+ self.kind = kind | |
self.epsilon = epsilon | |
self.allow_le = allow_le | |
self.step_cost = step_cost | |
@@ -127,7 +100,6 @@ class LocalConstraintLoss(CriticLossBase): | |
info['dist_090'], | |
info['dist_100'], | |
) = dist.quantile(dist.new_tensor([0, 0.1, 0.25, 0.5, 0.75, 0.9, 1])).unbind() | |
- info.update(dist=dist.mean()) | |
lagrange_mult = F.softplus(self.raw_lagrange_multiplier) # make positive | |
# lagrange multiplier is minimax training, so grad_mul -1 | |
@@ -138,63 +110,31 @@ class LocalConstraintLoss(CriticLossBase): | |
else: | |
target = dist.detach().clamp(self.step_cost, self.step_cost_high) | |
- if self.log: | |
- dist = dist.log() | |
- if isinstance(target, torch.Tensor): | |
- target = target.log() | |
+ if self.kind == 'mse': | |
+ if self.allow_le: | |
+ sq_deviation = (dist - target).relu().square().mean() | |
else: | |
- target = np.log(target) | |
- info.update(dist_log=dist.mean()) | |
- | |
- if self.compare_before_pow: | |
- deviation = dist - target | |
- deviation = (torch.relu if self.allow_le else torch.abs)(deviation) | |
- deviation = deviation ** self.p | |
- else: | |
- deviation = (dist ** self.p - target ** self.p) | |
- deviation = (torch.relu if self.allow_le else torch.abs)(deviation) | |
- | |
- if self.invpow_after_batch_agg: | |
- deviation = deviation.mean().pow(1 / self.p) | |
- violation = deviation - self.epsilon | |
+ sq_deviation = (dist - target).square().mean() | |
+ elif self.kind == 'sq': | |
+ if self.allow_le: | |
+ sq_deviation = (dist.square() - target ** 2).relu().mean() | |
+ else: | |
+ sq_deviation = (dist.square() - target ** 2).abs().mean() | |
else: | |
- violation = deviation.mean() - (self.epsilon ** self.p) | |
- info.update(deviation=deviation.mean(), violation=violation, lagrange_mult=lagrange_mult) | |
- | |
- # if self.kind == 'mse': | |
- # if self.allow_le: | |
- # sq_deviation = (dist - target).relu().square() | |
- # else: | |
- # sq_deviation = (dist - target).square() | |
- # elif self.kind == 'sq': | |
- # if self.allow_le: | |
- # sq_deviation = (dist.square() - target ** 2).relu() | |
- # else: | |
- # sq_deviation = (dist.square() - target ** 2).abs() | |
- # else: | |
- # raise ValueError(f"Unknown kind {self.kind}") | |
- # info.update(sq_deviation=sq_deviation.mean()) | |
- | |
- # if self.batch_reduction == 'mean': | |
- # batch_sq_deviation = sq_deviation.mean() | |
- # elif self.batch_reduction == 'l2': | |
- # batch_sq_deviation = sq_deviation.square().mean().sqrt() | |
- # else: | |
- # raise ValueError(f"Unknown batch_reduction {self.batch_reduction}") | |
- # info.update(batch_sq_deviation=batch_sq_deviation) | |
- | |
- # violation = (batch_sq_deviation - self.epsilon ** 2) | |
- | |
+ raise ValueError(f"Unknown kind {self.kind}") | |
+ violation = (sq_deviation - self.epsilon ** 2) | |
loss = violation * lagrange_mult | |
- # info.update(violation=violation, lagrange_mult=lagrange_mult) | |
+ | |
+ info.update( | |
+ dist=dist.mean(), sq_deviation=sq_deviation, | |
+ violation=violation, lagrange_mult=lagrange_mult, | |
+ ) | |
return LossResult(loss=loss, info=info) | |
def extra_repr(self) -> str: | |
return '\n'.join([ | |
- # f"kind={self.kind}, log={self.log}, batch_reduction={self.batch_reduction}", | |
- f"p={self.p}, compare_before_pow={self.compare_before_pow}, invpow_after_batch_agg={self.invpow_after_batch_agg}", | |
- f"epsilon={self.epsilon:g}, allow_le={self.allow_le}", | |
+ f"{self.kind}, epsilon={self.epsilon:g}, allow_le={self.allow_le}", | |
f"step_cost={self.step_cost:g}, step_cost_high={self.step_cost_high:g}", | |
f"detach_sp={self.detach_sp}, detach_proj_sp={self.detach_proj_sp}", | |
]) | |
diff --git a/quasimetric_rl/modules/quasimetric_critic/models/__init__.py b/quasimetric_rl/modules/quasimetric_critic/models/__init__.py | |
index 266f592..f9e05c4 100644 | |
--- a/quasimetric_rl/modules/quasimetric_critic/models/__init__.py | |
+++ b/quasimetric_rl/modules/quasimetric_critic/models/__init__.py | |
@@ -25,11 +25,6 @@ class QuasimetricCritic(Module): | |
attrs.validators.le(1), | |
))) | |
encoder: Encoder.Conf = Encoder.Conf() | |
- target_quasimetric_model_ema: Optional[float] = attrs.field( | |
- default=None, validator=attrs.validators.optional(attrs.validators.and_( | |
- attrs.validators.ge(0), | |
- attrs.validators.le(1), | |
- ))) | |
quasimetric_model: QuasimetricModel.Conf = QuasimetricModel.Conf() | |
latent_dynamics: Dynamics.Conf = Dynamics.Conf() | |
@@ -47,46 +42,34 @@ class QuasimetricCritic(Module): | |
) | |
return QuasimetricCritic(encoder, quasimetric_model, latent_dynamics, | |
share_embedding_from=share_embedding_from, | |
- target_encoder_ema=self.target_encoder_ema, | |
- target_quasimetric_model_ema=self.target_quasimetric_model_ema) | |
+ target_encoder_ema=self.target_encoder_ema) | |
borrowing_embedding: bool | |
encoder: Encoder | |
_target_encoder: Optional[Encoder] | |
target_encoder_ema: Optional[float] | |
quasimetric_model: QuasimetricModel | |
- _target_quasimetric_model: Optional[QuasimetricModel] | |
- target_quasimetric_model_ema: Optional[float] | |
latent_dynamics: Dynamics # FIXME: add BC to model loading & change name | |
def __init__(self, encoder: Encoder, quasimetric_model: QuasimetricModel, latent_dynamics: Dynamics, | |
- share_embedding_from: Optional['QuasimetricCritic'], target_encoder_ema: Optional[float], | |
- target_quasimetric_model_ema: Optional[float] = None): | |
+ share_embedding_from: Optional['QuasimetricCritic'], target_encoder_ema: Optional[float]): | |
super().__init__() | |
self.borrowing_embedding = share_embedding_from is not None | |
if share_embedding_from is not None: | |
encoder = share_embedding_from.encoder | |
- _target_encoder = share_embedding_from._target_encoder | |
+ target_encoder = share_embedding_from.target_encoder | |
assert target_encoder_ema == share_embedding_from.target_encoder_ema | |
quasimetric_model = share_embedding_from.quasimetric_model | |
- assert target_quasimetric_model_ema == share_embedding_from.target_quasimetric_model_ema | |
- _target_quasimetric_model = share_embedding_from._target_quasimetric_model | |
else: | |
if target_encoder_ema is None: | |
- _target_encoder = None | |
- else: | |
- _target_encoder = copy.deepcopy(encoder).requires_grad_(False).eval() | |
- if target_quasimetric_model_ema is None: | |
- _target_quasimetric_model = None | |
+ target_encoder = None | |
else: | |
- _target_quasimetric_model = copy.deepcopy(quasimetric_model).requires_grad_(False).eval() | |
+ target_encoder = copy.deepcopy(encoder).requires_grad_(False).eval() | |
self.encoder = encoder | |
+ self.add_module('_target_encoder', target_encoder) | |
self.target_encoder_ema = target_encoder_ema | |
- self.add_module('_target_encoder', _target_encoder) | |
self.quasimetric_model = quasimetric_model | |
- self.target_quasimetric_model_ema = target_quasimetric_model_ema | |
- self.add_module('_target_quasimetric_model', _target_quasimetric_model) | |
self.latent_dynamics = latent_dynamics | |
def forward(self, x: torch.Tensor, y: torch.Tensor, *, action: Optional[torch.Tensor] = None) -> torch.Tensor: | |
@@ -108,34 +91,16 @@ class QuasimetricCritic(Module): | |
else: | |
return self.encoder | |
- def get_encoder(self, target: bool = False) -> Encoder: | |
- return self.target_encoder if target else self.encoder | |
- | |
- @property | |
- def target_quasimetric_model(self) -> QuasimetricModel: | |
- if self._target_quasimetric_model is not None: | |
- return self._target_quasimetric_model | |
- else: | |
- return self.quasimetric_model | |
- | |
- def get_quasimetric_model(self, target: bool = False) -> QuasimetricModel: | |
- return self.target_quasimetric_model if target else self.quasimetric_model | |
- | |
@torch.no_grad() | |
- def update_target_models_(self): | |
- if not self.borrowing_embedding: | |
- if self.target_encoder_ema is not None: | |
- assert self._target_encoder is not None | |
- for p, p_target in zip(self.encoder.parameters(), self._target_encoder.parameters()): | |
- p_target.data.lerp_(p.data, 1 - self.target_encoder_ema) | |
- if self.target_quasimetric_model_ema is not None: | |
- assert self._target_quasimetric_model is not None | |
- for p, p_target in zip(self.quasimetric_model.parameters(), self._target_quasimetric_model.parameters()): | |
- p_target.data.lerp_(p.data, 1 - self.target_quasimetric_model_ema) | |
+ def update_target_encoder_(self): | |
+ if not self.borrowing_embedding and self.target_encoder_ema is not None: | |
+ assert self._target_encoder is not None | |
+ for p, p_target in zip(self.encoder.parameters(), self._target_encoder.parameters()): | |
+ p_target.data.lerp_(p.data, 1 - self.target_encoder_ema) | |
# for type hints | |
def __call__(self, x: torch.Tensor, y: torch.Tensor, *, action: Optional[torch.Tensor] = None) -> torch.Tensor: | |
return super().__call__(x, y, action=action) | |
def extra_repr(self) -> str: | |
- return f"borrowing_embedding={self.borrowing_embedding}, target_encoder_ema={self.target_encoder_ema}, target_quasimetric_model_ema={self.target_quasimetric_model_ema}" | |
+ return f"borrowing_embedding={self.borrowing_embedding}" | |
diff --git a/quasimetric_rl/modules/quasimetric_critic/models/latent_dynamics.py b/quasimetric_rl/modules/quasimetric_critic/models/latent_dynamics.py | |
index 9613674..0847076 100644 | |
--- a/quasimetric_rl/modules/quasimetric_critic/models/latent_dynamics.py | |
+++ b/quasimetric_rl/modules/quasimetric_critic/models/latent_dynamics.py | |
@@ -20,7 +20,6 @@ class Dynamics(MLP): | |
# config / argparse uses this to specify behavior | |
arch: Tuple[int, ...] = (512, 512) | |
- action_arch: Tuple[int, ...] = () | |
layer_norm: bool = True | |
latent_input: bool = True | |
raw_observation_input_arch: Optional[Tuple[int, ...]] = None | |
@@ -35,7 +34,6 @@ class Dynamics(MLP): | |
latent_input=self.latent_input, | |
raw_observation_input_arch=self.raw_observation_input_arch, | |
env_spec=env_spec, | |
- action_arch=self.action_arch, | |
hidden_sizes=self.arch, | |
layer_norm=self.layer_norm, | |
use_latent_space_proj=self.use_latent_space_proj, | |
@@ -43,8 +41,7 @@ class Dynamics(MLP): | |
fc_fuse=self.fc_fuse, | |
) | |
- action_input: Union[InputEncoding, nn.Sequential] | |
- action_arch: Tuple[int, ...] | |
+ action_input: InputEncoding | |
latent_input: bool | |
raw_observation_input: Optional[nn.Sequential] | |
residual: bool | |
@@ -53,50 +50,32 @@ class Dynamics(MLP): | |
def __init__(self, *, latent: LatentSpaceConf, | |
latent_input: bool, | |
raw_observation_input_arch: Optional[Tuple[int, ...]], | |
- env_spec: EnvSpec, action_arch: Tuple[int, ...], | |
- hidden_sizes: Tuple[int, ...], layer_norm: bool, | |
+ env_spec: EnvSpec, hidden_sizes: Tuple[int, ...], layer_norm: bool, | |
use_latent_space_proj: bool, residual: bool, | |
fc_fuse: bool): | |
- self.action_arch = action_arch | |
action_input = env_spec.make_action_input() | |
- action_outsize = action_input.output_size | |
- if action_arch: | |
- action_input = nn.Sequential( | |
- action_input, | |
- MLP( | |
- action_input.output_size, | |
- action_arch[-1], | |
- hidden_sizes=action_arch[:-1], | |
- activation_last_fc=True, | |
- layer_norm_last_fc=layer_norm, | |
- layer_norm=layer_norm, | |
- ), | |
- ) | |
- action_outsize = action_arch[-1] | |
- mlp_input_size = action_outsize | |
+ mlp_input_size = action_input.output_size | |
mlp_output_size = latent.latent_size | |
self.latent_input = latent_input | |
if latent_input: | |
mlp_input_size += latent.latent_size | |
if raw_observation_input_arch is not None: | |
+ assert raw_observation_input_arch is not None | |
+ assert len(raw_observation_input_arch) >= 1 | |
input_encoding = env_spec.make_observation_input() | |
- if len(raw_observation_input_arch) >= 1: | |
- raw_observation_input = MLP( | |
- input_encoding.output_size, | |
- raw_observation_input_arch[-1], | |
- hidden_sizes=raw_observation_input_arch[:-1], | |
- activation_last_fc=True, | |
- layer_norm_last_fc=layer_norm, | |
- layer_norm=layer_norm, | |
- ) | |
- mlp_input_size += raw_observation_input.output_size | |
- raw_observation_input = nn.Sequential( | |
- input_encoding, | |
- raw_observation_input, | |
- ) | |
- else: | |
- raw_observation_input = input_encoding | |
- mlp_input_size += input_encoding.output_size | |
+ raw_observation_input = MLP( | |
+ input_encoding.output_size, | |
+ raw_observation_input_arch[-1], | |
+ hidden_sizes=raw_observation_input_arch[:-1], | |
+ activation_last_fc=True, | |
+ layer_norm_last_fc=layer_norm, | |
+ layer_norm=layer_norm, | |
+ ) | |
+ mlp_input_size += raw_observation_input.output_size | |
+ raw_observation_input = nn.Sequential( | |
+ input_encoding, | |
+ raw_observation_input, | |
+ ) | |
else: | |
raw_observation_input = None | |
super().__init__( | |
@@ -134,7 +113,7 @@ class Dynamics(MLP): | |
inputs: List[torch.Tensor] = [] | |
if self.raw_observation_input is not None: | |
inputs.append(self.raw_observation_input(observation)) | |
- if self.latent_input: | |
+ if self.latent_input is not None: | |
inputs.append(zx) | |
inputs.append(self.action_input(action)) | |
diff --git a/quasimetric_rl/modules/quasimetric_critic/models/quasimetric_model.py b/quasimetric_rl/modules/quasimetric_critic/models/quasimetric_model.py | |
index 11577ee..5d19754 100644 | |
--- a/quasimetric_rl/modules/quasimetric_critic/models/quasimetric_model.py | |
+++ b/quasimetric_rl/modules/quasimetric_critic/models/quasimetric_model.py | |
@@ -20,7 +20,7 @@ class L2(torchqmet.QuasimetricBase): | |
super().__init__(input_size, num_components=1, guaranteed_quasimetric=True, | |
transforms=[], reduction='sum', discount=None) | |
- def compute_components(self, x: torch.Tensor, y: torch.Tensor, symmetric_upperbound: bool = False) -> torch.Tensor: | |
+ def compute_components(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
r''' | |
Inputs: | |
x (torch.Tensor): Shape [..., input_size] | |
@@ -41,11 +41,11 @@ def create_quasimetric_head_from_spec(spec: str) -> torchqmet.QuasimetricBase: | |
assert dim % components == 0, "IQE: dim must be divisible by components" | |
return torchqmet.IQE(dim, dim // components) | |
- def iqe2(*, dim: int, components: int, norm_delta: bool = False, fake_grad: bool = False, reduction: str = 'maxl12_sm') -> torchqmet.IQE: | |
+ def iqe2(*, dim: int, components: int, scale: bool = False, norm_delta: bool = False, fake_grad: bool = False) -> torchqmet.IQE: | |
assert dim % components == 0, "IQE: dim must be divisible by components" | |
return torchqmet.IQE2( | |
dim, dim // components, | |
- reduction=reduction, | |
+ reduction='maxl12_sm' if not scale else 'maxl12_sm_scale', | |
learned_delta=True, | |
learned_div=False, | |
div_init_mul=0.25, | |
@@ -80,7 +80,6 @@ class QuasimetricModel(Module): | |
class Conf: | |
# config / argparse uses this to specify behavior | |
- input_slice_size: Optional[int] = attrs.field(default=None, validator=attrs.validators.optional(attrs.validators.gt(0))) # take the first n dims | |
projector_arch: Optional[Tuple[int, ...]] = (512,) | |
projector_layer_norm: bool = True | |
projector_dropout: float = attrs.field(default=0., validator=attrs.validators.ge(0)) # TD-MPC2 uses 0.01 | |
@@ -89,11 +88,8 @@ class QuasimetricModel(Module): | |
quasimetric_head_spec: str = 'iqe(dim=2048,components=64)' | |
def make(self, *, input_size: int) -> 'QuasimetricModel': | |
- if self.input_slice_size is not None: | |
- assert self.input_slice_size <= input_size, f'input_slice_size={self.input_slice_size} > input_size={input_size}' | |
return QuasimetricModel( | |
input_size=input_size, | |
- input_slice_size=self.input_slice_size or input_size, | |
projector_arch=self.projector_arch, | |
projector_layer_norm=self.projector_layer_norm, | |
projector_dropout=self.projector_dropout, | |
@@ -103,36 +99,30 @@ class QuasimetricModel(Module): | |
) | |
input_size: int | |
- input_slice_size: int | |
projector: Union[Identity, MLP] | |
quasimetric_head: torchqmet.QuasimetricBase | |
- def __init__(self, *, input_size: int, input_slice_size: int, projector_arch: Optional[Tuple[int, ...]], | |
+ def __init__(self, *, input_size: int, projector_arch: Optional[Tuple[int, ...]], | |
projector_layer_norm: bool, projector_dropout: float, projector_weight_norm: bool, | |
projector_unit_norm: bool, quasimetric_head_spec: str): | |
super().__init__() | |
self.input_size = input_size | |
- self.input_slice_size = input_slice_size | |
self.quasimetric_head = create_quasimetric_head_from_spec(quasimetric_head_spec) | |
if projector_arch is None: | |
- assert input_slice_size == self.quasimetric_head.input_size, \ | |
- f'no projector but latent input_slice_size={input_slice_size}, quasimetric_head.input_size={self.quasimetric_head.input_size}' | |
+ assert input_size == self.quasimetric_head.input_size, \ | |
+ f'no projector but latent input_size={input_size}, quasimetric_head.input_size={self.quasimetric_head.input_size}' | |
self.projector = Identity() | |
else: | |
self.projector = MLP( | |
- input_slice_size, self.quasimetric_head.input_size, | |
+ input_size, self.quasimetric_head.input_size, | |
hidden_sizes=projector_arch, | |
layer_norm=projector_layer_norm, | |
dropout=projector_dropout, | |
weight_norm_last_fc=projector_weight_norm, | |
- unit_norm_last_fc=projector_unit_norm, | |
- ) | |
+ unit_norm_last_fc=projector_unit_norm) | |
def forward(self, zx: LatentTensor, zy: LatentTensor, *, bidirectional: bool = False, | |
- proj_grad_enabled: Tuple[bool, bool] = (True, True), reduced: bool = True, | |
- symmetric_upperbound: bool = False) -> torch.Tensor: | |
- zx = zx[..., :self.input_slice_size] | |
- zy = zy[..., :self.input_slice_size] | |
+ proj_grad_enabled: Tuple[bool, bool] = (True, True)) -> torch.Tensor: | |
with self.projector.requiring_grad(proj_grad_enabled[0]): | |
px = self.projector(zx) # [B x D] | |
with self.projector.requiring_grad(proj_grad_enabled[1]): | |
@@ -142,7 +132,7 @@ class QuasimetricModel(Module): | |
px, py = torch.broadcast_tensors(px, py) | |
px, py = torch.stack([px, py], dim=-2), torch.stack([py, px], dim=-2) # [B x 2 x D] | |
- return self.quasimetric_head(px, py, reduced=reduced, symmetric_upperbound=symmetric_upperbound) | |
+ return self.quasimetric_head(px, py) | |
def parameters(self, recurse: bool = True, *, include_head: bool = True) -> Iterator[Parameter]: | |
if include_head: | |
@@ -155,12 +145,8 @@ class QuasimetricModel(Module): | |
# for type hint | |
def __call__(self, zx: LatentTensor, zy: LatentTensor, *, bidirectional: bool = False, | |
- proj_grad_enabled: Tuple[bool, bool] = (True, True), reduced: bool = True, | |
- symmetric_upperbound: bool = False) -> torch.Tensor: | |
- return super().__call__(zx, zy, bidirectional=bidirectional, | |
- proj_grad_enabled=proj_grad_enabled, | |
- reduced=reduced, | |
- symmetric_upperbound=symmetric_upperbound) | |
+ proj_grad_enabled: Tuple[bool, bool] = (True, True)) -> torch.Tensor: | |
+ return super().__call__(zx, zy, bidirectional=bidirectional, proj_grad_enabled=proj_grad_enabled) | |
def extra_repr(self) -> str: | |
- return f"input_size={self.input_size}, input_slice_size={self.input_slice_size}" | |
+ return f"input_size={self.input_size}" | |
diff --git a/quasimetric_rl/modules/utils.py b/quasimetric_rl/modules/utils.py | |
index 0241ac2..5d7a845 100644 | |
--- a/quasimetric_rl/modules/utils.py | |
+++ b/quasimetric_rl/modules/utils.py | |
@@ -33,18 +33,11 @@ InfoValT = Union[InfoT, float, torch.Tensor] | |
@attrs.define(kw_only=True) | |
class LossResult: | |
- loss: InfoValT | |
+ loss: Union[torch.Tensor, float] | |
info: InfoT | |
def __attrs_post_init__(self): | |
- def test_loss(loss: InfoValT): | |
- if isinstance(loss, Mapping): | |
- for v in loss.values(): | |
- test_loss(v) | |
- else: | |
- assert isinstance(loss, (int, float)) or loss.numel() == 1 | |
- | |
- test_loss(self.loss) | |
+ assert isinstance(self.loss, (int, float)) or self.loss.numel() == 1 | |
# detach info tensors | |
def detach(d: InfoValT) -> InfoValT: | |
@@ -57,40 +50,17 @@ class LossResult: | |
object.__setattr__(self, "info", detach(self.info)) | |
- def map_losses(self, fn: Callable[[torch.Tensor | float], torch.Tensor | float]) -> InfoValT: | |
- def map_loss(loss: InfoValT): | |
- if isinstance(loss, Mapping): | |
- return {k: map_loss(v) for k, v in loss.items()} | |
- else: | |
- return fn(loss) | |
- | |
- return map_loss(self.loss) | |
- | |
@classmethod | |
def empty(cls) -> 'LossResult': | |
return LossResult(loss=0, info={}) | |
- @property | |
- def total_loss(self) -> torch.Tensor: | |
- l = 0 | |
- def add_loss(loss: InfoValT): | |
- nonlocal l | |
- if isinstance(loss, Mapping): | |
- for v in loss.values(): | |
- add_loss(v) | |
- else: | |
- l += loss | |
- add_loss(self.loss) | |
- assert isinstance(l, torch.Tensor) | |
- return l | |
- | |
@classmethod | |
def combine(cls, results: Mapping[str, 'LossResult'], **kwargs) -> 'LossResult': | |
info = {k: r.info for k, r in results.items()} | |
assert info.keys().isdisjoint(kwargs.keys()) | |
info.update(**kwargs) | |
return LossResult( | |
- loss={k: r.loss for k, r in results.items()}, | |
+ loss=sum(r.loss for r in results.values()), | |
info=info, | |
) | |
diff --git a/quasimetric_rl/utils/logging.py b/quasimetric_rl/utils/logging.py | |
index 6efab8d..2a9e23f 100644 | |
--- a/quasimetric_rl/utils/logging.py | |
+++ b/quasimetric_rl/utils/logging.py | |
@@ -4,8 +4,6 @@ | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
-from typing import * | |
- | |
import sys | |
import os | |
import logging | |
@@ -36,12 +34,10 @@ class TqdmLoggingHandler(logging.Handler): | |
class MultiLineFormatter(logging.Formatter): | |
- _fmt: str | |
def __init__(self, fmt=None, datefmt=None, style='%'): | |
assert style == '%' | |
super(MultiLineFormatter, self).__init__(fmt, datefmt, style) | |
- assert fmt is not None | |
self.multiline_fmt = fmt | |
def format(self, record): | |
@@ -79,7 +75,7 @@ class MultiLineFormatter(logging.Formatter): | |
output += '\n'.join( | |
self.multiline_fmt % dict(record.__dict__, message=line) | |
for index, line | |
- in enumerate(record.exc_text.decode(sys.getfilesystemencoding(), 'replace').splitlines()) # type: ignore | |
+ in enumerate(record.exc_text.decode(sys.getfilesystemencoding(), 'replace').splitlines()) | |
) | |
return output | |
@@ -100,7 +96,7 @@ def configure(logging_file, log_level=logging.INFO, level_prefix='', prefix='', | |
sys.excepthook = handle_exception # automatically log uncaught errors | |
- handlers: List[logging.Handler] = [] | |
+ handlers = [] | |
if write_to_stdout: | |
handlers.append(TqdmLoggingHandler()) | |
Submodule third_party/torch-quasimetric 186ed87..c5213ff (rewind): | |
diff --git a/third_party/torch-quasimetric/torchqmet/__init__.py b/third_party/torch-quasimetric/torchqmet/__init__.py | |
index b776de5..0008b7a 100644 | |
--- a/third_party/torch-quasimetric/torchqmet/__init__.py | |
+++ b/third_party/torch-quasimetric/torchqmet/__init__.py | |
@@ -54,22 +54,16 @@ class QuasimetricBase(nn.Module, metaclass=abc.ABCMeta): | |
''' | |
pass | |
- def forward(self, x: torch.Tensor, y: torch.Tensor, *, reduced: bool = True, **kwargs) -> torch.Tensor: | |
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
assert x.shape[-1] == y.shape[-1] == self.input_size | |
- d = self.compute_components(x, y, **kwargs) | |
+ d = self.compute_components(x, y) | |
d: torch.Tensor = self.transforms(d) | |
- scale = self.scale | |
- if not self.training: | |
- scale = scale.detach() | |
- if reduced: | |
- return self.reduction(d) * scale | |
- else: | |
- return d * scale | |
- | |
- def __call__(self, x: torch.Tensor, y: torch.Tensor, reduced: bool = True, **kwargs) -> torch.Tensor: | |
+ return self.reduction(d) * self.scale | |
+ | |
+ def __call__(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
# Manually define for typing | |
# https://github.com/pytorch/pytorch/issues/45414 | |
- return super().__call__(x, y, reduced=reduced, **kwargs) | |
+ return super().__call__(x, y) | |
def extra_repr(self) -> str: | |
return f"guaranteed_quasimetric={self.guaranteed_quasimetric}\ninput_size={self.input_size}, num_components={self.num_components}" + ( | |
@@ -82,5 +76,5 @@ from .iqe import IQE, IQE2 | |
from .mrn import MRN, MRNFixed | |
from .neural_norms import DeepNorm, WideNorm | |
-__all__ = ['PQE', 'PQELH', 'PQEGG', 'IQE', 'IQE2', 'MRN', 'MRNFixed', 'DeepNorm', 'WideNorm'] | |
+__all__ = ['PQE', 'PQELH', 'PQEGG', 'IQE', 'MRN', 'MRNFixed', 'DeepNorm', 'WideNorm'] | |
__version__ = "0.1.0" | |
diff --git a/third_party/torch-quasimetric/torchqmet/iqe.py b/third_party/torch-quasimetric/torchqmet/iqe.py | |
index f084b6d..bc03f05 100644 | |
--- a/third_party/torch-quasimetric/torchqmet/iqe.py | |
+++ b/third_party/torch-quasimetric/torchqmet/iqe.py | |
@@ -12,6 +12,7 @@ from . import QuasimetricBase | |
# The PQELH function. | |
+@torch.jit.script | |
def f_PQELH(h: torch.Tensor): # PQELH: strictly monotonically increasing mapping from [0, +infty) -> [0, 1) | |
return -torch.expm1(-h) | |
@@ -20,28 +21,28 @@ def iqe_tensor_delta(x: torch.Tensor, y: torch.Tensor, delta: torch.Tensor, div_ | |
D = x.shape[-1] # D: component_dim | |
# ignore pairs that x >= y | |
- valid = (x < y) # [..., K, D] | |
+ valid = (x < y) | |
# sort to better count | |
- xy = torch.cat(torch.broadcast_tensors(x, y), dim=-1) # [..., K, 2D] | |
+ xy = torch.cat(torch.broadcast_tensors(x, y), dim=-1) | |
sxy, ixy = xy.sort(dim=-1) | |
# neg_inc: the **negated** increment of **input** of f at sorted locations | |
# inc = torch.gather(delta * valid, dim=-1, index=ixy % D) * torch.where(ixy < D, 1, -1) | |
- neg_inc = torch.gather(delta * valid, dim=-1, index=ixy % D) * torch.where(ixy < D, -1, 1) # [..., K, 2D-sort] | |
+ neg_inc = torch.gather(delta * valid, dim=-1, index=ixy % D) * torch.where(ixy < D, -1, 1) | |
# neg_incf: the **negated** increment of **output** of f at sorted locations | |
neg_f_input = torch.cumsum(neg_inc, dim=-1) / div_pre_f[:, None] | |
if fake_grad: | |
neg_f_input__grad_path = neg_f_input.clone() | |
- neg_f_input__grad_path.data.clamp_(min=-15) # fake grad | |
+ neg_f_input__grad_path.data.clamp_(max=17) # fake grad | |
neg_f_input = neg_f_input__grad_path + ( | |
neg_f_input - neg_f_input__grad_path | |
).detach() | |
neg_f = torch.expm1(neg_f_input) | |
- neg_incf = torch.diff(neg_f, dim=-1, prepend=neg_f.new_zeros(()).expand_as(neg_f[..., :1])) | |
+ neg_incf = torch.cat([neg_f.narrow(-1, 0, 1), torch.diff(neg_f, dim=-1)], dim=-1) | |
# reduction | |
if neg_incf.ndim == 3: | |
@@ -62,6 +63,7 @@ def iqe_tensor_delta(x: torch.Tensor, y: torch.Tensor, delta: torch.Tensor, div_ | |
return comp | |
+ | |
def iqe(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
D = x.shape[-1] # D: dim_per_component | |
@@ -87,35 +89,19 @@ def iqe(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
# neg_inp = torch.where(neg_inp_copies == 0, 0., -delta) | |
# f output: 0 -> 0, x -> 1. | |
neg_f = (neg_inp_copies < 0) * (-1.) | |
- neg_incf = torch.diff(neg_f, dim=-1, prepend=neg_f.new_zeros(()).expand_as(neg_f[..., :1])) | |
+ neg_incf = torch.cat([neg_f.narrow(-1, 0, 1), torch.diff(neg_f, dim=-1)], dim=-1) | |
# reduction | |
return (sxy * neg_incf).sum(-1) | |
-def is_notebook(): | |
- r""" | |
- Inspired by | |
- https://github.com/tqdm/tqdm/blob/cc372d09dcd5a5eabdc6ed4cf365bdb0be004d44/tqdm/autonotebook.py | |
- """ | |
- import sys | |
- try: | |
- get_ipython = sys.modules['IPython'].get_ipython | |
- if 'IPKernelApp' not in get_ipython().config: # pragma: no cover | |
- raise ImportError("console") | |
- except Exception: | |
- return False | |
- else: # pragma: no cover | |
- return True | |
- | |
- | |
-if torch.__version__ >= '2.0.1' and not is_notebook(): # well, broken process pool in notebooks | |
+if torch.__version__ >= '2.0.1' and False: # well, broken process pool in notebooks | |
iqe = torch.compile(iqe) | |
iqe_tensor_delta = torch.compile(iqe_tensor_delta) | |
# iqe = torch.compile(iqe, dynamic=True) | |
else: | |
- iqe = torch.jit.script(iqe) # type: ignore | |
- iqe_tensor_delta = torch.jit.script(iqe_tensor_delta) # type: ignore | |
+ iqe = torch.jit.script(iqe) | |
+ iqe_tensor_delta = torch.jit.script(iqe_tensor_delta) | |
class IQE(QuasimetricBase): | |
@@ -245,8 +231,8 @@ class IQE2(IQE): | |
ema_weight: float = 0.95): | |
super().__init__(input_size, dim_per_component, transforms=transforms, reduction=reduction, | |
discount=discount, warn_if_not_quasimetric=warn_if_not_quasimetric) | |
- self.component_dropout_thresh = tuple(component_dropout_thresh) # type: ignore | |
- self.dropout_p_thresh = tuple(dropout_p_thresh) # type: ignore | |
+ self.component_dropout_thresh = tuple(component_dropout_thresh) | |
+ self.dropout_p_thresh = tuple(dropout_p_thresh) | |
self.dropout_batch_frac = float(dropout_batch_frac) | |
self.fake_grad = fake_grad | |
assert 0 <= self.dropout_batch_frac <= 1 | |
@@ -263,7 +249,7 @@ class IQE2(IQE): | |
# ) | |
self.register_parameter( | |
'raw_delta', | |
- torch.nn.Parameter( # type: ignore | |
+ torch.nn.Parameter( | |
torch.zeros(self.latent_2d_shape).requires_grad_() | |
) | |
) | |
@@ -284,7 +270,7 @@ class IQE2(IQE): | |
self.register_parameter( | |
'raw_div', | |
- torch.nn.Parameter(torch.zeros(self.num_components).requires_grad_()) # type: ignore | |
+ torch.nn.Parameter(torch.zeros(self.num_components).requires_grad_()) | |
) | |
else: | |
self.register_buffer( | |
@@ -299,11 +285,11 @@ class IQE2(IQE): | |
self.div_init_mul = div_init_mul | |
self.mul_kind = mul_kind | |
- self.last_components = None # type: ignore | |
- self.last_drop_p = None # type: ignore | |
+ self.last_components = None | |
+ self.last_drop_p = None | |
- def compute_components(self, x: torch.Tensor, y: torch.Tensor, *, symmetric_upperbound: bool = False) -> torch.Tensor: | |
+ def compute_components(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
# if self.raw_delta is None: | |
# components = super().compute_components(x, y) | |
# else: | |
@@ -313,9 +299,6 @@ class IQE2(IQE): | |
delta.data.clamp_(max=1e3 / (self.latent_2d_shape[-1] / 8)) | |
div_pre_f.data.clamp_(min=1e-3) | |
- if symmetric_upperbound: | |
- x, y = torch.minimum(x, y), torch.maximum(x, y) | |
- | |
components = iqe_tensor_delta( | |
x=x.unflatten(-1, self.latent_2d_shape), | |
y=y.unflatten(-1, self.latent_2d_shape), | |
diff --git a/third_party/torch-quasimetric/torchqmet/reductions.py b/third_party/torch-quasimetric/torchqmet/reductions.py | |
index a87be8b..7681242 100644 | |
--- a/third_party/torch-quasimetric/torchqmet/reductions.py | |
+++ b/third_party/torch-quasimetric/torchqmet/reductions.py | |
@@ -59,11 +59,6 @@ class Mean(ReductionBase): | |
return d.mean(dim=-1) | |
-class L2(ReductionBase): | |
- def reduce_distance(self, d: torch.Tensor) -> torch.Tensor: | |
- return d.norm(p=2, dim=-1) | |
- | |
- | |
class MaxMean(ReductionBase): | |
r''' | |
`maxmean` from Neural Norms paper: | |
@@ -149,7 +144,7 @@ class MaxL12_PGsm(ReductionBase): | |
super().__init__(input_num_components=input_num_components, discount=discount) | |
self.raw_alpha = nn.Parameter(torch.tensor([0., 0., 0., 0.], dtype=torch.float32).requires_grad_()) # pre normalizing | |
self.raw_alpha_w = nn.Parameter(torch.tensor([0., 0., 0.], dtype=torch.float32).requires_grad_()) # pre normalizing | |
- self.last_logp = None # type: ignore | |
+ self.last_logp = None | |
self.on_pi = True | |
# self.last_p = None | |
@@ -227,7 +222,7 @@ class MaxL12_PG3(ReductionBase): | |
super().__init__(input_num_components=input_num_components, discount=discount) | |
self.raw_alpha = nn.Parameter(torch.tensor([0., 0., 0.], dtype=torch.float32).requires_grad_()) # pre normalizing | |
self.raw_alpha_w = torch.tensor([], dtype=torch.float32) # just to make logging easier | |
- self.last_logp = None # type: ignore | |
+ self.last_logp = None | |
self.on_pi = True | |
# self.last_p = None | |
@@ -327,7 +322,6 @@ class DeepLinearNetWeightedSum(ReductionBase): | |
REDUCTIONS: Mapping[str, Type[ReductionBase]] = dict( | |
sum=Sum, | |
mean=Mean, | |
- l2=L2, | |
maxmean=MaxMean, | |
maxl12=MaxL12, | |
maxl12_sm=MaxL12_sm, |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment