Created
March 12, 2024 22:21
-
-
Save ssnl/f30c30fb7f2086d88fcf69eaf3aa31a8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/offline/main.py b/offline/main.py | |
index 1581d30..6647508 100644 | |
--- a/offline/main.py | |
+++ b/offline/main.py | |
@@ -58,7 +58,7 @@ cs.store(name='config', node=Conf()) | |
@hydra.main(version_base=None, config_name="config") | |
def train(dict_cfg: DictConfig): | |
cfg: Conf = Conf.from_DictConfig(dict_cfg) # type: ignore | |
- wandb_run = cfg.setup_for_experiment() # checking & setup logging | |
+ cfg.setup_for_experiment() # checking & setup logging | |
dataset = cfg.env.make() | |
@@ -117,7 +117,7 @@ def train(dict_cfg: DictConfig): | |
logging.info(f"Checkpointed to {relpath}") | |
def eval(epoch, it, optim_steps): | |
- val_result_allenvs = trainer.evaluate(desc=f"opt{optim_steps:08d}") | |
+ val_result_allenvs = trainer.evaluate() | |
val_results.clear() | |
val_results.append(dict( | |
epoch=epoch, | |
@@ -170,8 +170,6 @@ def train(dict_cfg: DictConfig): | |
) | |
num_total_epochs = int(np.ceil(cfg.total_optim_steps / trainer.num_batches)) | |
- logging.info(f"save folder: {cfg.output_dir}") | |
- logging.info(f"wandb: {wandb_run.get_url()}") | |
# Training loop | |
optim_steps = 0 | |
@@ -181,7 +179,7 @@ def train(dict_cfg: DictConfig): | |
if start_epoch < num_total_epochs: | |
for epoch in range(num_total_epochs): | |
epoch_desc = f"Train epoch {epoch:05d}/{num_total_epochs:05d}" | |
- for it, (data, data_info) in enumerate(tqdm(trainer.iter_training_data(), total=trainer.num_batches, desc=epoch_desc, leave=False)): | |
+ for it, (data, data_info) in enumerate(tqdm(trainer.iter_training_data(), total=trainer.num_batches, desc=epoch_desc)): | |
step_counter.update_then_record_alerts() | |
optim_steps += 1 | |
diff --git a/offline/trainer.py b/offline/trainer.py | |
index c75abed..bf29c9d 100644 | |
--- a/offline/trainer.py | |
+++ b/offline/trainer.py | |
@@ -93,15 +93,12 @@ class Trainer(object): | |
max_episode_length=env.max_episode_steps) | |
return rollout | |
- def evaluate(self, desc=None) -> Mapping[str, interaction.EvalEpisodeResult]: | |
+ def evaluate(self) -> Mapping[str, interaction.EvalEpisodeResult]: | |
envs = self.dataset.create_eval_envs(self.eval_seed) | |
results: Dict[str, interaction.EvalEpisodeResult] = {} | |
for k, env in envs.items(): | |
rollouts: List[EpisodeData] = [] | |
- this_desc = f'eval/{k}' | |
- if desc is not None: | |
- this_desc = f'{desc}/{this_desc}' | |
- for _ in tqdm(range(self.num_eval_episodes), desc=this_desc): | |
+ for _ in tqdm(range(self.num_eval_episodes), desc=f'eval/{k}'): | |
rollouts.append(self.collect_eval_rollout(env=env)) | |
results[k] = interaction.EvalEpisodeResult.from_episode_rollouts(self.dataset, rollouts) | |
return results | |
diff --git a/online/main.py b/online/main.py | |
index c16e094..2e4400b 100644 | |
--- a/online/main.py | |
+++ b/online/main.py | |
@@ -50,7 +50,7 @@ cs.store(name='config', node=Conf()) | |
@hydra.main(version_base=None, config_name="config") | |
def train(dict_cfg: DictConfig): | |
cfg: Conf = Conf.from_DictConfig(dict_cfg) # type: ignore | |
- wandb_run = cfg.setup_for_experiment() # checking & setup logging | |
+ cfg.setup_for_experiment() # checking & setup logging | |
replay_buffer = cfg.env.make() | |
@@ -85,7 +85,7 @@ def train(dict_cfg: DictConfig): | |
logging.info(f"Checkpointed to {relpath}") | |
def eval(env_steps, optim_steps): | |
- val_result_allenvs = trainer.evaluate(desc=f'env{env_steps:08d}_opt{optim_steps:08d}') | |
+ val_result_allenvs = trainer.evaluate() | |
val_results.clear() | |
val_results.append(dict( | |
env_steps=env_steps, | |
@@ -113,8 +113,6 @@ def train(dict_cfg: DictConfig): | |
), | |
) | |
- logging.info(f"save folder: {cfg.output_dir}") | |
- logging.info(f"wandb: {wandb_run.get_url()}") | |
# Training loop | |
eval(0, 0); save(0, 0) | |
for optim_steps, (env_steps, next_iter_new_env_step, data, data_info) in enumerate(trainer.iter_training_data(), start=1): | |
diff --git a/online/trainer.py b/online/trainer.py | |
index 8926871..4e35dff 100644 | |
--- a/online/trainer.py | |
+++ b/online/trainer.py | |
@@ -133,15 +133,12 @@ class Trainer(object): | |
self.replay.add_rollout(rollout) | |
return rollout | |
- def evaluate(self, desc=None) -> Mapping[str, interaction.EvalEpisodeResult]: | |
+ def evaluate(self) -> Mapping[str, interaction.EvalEpisodeResult]: | |
envs = self.make_evaluate_envs() | |
results: Dict[str, interaction.EvalEpisodeResult] = {} | |
for k, env in envs.items(): | |
rollouts = [] | |
- this_desc = f'eval/{k}' | |
- if desc is not None: | |
- this_desc = f'{desc}/{this_desc}' | |
- for _ in tqdm(range(self.num_eval_episodes), desc=this_desc): | |
+ for _ in tqdm(range(self.num_eval_episodes), desc=f'eval/{k}'): | |
rollouts.append(self.collect_rollout(eval=True, store=False, env=env)) | |
results[k] = interaction.EvalEpisodeResult.from_episode_rollouts( | |
self.replay, rollouts) | |
@@ -160,7 +157,7 @@ class Trainer(object): | |
""" | |
def yield_data(): | |
num_transitions = self.replay.num_transitions_realized | |
- for icyc in tqdm(range(self.num_samples_per_cycle), desc=f"{num_transitions} env steps, train batches", leave=False): | |
+ for icyc in tqdm(range(self.num_samples_per_cycle), desc=f"{num_transitions} env steps, train batches"): | |
data_t0 = time.time() | |
data = self.sample() | |
info = dict( | |
diff --git a/quasimetric_rl/base_conf.py b/quasimetric_rl/base_conf.py | |
index e9f2202..b3a332e 100644 | |
--- a/quasimetric_rl/base_conf.py | |
+++ b/quasimetric_rl/base_conf.py | |
@@ -172,7 +172,7 @@ class BaseConf(abc.ABC): | |
else: | |
raise RuntimeError(f'Output directory {self.output_dir} exists and is complete') | |
- run = wandb.init( | |
+ wandb.init( | |
project=self.wandb_project, | |
name=self.output_folder.replace('/', '__') + '__' + datetime.now().strftime(r"%Y%m%d_%H:%M:%S"), | |
config=yaml.safe_load(OmegaConf.to_yaml(self)), | |
@@ -219,6 +219,3 @@ class BaseConf(abc.ABC): | |
if self.device.type == 'cuda' and self.device.index is not None: | |
torch.cuda.set_device(self.device.index) | |
- | |
- assert run is not None | |
- return run | |
diff --git a/quasimetric_rl/data/base.py b/quasimetric_rl/data/base.py | |
index 09711c4..bf5383f 100644 | |
--- a/quasimetric_rl/data/base.py | |
+++ b/quasimetric_rl/data/base.py | |
@@ -36,7 +36,6 @@ class BatchData(TensorCollectionAttrsMixin): # TensorCollectionAttrsMixin has s | |
timeouts: torch.Tensor | |
future_observations: torch.Tensor # sampled! | |
- future_tdelta: torch.Tensor | |
@property | |
def device(self) -> torch.device: | |
@@ -144,12 +143,6 @@ class EpisodeData(MultiEpisodeData): | |
rewards=self.rewards[:t], | |
terminals=self.terminals[:t], | |
timeouts=self.timeouts[:t], | |
- observation_infos={ | |
- k: v[:t + 1] for k, v in self.observation_infos.items() | |
- }, | |
- transition_infos={ | |
- k: v[:t] for k, v in self.transition_infos.items() | |
- }, | |
) | |
@@ -242,7 +235,6 @@ class Dataset(torch.utils.data.Dataset): | |
indices_to_episode_timesteps: torch.Tensor | |
max_episode_length: int | |
# ----- | |
- device: torch.device | |
def create_env(self, *, dict_obseravtion: Optional[bool] = None, seed: Optional[int] = None, **kwargs) -> 'OfflineEnv': | |
from .offline import OfflineEnv | |
@@ -280,7 +272,6 @@ class Dataset(torch.utils.data.Dataset): | |
def __init__(self, kind: str, name: str, *, | |
future_observation_discount: float, | |
dummy: bool = False, # when you don't want to load data, e.g., in analysis | |
- device: torch.device = torch.device('cpu'), # FIXME: get some heuristic | |
) -> None: | |
self.kind = kind | |
self.name = name | |
@@ -307,22 +298,18 @@ class Dataset(torch.utils.data.Dataset): | |
indices_to_episode_timesteps.append(torch.arange(l, dtype=torch.int64)) | |
assert len(episodes) > 0, "must have at least one episode" | |
- self.raw_data = MultiEpisodeData.cat(episodes).to(device) | |
+ self.raw_data = MultiEpisodeData.cat(episodes) | |
- self.obs_indices_to_obs_index_in_episode = torch.cat(obs_indices_to_obs_index_in_episode, dim=0).to(device) | |
- self.indices_to_episode_indices = torch.cat(indices_to_episode_indices, dim=0).to(device) | |
- self.indices_to_episode_timesteps = torch.cat(indices_to_episode_timesteps, dim=0).to(device) | |
+ self.obs_indices_to_obs_index_in_episode = torch.cat(obs_indices_to_obs_index_in_episode, dim=0) | |
+ self.indices_to_episode_indices = torch.cat(indices_to_episode_indices, dim=0) | |
+ self.indices_to_episode_timesteps = torch.cat(indices_to_episode_timesteps, dim=0) | |
self.max_episode_length = int(self.raw_data.episode_lengths.max().item()) | |
- self.device = device | |
- | |
- # def max_bytes_used(self): | |
- # return self | |
def get_observations(self, obs_indices: torch.Tensor): | |
- return self.raw_data.all_observations[obs_indices.to(self.device)] | |
+ return self.raw_data.all_observations[obs_indices] | |
def __getitem__(self, indices: torch.Tensor) -> BatchData: | |
- indices = torch.as_tensor(indices, device=self.device) | |
+ indices = torch.as_tensor(indices) | |
eindices = self.indices_to_episode_indices[indices] | |
obs_indices = indices + eindices # index for `observation`: skip the s_last from previous episodes | |
obs = self.get_observations(obs_indices) | |
@@ -332,7 +319,7 @@ class Dataset(torch.utils.data.Dataset): | |
tindices = self.indices_to_episode_timesteps[indices] | |
epilengths = self.raw_data.episode_lengths[eindices] # max idx is this | |
- deltas = torch.arange(self.max_episode_length, device=self.device) | |
+ deltas = torch.arange(self.max_episode_length) | |
pdeltas = torch.where( | |
# test tidx + 1 + delta <= max_idx = epi_length | |
(tindices[:, None] + deltas) < epilengths[:, None], | |
@@ -341,16 +328,14 @@ class Dataset(torch.utils.data.Dataset): | |
) | |
deltas = torch.distributions.Categorical( | |
probs=pdeltas, | |
- validate_args=False, | |
- ).sample() + 1 | |
- future_observations = self.get_observations(obs_indices + deltas) | |
+ ).sample() | |
+ future_observations = self.get_observations(obs_indices + 1 + deltas) | |
return BatchData( | |
observations=obs, | |
actions=self.raw_data.actions[indices], | |
next_observations=nobs, | |
future_observations=future_observations, | |
- future_tdelta=deltas, | |
rewards=self.raw_data.rewards[indices], | |
terminals=terminals, | |
timeouts=self.raw_data.timeouts[indices], | |
diff --git a/quasimetric_rl/data/offline/d4rl/antmaze.py b/quasimetric_rl/data/offline/d4rl/antmaze.py | |
index dbf5198..c48ab34 100644 | |
--- a/quasimetric_rl/data/offline/d4rl/antmaze.py | |
+++ b/quasimetric_rl/data/offline/d4rl/antmaze.py | |
@@ -182,13 +182,12 @@ def create_env_antmaze(name, dict_obseravtion: Optional[bool] = None, *, random_ | |
return env | |
-def load_episodes_antmaze(name, normalize_observation=True): | |
+def load_episodes_antmaze(name): | |
env = load_environment(name) | |
d4rl_dataset = cached_d4rl_dataset(name) | |
- if normalize_observation: | |
- # normalize | |
- d4rl_dataset['observations'] = obs_norm(name, d4rl_dataset['observations']) | |
- d4rl_dataset['next_observations'] = obs_norm(name, d4rl_dataset['next_observations']) | |
+ # normalize | |
+ d4rl_dataset['observations'] = obs_norm(name, d4rl_dataset['observations']) | |
+ d4rl_dataset['next_observations'] = obs_norm(name, d4rl_dataset['next_observations']) | |
yield from convert_dict_to_EpisodeData_iter( | |
sequence_dataset( | |
env, | |
@@ -207,11 +206,3 @@ for name in ['antmaze-umaze-v2', 'antmaze-umaze-diverse-v2', | |
normalize_score_fn=functools.partial(get_normalized_score, name), | |
eval_specs=dict(single_task=dict(random_start_goal=False), multi_task=dict(random_start_goal=True)), | |
) | |
- register_offline_env( | |
- 'd4rl', name + '-nonorm', | |
- create_env_fn=functools.partial(create_env_antmaze, name, normalize_observation=False), | |
- load_episodes_fn=functools.partial(load_episodes_antmaze, name, normalize_observation=False), | |
- normalize_score_fn=functools.partial(get_normalized_score, name), | |
- eval_specs=dict(single_task=dict(random_start_goal=False, normalize_observation=False), | |
- multi_task=dict(random_start_goal=True, normalize_observation=False)), | |
- ) | |
diff --git a/quasimetric_rl/data/online/memory.py b/quasimetric_rl/data/online/memory.py | |
index a3445f4..df3aeb6 100644 | |
--- a/quasimetric_rl/data/online/memory.py | |
+++ b/quasimetric_rl/data/online/memory.py | |
@@ -154,7 +154,7 @@ class ReplayBuffer(Dataset): | |
get_empty_episodes( | |
self.env_spec, self.episode_length, | |
int(np.ceil(self.increment_num_transitions / self.episode_length)), | |
- ).to(self.device), | |
+ ), | |
], | |
dim=0, | |
) | |
@@ -165,11 +165,11 @@ class ReplayBuffer(Dataset): | |
# indices_to_episode_timesteps: torch.Tensor | |
self.indices_to_episode_indices = torch.cat([ | |
self.indices_to_episode_indices, | |
- torch.repeat_interleave(torch.arange(original_capacity, new_capacity, device=self.device), self.episode_length), | |
+ torch.repeat_interleave(torch.arange(original_capacity, new_capacity), self.episode_length), | |
], dim=0) | |
self.indices_to_episode_timesteps = torch.cat([ | |
self.indices_to_episode_timesteps, | |
- torch.arange(self.episode_length, device=self.device).repeat(new_capacity - original_capacity), | |
+ torch.arange(self.episode_length).repeat(new_capacity - original_capacity), | |
], dim=0) | |
logging.info(f'ReplayBuffer: Expanded from capacity={original_capacity} to {new_capacity} episodes') | |
@@ -215,8 +215,7 @@ class ReplayBuffer(Dataset): | |
def sample(self, batch_size: int) -> BatchData: | |
indices = torch.as_tensor( | |
- np.random.choice(self.num_transitions_realized, size=[batch_size]), | |
- device=self.device, | |
+ np.random.choice(self.num_transitions_realized, size=[batch_size]) | |
) | |
return self[indices] | |
diff --git a/quasimetric_rl/modules/__init__.py b/quasimetric_rl/modules/__init__.py | |
index 66704ff..79fd35f 100644 | |
--- a/quasimetric_rl/modules/__init__.py | |
+++ b/quasimetric_rl/modules/__init__.py | |
@@ -32,15 +32,13 @@ class QRLLosses(Module): | |
critic_losses: Collection[quasimetric_critic.QuasimetricCriticLosses], | |
critics_total_grad_clip_norm: Optional[float], | |
recompute_critic_for_actor_loss: bool, | |
- critics_share_embedding: bool, | |
- critic_losses_use_target_encoder: bool): | |
+ critics_share_embedding: bool): | |
super().__init__() | |
self.add_module('actor_loss', actor_loss) | |
self.critic_losses = torch.nn.ModuleList(critic_losses) # type: ignore | |
self.critics_total_grad_clip_norm = critics_total_grad_clip_norm | |
self.recompute_critic_for_actor_loss = recompute_critic_for_actor_loss | |
self.critics_share_embedding = critics_share_embedding | |
- self.critic_losses_use_target_encoder = critic_losses_use_target_encoder | |
def forward(self, agent: QRLAgent, data: BatchData, *, optimize: bool = True) -> LossResult: | |
# compute CriticBatchInfo | |
@@ -59,8 +57,8 @@ class QRLLosses(Module): | |
) | |
else: | |
zx = critic.encoder(data.observations) | |
- zy = critic.get_encoder(target=self.critic_losses_use_target_encoder)(data.next_observations) | |
- if critic.has_separate_target_encoder and self.critic_losses_use_target_encoder: | |
+ zy = critic.target_encoder(data.next_observations) | |
+ if critic.has_separate_target_encoder: | |
assert not zy.requires_grad | |
critic_batch_info = quasimetric_critic.CriticBatchInfo( | |
critic=critic, | |
@@ -76,24 +74,19 @@ class QRLLosses(Module): | |
).sum().backward() | |
if self.critics_total_grad_clip_norm is not None: | |
- critic_grad_norm = torch.nn.utils.clip_grad_norm_( | |
- cast(torch.nn.ModuleList, agent.critics).parameters(), | |
- max_norm=self.critics_total_grad_clip_norm) | |
- else: | |
- critic_grad_norm = sum(p.grad.pow(2).sum() for p in cast(torch.nn.ModuleList, agent.critics).parameters() if p.grad is not None) | |
- critic_grad_norm = cast(torch.Tensor, critic_grad_norm).sqrt() | |
+ torch.nn.utils.clip_grad_norm_(cast(torch.nn.ModuleList, agent.critics).parameters(), | |
+ max_norm=self.critics_total_grad_clip_norm) | |
if optimize: | |
for critic in agent.critics: | |
- critic.update_target_models_() | |
+ critic.update_target_encoder_() | |
if self.actor_loss is not None: | |
assert agent.actor is not None | |
with torch.no_grad(), torch.inference_mode(): | |
- for idx, critic in enumerate(agent.critics): # FIXME? | |
- if self.recompute_critic_for_actor_loss or (critic.has_separate_target_encoder and self.actor_loss.use_target_encoder): | |
- zx, zy = critic.get_encoder(target=self.actor_loss.use_target_encoder)( | |
- torch.stack([data.observations, data.next_observations], dim=0)).unbind(0) | |
+ for idx, critic in enumerate(agent.critics): | |
+ if self.recompute_critic_for_actor_loss or critic.has_separate_target_encoder: | |
+ zx, zy = critic.target_encoder(torch.stack([data.observations, data.next_observations], dim=0)).unbind(0) | |
critic_batch_infos[idx] = quasimetric_critic.CriticBatchInfo( | |
critic=critic, | |
zx=zx, | |
@@ -103,13 +96,7 @@ class QRLLosses(Module): | |
loss_results['actor'] = loss_r = self.actor_loss(agent.actor, critic_batch_infos, data) | |
cast(torch.Tensor, loss_r.loss).backward() | |
- actor_grad_norm = sum(p.grad.pow(2).sum() for p in cast(torch.nn.ModuleList, agent.actor).parameters() if p.grad is not None) | |
- actor_grad_norm = cast(torch.Tensor, actor_grad_norm).sqrt() | |
- | |
- else: | |
- actor_grad_norm = 0 | |
- | |
- return LossResult.combine(loss_results, grad_norm=dict(critic=critic_grad_norm, actor=actor_grad_norm)) | |
+ return LossResult.combine(loss_results) | |
# for type hints | |
def __call__(self, agent: QRLAgent, data: BatchData, *, optimize: bool = True) -> LossResult: | |
@@ -155,12 +142,7 @@ class QRLLosses(Module): | |
critic_loss.dynamics_lagrange_mult_sched.load_state_dict(optim_scheds[f"critic_{idx:02d}"]['dynamics_lagrange_mult_sched']) | |
def extra_repr(self) -> str: | |
- return '\n'.join([ | |
- f'recompute_critic_for_actor_loss={self.recompute_critic_for_actor_loss}', | |
- f'critics_share_embedding={self.critics_share_embedding}', | |
- f'critics_total_grad_clip_norm={self.critics_total_grad_clip_norm}', | |
- f'critic_losses_use_target_encoder={self.critic_losses_use_target_encoder}', | |
- ]) | |
+ return f'recompute_critic_for_actor_loss={self.recompute_critic_for_actor_loss}' | |
@attrs.define(kw_only=True) | |
@@ -173,7 +155,6 @@ class QRLConf: | |
default=None, validator=attrs.validators.optional(attrs.validators.gt(0)), | |
) | |
recompute_critic_for_actor_loss: bool = False | |
- critic_losses_use_target_encoder: bool = True | |
def make(self, *, env_spec: EnvSpec, total_optim_steps: int) -> Tuple[QRLAgent, QRLLosses]: | |
if self.actor is None: | |
@@ -193,12 +174,9 @@ class QRLConf: | |
critics.append(critic) | |
critic_losses.append(critic_loss) | |
- return QRLAgent(actor=actor, critics=critics), QRLLosses( | |
- actor_loss=actor_losses, critic_losses=critic_losses, | |
- critics_share_embedding=self.critics_share_embedding, | |
- critics_total_grad_clip_norm=self.critics_total_grad_clip_norm, | |
- recompute_critic_for_actor_loss=self.recompute_critic_for_actor_loss, | |
- critic_losses_use_target_encoder=self.critic_losses_use_target_encoder, | |
- ) | |
+ return QRLAgent(actor=actor, critics=critics), QRLLosses(actor_loss=actor_losses, critic_losses=critic_losses, | |
+ critics_share_embedding=self.critics_share_embedding, | |
+ critics_total_grad_clip_norm=self.critics_total_grad_clip_norm, | |
+ recompute_critic_for_actor_loss=self.recompute_critic_for_actor_loss) | |
__all__ = ['QRLAgent', 'QRLLosses', 'QRLConf', 'InfoT', 'InfoValT'] | |
diff --git a/quasimetric_rl/modules/actor/losses/__init__.py b/quasimetric_rl/modules/actor/losses/__init__.py | |
index f9c78a1..c8f1ec7 100644 | |
--- a/quasimetric_rl/modules/actor/losses/__init__.py | |
+++ b/quasimetric_rl/modules/actor/losses/__init__.py | |
@@ -15,15 +15,12 @@ from ...optim import OptimWrapper, AdamWSpec, LRScheduler | |
class ActorLossBase(LossBase): | |
@abc.abstractmethod | |
- def forward(self, actor: Actor, critic_batch_infos: Collection[CriticBatchInfo], data: BatchData, | |
- use_target_encoder: bool, use_target_quasimetric_model: bool) -> LossResult: | |
+ def forward(self, actor: Actor, critic_batch_infos: Collection[CriticBatchInfo], data: BatchData) -> LossResult: | |
pass | |
# for type hints | |
- def __call__(self, actor: Actor, critic_batch_infos: Collection[CriticBatchInfo], data: BatchData, | |
- use_target_encoder: bool, use_target_quasimetric_model: bool) -> LossResult: | |
- return super().__call__(actor, critic_batch_infos, data, use_target_encoder=use_target_encoder, | |
- use_target_quasimetric_model=use_target_quasimetric_model) | |
+ def __call__(self, actor: Actor, critic_batch_infos: Collection[CriticBatchInfo], data: BatchData) -> LossResult: | |
+ return super().__call__(actor, critic_batch_infos, data) | |
from .min_dist import MinDistLoss | |
@@ -43,9 +40,6 @@ class ActorLosses(ActorLossBase): | |
actor_optim: AdamWSpec.Conf = AdamWSpec.Conf(lr=3e-5) | |
entropy_weight_optim: AdamWSpec.Conf = AdamWSpec.Conf(lr=3e-4) # TODO: move to loss | |
- use_target_encoder: bool = True | |
- use_target_quasimetric_model: bool = True | |
- | |
def make(self, actor: Actor, total_optim_steps: int, env_spec: EnvSpec) -> 'ActorLosses': | |
return ActorLosses( | |
actor, | |
@@ -55,8 +49,6 @@ class ActorLosses(ActorLossBase): | |
advantage_weighted_regression=self.advantage_weighted_regression.make(), | |
actor_optim_spec=self.actor_optim.make(), | |
entropy_weight_optim_spec=self.entropy_weight_optim.make(), | |
- use_target_encoder=self.use_target_encoder, | |
- use_target_quasimetric_model=self.use_target_quasimetric_model, | |
) | |
min_dist: Optional[MinDistLoss] | |
@@ -68,14 +60,10 @@ class ActorLosses(ActorLossBase): | |
entropy_weight_optim: OptimWrapper | |
entropy_weight_sched: LRScheduler | |
- use_target_encoder: bool | |
- use_target_quasimetric_model: bool | |
- | |
def __init__(self, actor: Actor, *, total_optim_steps: int, | |
min_dist: Optional[MinDistLoss], behavior_cloning: Optional[BCLoss], | |
advantage_weighted_regression: Optional[AWRLoss], | |
- actor_optim_spec: AdamWSpec, entropy_weight_optim_spec: AdamWSpec, | |
- use_target_encoder: bool, use_target_quasimetric_model: bool): | |
+ actor_optim_spec: AdamWSpec, entropy_weight_optim_spec: AdamWSpec): | |
super().__init__() | |
self.add_module('min_dist', min_dist) | |
self.add_module('behavior_cloning', behavior_cloning) | |
@@ -88,9 +76,6 @@ class ActorLosses(ActorLossBase): | |
if min_dist is not None: | |
assert len(list(min_dist.parameters())) <= 1 | |
- self.use_target_encoder = use_target_encoder | |
- self.use_target_quasimetric_model = use_target_quasimetric_model | |
- | |
def optimizers(self) -> Iterable[OptimWrapper]: | |
return [self.actor_optim, self.entropy_weight_optim] | |
@@ -101,24 +86,18 @@ class ActorLosses(ActorLossBase): | |
loss_results: Dict[str, LossResult] = {} | |
if self.min_dist is not None: | |
loss_results.update( | |
- min_dist=self.min_dist(actor, critic_batch_infos, data, self.use_target_encoder, self.use_target_quasimetric_model), | |
+ min_dist=self.min_dist(actor, critic_batch_infos, data), | |
) | |
if self.behavior_cloning is not None: | |
loss_results.update( | |
- bc=self.behavior_cloning(actor, critic_batch_infos, data, self.use_target_encoder, self.use_target_quasimetric_model), | |
+ bc=self.behavior_cloning(actor, critic_batch_infos, data), | |
) | |
if self.advantage_weighted_regression is not None: | |
loss_results.update( | |
- awr=self.advantage_weighted_regression(actor, critic_batch_infos, data, self.use_target_encoder, self.use_target_quasimetric_model), | |
+ awr=self.advantage_weighted_regression(actor, critic_batch_infos, data), | |
) | |
return LossResult.combine(loss_results) | |
# for type hints | |
def __call__(self, actor: Actor, critic_batch_infos: Sequence[CriticBatchInfo], data: BatchData) -> LossResult: | |
return torch.nn.Module.__call__(self, actor, critic_batch_infos, data) | |
- | |
- def extra_repr(self) -> str: | |
- return '\n'.join([ | |
- f"actor_optim={self.actor_optim}, entropy_weight_optim={self.entropy_weight_optim}", | |
- f"use_target_encoder={self.use_target_encoder}, use_target_quasimetric_model={self.use_target_quasimetric_model}", | |
- ]) | |
diff --git a/quasimetric_rl/modules/actor/losses/awr.py b/quasimetric_rl/modules/actor/losses/awr.py | |
index 1313dc3..f6cf4de 100644 | |
--- a/quasimetric_rl/modules/actor/losses/awr.py | |
+++ b/quasimetric_rl/modules/actor/losses/awr.py | |
@@ -81,8 +81,8 @@ class AWRLoss(ActorLossBase): | |
self.clamp = clamp | |
self.add_goal_as_future_state = add_goal_as_future_state | |
- def gather_obs_goal_pairs(self, critic_batch_infos: Sequence[CriticBatchInfo], data: BatchData, | |
- use_target_encoder: bool) -> Tuple[torch.Tensor, torch.Tensor, Collection[ActorObsGoalCriticInfo]]: | |
+ def gather_obs_goal_pairs(self, critic_batch_infos: Sequence[CriticBatchInfo], | |
+ data: BatchData) -> Tuple[torch.Tensor, torch.Tensor, Collection[ActorObsGoalCriticInfo]]: | |
r""" | |
Returns ( | |
obs, | |
@@ -115,7 +115,7 @@ class AWRLoss(ActorLossBase): | |
# add future_observations | |
zg = torch.stack([ | |
zg, | |
- critic_batch_info.critic.get_encoder(target=use_target_encoder)(data.future_observations), | |
+ critic_batch_info.critic.target_encoder(data.future_observations), | |
], 0) | |
# zo = zo.expand_as(zg) | |
@@ -127,11 +127,9 @@ class AWRLoss(ActorLossBase): | |
return obs, goal, actor_obs_goal_critic_infos | |
- def forward(self, actor: Actor, critic_batch_infos: Sequence[CriticBatchInfo], data: BatchData, | |
- use_target_encoder: bool, use_target_quasimetric_model: bool) -> LossResult: | |
- | |
+ def forward(self, actor: Actor, critic_batch_infos: Sequence[CriticBatchInfo], data: BatchData) -> LossResult: | |
with torch.no_grad(): | |
- obs, goal, actor_obs_goal_critic_infos = self.gather_obs_goal_pairs(critic_batch_infos, data, use_target_encoder) | |
+ obs, goal, actor_obs_goal_critic_infos = self.gather_obs_goal_pairs(critic_batch_infos, data) | |
info: Dict[str, Union[float, torch.Tensor]] = {} | |
@@ -141,13 +139,11 @@ class AWRLoss(ActorLossBase): | |
for idx, actor_obs_goal_critic_info in enumerate(actor_obs_goal_critic_infos): | |
critic = actor_obs_goal_critic_info.critic | |
- latent_dynamics = critic.latent_dynamics | |
- quasimetric_model = critic.get_quasimetric_model(target=use_target_quasimetric_model) | |
zo = actor_obs_goal_critic_info.zo.detach() # [B,D] | |
zg = actor_obs_goal_critic_info.zg.detach() # [2?,B,D] | |
with torch.no_grad(), critic.mode(False): | |
- zp = latent_dynamics(data.observations, zo, data.actions) # [B,D] | |
- if idx == 0 or not critic.borrowing_embedding: | |
+ zp = critic.latent_dynamics(data.observations, zo, data.actions) # [B,D] | |
+ if not critic.borrowing_embedding: | |
zo, zp, zg = bcast_bshape( | |
(zo, 1), | |
(zp, 1), | |
@@ -155,10 +151,10 @@ class AWRLoss(ActorLossBase): | |
) | |
z = torch.stack([zo, zp], dim=0) # [2,2?,B,D] | |
# z = z[(slice(None),) + (None,) * (zg.ndim - zo.ndim) + (Ellipsis,)] # if zg is batched, add enough dim to be broadcast-able | |
- dist_noact, dist = quasimetric_model(z, zg).unbind(0) # [2,2?,B] -> 2x [2?,B] | |
+ dist_noact, dist = critic.quasimetric_model(z, zg).unbind(0) # [2,2?,B] -> 2x [2?,B] | |
dist_noact = dist_noact.detach() | |
else: | |
- dist = quasimetric_model(zp, zg) | |
+ dist = critic.quasimetric_model(zp, zg) | |
dist_noact = dists_noact[0] | |
info[f'dist_delta_{idx:02d}'] = (dist_noact - dist).mean() | |
info[f'dist_{idx:02d}'] = dist.mean() | |
diff --git a/quasimetric_rl/modules/actor/losses/behavior_cloning.py b/quasimetric_rl/modules/actor/losses/behavior_cloning.py | |
index eba36af..632221e 100644 | |
--- a/quasimetric_rl/modules/actor/losses/behavior_cloning.py | |
+++ b/quasimetric_rl/modules/actor/losses/behavior_cloning.py | |
@@ -34,8 +34,7 @@ class BCLoss(ActorLossBase): | |
super().__init__() | |
self.weight = weight | |
- def forward(self, actor: Actor, critic_batch_infos: Collection[CriticBatchInfo], data: BatchData, | |
- use_target_encoder: bool, use_target_quasimetric_model: bool) -> LossResult: | |
+ def forward(self, actor: Actor, critic_batch_infos: Collection[CriticBatchInfo], data: BatchData) -> LossResult: | |
actor_distn = actor(data.observations, data.future_observations) | |
log_prob: torch.Tensor = actor_distn.log_prob(data.actions).mean() | |
loss = -log_prob * self.weight | |
diff --git a/quasimetric_rl/modules/actor/losses/min_dist.py b/quasimetric_rl/modules/actor/losses/min_dist.py | |
index 1d57609..bc36ee8 100644 | |
--- a/quasimetric_rl/modules/actor/losses/min_dist.py | |
+++ b/quasimetric_rl/modules/actor/losses/min_dist.py | |
@@ -87,8 +87,8 @@ class MinDistLoss(ActorLossBase): | |
self.register_parameter('raw_entropy_weight', None) | |
self.target_entropy = None | |
- def gather_obs_goal_pairs(self, critic_batch_infos: Sequence[CriticBatchInfo], data: BatchData, | |
- use_target_encoder: bool) -> Tuple[torch.Tensor, torch.Tensor, Collection[ActorObsGoalCriticInfo]]: | |
+ def gather_obs_goal_pairs(self, critic_batch_infos: Sequence[CriticBatchInfo], | |
+ data: BatchData) -> Tuple[torch.Tensor, torch.Tensor, Collection[ActorObsGoalCriticInfo]]: | |
r""" | |
Returns ( | |
obs, | |
@@ -121,7 +121,7 @@ class MinDistLoss(ActorLossBase): | |
# add future_observations | |
zg = torch.stack([ | |
zg, | |
- critic_batch_info.critic.get_encoder(target=use_target_encoder)(data.future_observations), | |
+ critic_batch_info.critic.target_encoder(data.future_observations), | |
], 0) | |
# zo = zo.expand_as(zg) | |
@@ -133,10 +133,9 @@ class MinDistLoss(ActorLossBase): | |
return obs, goal, actor_obs_goal_critic_infos | |
- def forward(self, actor: Actor, critic_batch_infos: Sequence[CriticBatchInfo], data: BatchData, | |
- use_target_encoder: bool, use_target_quasimetric_model: bool) -> LossResult: | |
+ def forward(self, actor: Actor, critic_batch_infos: Sequence[CriticBatchInfo], data: BatchData) -> LossResult: | |
with torch.no_grad(): | |
- obs, goal, actor_obs_goal_critic_infos = self.gather_obs_goal_pairs(critic_batch_infos, data, use_target_encoder) | |
+ obs, goal, actor_obs_goal_critic_infos = self.gather_obs_goal_pairs(critic_batch_infos, data) | |
actor_distn = actor(obs, goal) | |
action = actor_distn.rsample() | |
@@ -148,19 +147,17 @@ class MinDistLoss(ActorLossBase): | |
for idx, actor_obs_goal_critic_info in enumerate(actor_obs_goal_critic_infos): | |
critic = actor_obs_goal_critic_info.critic | |
- latent_dynamics = critic.latent_dynamics | |
- quasimetric_model = critic.get_quasimetric_model(target=use_target_quasimetric_model) | |
zo = actor_obs_goal_critic_info.zo.detach() # [B,D] | |
zg = actor_obs_goal_critic_info.zg.detach() # [2?,B,D] | |
with critic.requiring_grad(False), critic.mode(False): | |
- zp = latent_dynamics(data.observations, zo, action) # [2?,B,D] | |
- if idx == 0 or not critic.borrowing_embedding: | |
+ zp = critic.latent_dynamics(data.observations, zo, action) # [2?,B,D] | |
+ if not critic.borrowing_embedding: | |
# action: [2?,B,A] | |
z = torch.stack(torch.broadcast_tensors(zo, zp), dim=0) # [2,2?,B,D] | |
- dist_noact, dist = quasimetric_model(z, zg).unbind(0) # [2,2?,B] -> 2x [2?,B] | |
+ dist_noact, dist = critic.quasimetric_model(z, zg).unbind(0) # [2,2?,B] -> 2x [2?,B] | |
dist_noact = dist_noact.detach() | |
else: | |
- dist = quasimetric_model(zp, zg) # [2?,B] | |
+ dist = critic.quasimetric_model(zp, zg) # [2?,B] | |
dist_noact = dists_noact[0] | |
info[f'dist_delta_{idx:02d}'] = (dist_noact - dist).mean() | |
info[f'dist_{idx:02d}'] = dist.mean() | |
diff --git a/quasimetric_rl/modules/optim.py b/quasimetric_rl/modules/optim.py | |
index 5019e1b..5131ce0 100644 | |
--- a/quasimetric_rl/modules/optim.py | |
+++ b/quasimetric_rl/modules/optim.py | |
@@ -1,7 +1,6 @@ | |
from typing import * | |
import attrs | |
-import logging | |
import contextlib | |
import torch | |
@@ -92,11 +91,9 @@ class AdamWSpec: | |
if len(params) == 0: | |
params = [dict(params=[])] # dummy param group so pytorch doesn't complain | |
for ii in range(len(params)): | |
- pg = params[ii] | |
- if isinstance(pg, Mapping) and 'lr_mul' in pg: # handle lr_multiplier | |
- pg['params'] = params_list = cast(List[torch.Tensor], list(pg['params'])) | |
+ if isinstance(params, Mapping) and 'lr_mul' in params: # handle lr_multiplier | |
+ pg = params[ii] | |
assert 'lr' not in pg | |
- logging.info(f'params (#tensor={len(params_list)}, #={sum(p.numel() for p in params_list)}): lr_mul={pg["lr_mul"]}') | |
pg['lr'] = self.lr * pg['lr_mul'] # type: ignore | |
del pg['lr_mul'] | |
return OptimWrapper( | |
diff --git a/quasimetric_rl/modules/quasimetric_critic/losses/__init__.py b/quasimetric_rl/modules/quasimetric_critic/losses/__init__.py | |
index 48ed700..94f59ba 100644 | |
--- a/quasimetric_rl/modules/quasimetric_critic/losses/__init__.py | |
+++ b/quasimetric_rl/modules/quasimetric_critic/losses/__init__.py | |
@@ -31,7 +31,7 @@ class CriticLossBase(LossBase): | |
return super().__call__(data, critic_batch_info) | |
-from .global_push import GlobalPushLoss, GlobalPushLinearLoss, GlobalPushLogLoss, GlobalPushRBFLoss, GlobalPushNextMSELoss | |
+from .global_push import GlobalPushLoss, GlobalPushLinearLoss, GlobalPushLogLoss, GlobalPushRBFLoss | |
from .local_constraint import LocalConstraintLoss | |
from .latent_dynamics import LatentDynamicsLoss | |
@@ -41,7 +41,6 @@ class QuasimetricCriticLosses(CriticLossBase): | |
class Conf: | |
global_push: GlobalPushLoss.Conf = GlobalPushLoss.Conf() | |
global_push_linear: GlobalPushLinearLoss.Conf = GlobalPushLinearLoss.Conf() | |
- global_push_next_mse: GlobalPushNextMSELoss.Conf = GlobalPushNextMSELoss.Conf() | |
global_push_log: GlobalPushLogLoss.Conf = GlobalPushLogLoss.Conf() | |
global_push_rbf: GlobalPushRBFLoss.Conf = GlobalPushRBFLoss.Conf() | |
local_constraint: LocalConstraintLoss.Conf = LocalConstraintLoss.Conf() | |
@@ -69,7 +68,6 @@ class QuasimetricCriticLosses(CriticLossBase): | |
# global losses | |
global_push=self.global_push.make(), | |
global_push_linear=self.global_push_linear.make(), | |
- global_push_next_mse=self.global_push_next_mse.make(), | |
global_push_log=self.global_push_log.make(), | |
global_push_rbf=self.global_push_rbf.make(), | |
# local loss | |
@@ -92,7 +90,6 @@ class QuasimetricCriticLosses(CriticLossBase): | |
borrowing_embedding: bool | |
global_push: Optional[GlobalPushLoss] | |
global_push_linear: Optional[GlobalPushLinearLoss] | |
- global_push_next_mse: Optional[GlobalPushNextMSELoss] | |
global_push_log: Optional[GlobalPushLogLoss] | |
global_push_rbf: Optional[GlobalPushRBFLoss] | |
local_constraint: Optional[LocalConstraintLoss] | |
@@ -109,8 +106,7 @@ class QuasimetricCriticLosses(CriticLossBase): | |
def __init__(self, critic: QuasimetricCritic, *, total_optim_steps: int, | |
share_embedding_from: Optional[QuasimetricCritic] = None, | |
global_push: Optional[GlobalPushLoss], global_push_linear: Optional[GlobalPushLinearLoss], | |
- global_push_next_mse: Optional[GlobalPushNextMSELoss], global_push_log: Optional[GlobalPushLogLoss], | |
- global_push_rbf: Optional[GlobalPushRBFLoss], | |
+ global_push_log: Optional[GlobalPushLogLoss], global_push_rbf: Optional[GlobalPushRBFLoss], | |
local_constraint: Optional[LocalConstraintLoss], latent_dynamics: LatentDynamicsLoss, | |
critic_optim_spec: AdamWSpec, | |
latent_dynamics_lr_mul: float, | |
@@ -130,7 +126,6 @@ class QuasimetricCriticLosses(CriticLossBase): | |
local_constraint = None | |
self.add_module('global_push', global_push) | |
self.add_module('global_push_linear', global_push_linear) | |
- self.add_module('global_push_next_mse', global_push_next_mse) | |
self.add_module('global_push_log', global_push_log) | |
self.add_module('global_push_rbf', global_push_rbf) | |
self.add_module('local_constraint', local_constraint) | |
@@ -189,8 +184,6 @@ class QuasimetricCriticLosses(CriticLossBase): | |
loss_results.update(global_push=self.global_push(data, critic_batch_info)) | |
if self.global_push_linear is not None: | |
loss_results.update(global_push_linear=self.global_push_linear(data, critic_batch_info)) | |
- if self.global_push_next_mse is not None: | |
- loss_results.update(global_push_next_mse=self.global_push_next_mse(data, critic_batch_info)) | |
if self.global_push_log is not None: | |
loss_results.update(global_push_log=self.global_push_log(data, critic_batch_info)) | |
if self.global_push_rbf is not None: | |
diff --git a/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py b/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py | |
index c4e2656..7c3f9f3 100644 | |
--- a/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py | |
+++ b/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py | |
@@ -53,104 +53,16 @@ from . import CriticLossBase, CriticBatchInfo | |
# return f"weight={self.weight:g}, softplus_beta={self.softplus_beta:g}, softplus_offset={self.softplus_offset:g}" | |
-class GlobalPushLossBase(CriticLossBase): | |
+ | |
+class GlobalPushLoss(CriticLossBase): | |
@attrs.define(kw_only=True) | |
- class Conf(abc.ABC): | |
+ class Conf: | |
# config / argparse uses this to specify behavior | |
enabled: bool = True | |
detach_goal: bool = False | |
detach_proj_goal: bool = False | |
- detach_qmet: bool = False | |
- step_cost: float = attrs.field(default=1., validator=attrs.validators.gt(0)) | |
weight: float = attrs.field(default=1., validator=attrs.validators.gt(0)) | |
- weight_future_goal: float = attrs.field(default=0., validator=attrs.validators.ge(0)) | |
- clamp_max_future_goal: bool = True | |
- | |
- @abc.abstractmethod | |
- def make(self) -> Optional['GlobalPushLossBase']: | |
- if not self.enabled: | |
- return None | |
- | |
- weight: float | |
- weight_future_goal: float | |
- detach_goal: bool | |
- detach_proj_goal: bool | |
- detach_qmet: bool | |
- step_cost: float | |
- clamp_max_future_goal: bool | |
- | |
- def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, | |
- detach_qmet: bool, step_cost: float, clamp_max_future_goal: bool): | |
- super().__init__() | |
- self.weight = weight | |
- self.weight_future_goal = weight_future_goal | |
- self.detach_goal = detach_goal | |
- self.detach_proj_goal = detach_proj_goal | |
- self.detach_qmet = detach_qmet | |
- self.step_cost = step_cost | |
- self.clamp_max_future_goal = clamp_max_future_goal | |
- | |
- def generate_dist_weight(self, data: BatchData, critic_batch_info: CriticBatchInfo): | |
- def get_dist(za: torch.Tensor, zb: torch.Tensor): | |
- if self.detach_goal: | |
- zb = zb.detach() | |
- with critic_batch_info.critic.quasimetric_model.requiring_grad(not self.detach_qmet): | |
- return critic_batch_info.critic.quasimetric_model(za, zb, proj_grad_enabled=(True, not self.detach_proj_goal)) | |
- | |
- # To randomly pair zx, zy, we just roll over zy by 1, because zx and zy | |
- # are latents of randomly ordered random batches. | |
- zgoal = torch.roll(critic_batch_info.zy, 1, dims=0) | |
- yield ( | |
- 'random_goal', | |
- zgoal, | |
- get_dist(critic_batch_info.zx, zgoal), | |
- self.weight, | |
- {}, | |
- ) | |
- if self.weight_future_goal > 0: | |
- zgoal = critic_batch_info.critic.encoder(data.future_observations) | |
- dist = get_dist(critic_batch_info.zx, zgoal) | |
- if self.clamp_max_future_goal: | |
- observed_upper_bound = self.step_cost * data.future_tdelta | |
- info = dict( | |
- ratio_future_observed_dist=(dist / observed_upper_bound).mean(), | |
- exceed_future_observed_dist_rate=(dist > observed_upper_bound).mean(dtype=torch.float32), | |
- ) | |
- dist = dist.clamp_max(self.step_cost * data.future_tdelta) | |
- else: | |
- info = {} | |
- yield ( | |
- 'future_goal', | |
- zgoal, | |
- dist, | |
- self.weight_future_goal,info, | |
- ) | |
- | |
- @abc.abstractmethod | |
- def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor, | |
- dist: torch.Tensor, weight: float, info: Mapping[str, torch.Tensor]) -> LossResult: | |
- raise NotImplementedError | |
- | |
- def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult: | |
- return LossResult.combine( | |
- { | |
- name: self.compute_loss(data, critic_batch_info, zgoal, dist, weight, info) | |
- for name, zgoal, dist, weight, info in self.generate_dist_weight(data, critic_batch_info) | |
- }, | |
- ) | |
- | |
- def extra_repr(self) -> str: | |
- return '\n'.join([ | |
- f"weight={self.weight:g}, detach_proj_goal={self.detach_proj_goal}", | |
- f"weight_future_goal={self.weight_future_goal:g}, detach_qmet={self.detach_qmet}", | |
- f"step_cost={self.step_cost:g}, clamp_max_future_goal={self.clamp_max_future_goal}", | |
- ]) | |
- | |
- | |
-class GlobalPushLoss(GlobalPushLossBase): | |
- @attrs.define(kw_only=True) | |
- class Conf(GlobalPushLossBase.Conf): | |
# smaller => smoother loss | |
softplus_beta: float = attrs.field(default=0.1, validator=attrs.validators.gt(0)) | |
@@ -163,175 +75,99 @@ class GlobalPushLoss(GlobalPushLossBase): | |
return None | |
return GlobalPushLoss( | |
weight=self.weight, | |
- weight_future_goal=self.weight_future_goal, | |
detach_goal=self.detach_goal, | |
detach_proj_goal=self.detach_proj_goal, | |
- detach_qmet=self.detach_qmet, | |
- step_cost=self.step_cost, | |
- clamp_max_future_goal=self.clamp_max_future_goal, | |
softplus_beta=self.softplus_beta, | |
softplus_offset=self.softplus_offset, | |
) | |
+ weight: float | |
+ detach_goal: bool | |
+ detach_proj_goal: bool | |
softplus_beta: float | |
softplus_offset: float | |
- def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, detach_qmet: bool, step_cost: float, clamp_max_future_goal: bool, | |
+ def __init__(self, *, weight: float, detach_goal: bool, detach_proj_goal: bool, | |
softplus_beta: float, softplus_offset: float): | |
- super().__init__(weight=weight, weight_future_goal=weight_future_goal, detach_goal=detach_goal, | |
- detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, clamp_max_future_goal=clamp_max_future_goal) | |
+ super().__init__() | |
+ self.weight = weight | |
+ self.detach_goal = detach_goal | |
+ self.detach_proj_goal = detach_proj_goal | |
self.softplus_beta = softplus_beta | |
self.softplus_offset = softplus_offset | |
- def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor, | |
- dist: torch.Tensor, weight: float, info: Mapping[str, torch.Tensor]) -> LossResult: | |
- dict_info: Dict[str, torch.Tensor] = dict(info) | |
+ def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult: | |
+ # To randomly pair zx, zy, we just roll over zy by 1, because zx and zy | |
+ # are latents of randomly ordered random batches. | |
+ zgoal = torch.roll(critic_batch_info.zy, 1, dims=0) | |
+ if self.detach_goal: | |
+ zgoal = zgoal.detach() | |
+ dists = critic_batch_info.critic.quasimetric_model(critic_batch_info.zx, zgoal, | |
+ proj_grad_enabled=(True, not self.detach_proj_goal)) | |
# Sec 3.2. Transform so that we penalize large distances less. | |
- tsfm_dist: torch.Tensor = F.softplus(self.softplus_offset - dist, beta=self.softplus_beta) # type: ignore | |
+ tsfm_dist: torch.Tensor = F.softplus(self.softplus_offset - dists, beta=self.softplus_beta) # type: ignore | |
tsfm_dist = tsfm_dist.mean() | |
- dict_info.update(dist=dist.mean(), tsfm_dist=tsfm_dist) | |
- return LossResult(loss=tsfm_dist * weight, info=dict_info) | |
+ return LossResult(loss=tsfm_dist * self.weight, info=dict(dist=dists.mean(), tsfm_dist=tsfm_dist)) # type: ignore | |
def extra_repr(self) -> str: | |
- return '\n'.join([ | |
- super().extra_repr(), | |
- f"softplus_beta={self.softplus_beta:g}, softplus_offset={self.softplus_offset:g}", | |
- ]) | |
+ return f"weight={self.weight:g}, detach_proj_goal={self.detach_proj_goal}, softplus_beta={self.softplus_beta:g}, softplus_offset={self.softplus_offset:g}" | |
+ | |
-class GlobalPushLinearLoss(GlobalPushLossBase): | |
+class GlobalPushLinearLoss(CriticLossBase): | |
@attrs.define(kw_only=True) | |
- class Conf(GlobalPushLossBase.Conf): | |
- enabled: bool = False | |
+ class Conf: | |
+ # config / argparse uses this to specify behavior | |
- clamp_max: Optional[float] = attrs.field(default=None, validator=attrs.validators.optional(attrs.validators.gt(0))) | |
+ enabled: bool = False | |
+ detach_goal: bool = False | |
+ detach_proj_goal: bool = False | |
+ weight: float = attrs.field(default=1., validator=attrs.validators.gt(0)) | |
def make(self) -> Optional['GlobalPushLinearLoss']: | |
if not self.enabled: | |
return None | |
return GlobalPushLinearLoss( | |
weight=self.weight, | |
- weight_future_goal=self.weight_future_goal, | |
detach_goal=self.detach_goal, | |
detach_proj_goal=self.detach_proj_goal, | |
- detach_qmet=self.detach_qmet, | |
- step_cost=self.step_cost, | |
- clamp_max_future_goal=self.clamp_max_future_goal, | |
- clamp_max=self.clamp_max, | |
) | |
- clamp_max: Optional[float] | |
- | |
- def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, detach_qmet: bool, step_cost: float, clamp_max_future_goal: bool, | |
- clamp_max: Optional[float]): | |
- super().__init__(weight=weight, weight_future_goal=weight_future_goal, detach_goal=detach_goal, | |
- detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, clamp_max_future_goal=clamp_max_future_goal) | |
- self.clamp_max = clamp_max | |
- | |
- def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor, | |
- dist: torch.Tensor, weight: float, info: Mapping[str, torch.Tensor]) -> LossResult: | |
- dict_info: Dict[str, torch.Tensor] = dict(info) | |
- if self.clamp_max is None: | |
- dist = dist.mean() | |
- dict_info.update(dist=dist) | |
- neg_loss = dist | |
- else: | |
- tsfm_dist = dist.clamp_max(self.clamp_max) | |
- dict_info.update( | |
- dist=dist.mean(), | |
- tsfm_dist=tsfm_dist.mean(), | |
- exceed_rate=(dist >= self.clamp_max).mean(dtype=torch.float32), | |
- ) | |
- neg_loss = tsfm_dist | |
- return LossResult(loss=neg_loss * weight, info=dict_info) | |
- | |
- def extra_repr(self) -> str: | |
- return '\n'.join([ | |
- super().extra_repr(), | |
- f"clamp_max={self.clamp_max!r}", | |
- ]) | |
- | |
- | |
- | |
- | |
-class GlobalPushNextMSELoss(GlobalPushLossBase): | |
- @attrs.define(kw_only=True) | |
- class Conf(GlobalPushLossBase.Conf): | |
- enabled: bool = False | |
- detach_target_dist: bool = True | |
- allow_gt: bool = False | |
- gamma: Optional[float] = attrs.field( | |
- default=None, validator=attrs.validators.optional(attrs.validators.and_( | |
- attrs.validators.gt(0), | |
- attrs.validators.lt(1), | |
- ))) | |
- | |
- def make(self) -> Optional['GlobalPushNextMSELoss']: | |
- if not self.enabled: | |
- return None | |
- return GlobalPushNextMSELoss( | |
- weight=self.weight, | |
- weight_future_goal=self.weight_future_goal, | |
- detach_goal=self.detach_goal, | |
- detach_proj_goal=self.detach_proj_goal, | |
- detach_qmet=self.detach_qmet, | |
- clamp_max_future_goal=self.clamp_max_future_goal, | |
- step_cost=self.step_cost, | |
- detach_target_dist=self.detach_target_dist, | |
- allow_gt=self.allow_gt, | |
- gamma=self.gamma, | |
- ) | |
- | |
- def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, detach_qmet: bool, step_cost: float, clamp_max_future_goal: bool, | |
- detach_target_dist: bool, allow_gt: bool, gamma: Optional[float]): | |
- super().__init__(weight=weight, weight_future_goal=weight_future_goal, detach_goal=detach_goal, | |
- detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, clamp_max_future_goal=clamp_max_future_goal) | |
- self.detach_target_dist = detach_target_dist | |
- self.allow_gt = allow_gt | |
- self.gamma = gamma | |
- | |
- def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor, | |
- dist: torch.Tensor, weight: float, info: Mapping[str, torch.Tensor]) -> LossResult: | |
- dict_info: Dict[str, torch.Tensor] = dict(info) | |
- | |
- with torch.enable_grad(self.detach_target_dist): | |
- # by tri-eq, the actual cost can't be larger that step_cost + d(s', g) | |
- next_dist = critic_batch_info.critic.quasimetric_model( | |
- critic_batch_info.zy, zgoal, proj_grad_enabled=(True, not self.detach_proj_goal) | |
- ) | |
- if self.detach_target_dist: | |
- next_dist = next_dist.detach() | |
- target_dist = self.step_cost + next_dist | |
- | |
- if self.allow_gt: | |
- dist = dist.clamp_max(target_dist) | |
- dict_info.update( | |
- exceed_rate=(dist >= target_dist).mean(dtype=torch.float32), | |
- ) | |
- | |
- if self.gamma is None: | |
- loss = F.mse_loss(dist, target_dist) | |
- else: | |
- loss = F.mse_loss(self.gamma ** dist, self.gamma ** target_dist) | |
+ weight: float | |
+ detach_goal: bool | |
+ detach_proj_goal: bool | |
- dict_info.update( | |
- loss=loss, dist=dist.mean(), target_dist=target_dist.mean() | |
- ) | |
+ def __init__(self, *, weight: float, detach_goal: bool, detach_proj_goal: bool): | |
+ super().__init__() | |
+ self.weight = weight | |
+ self.detach_goal = detach_goal | |
+ self.detach_proj_goal = detach_proj_goal | |
- return LossResult(loss=loss * self.weight, info=dict_info) | |
+ def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult: | |
+ # To randomly pair zx, zy, we just roll over zy by 1, because zx and zy | |
+ # are latents of randomly ordered random batches. | |
+ zgoal = torch.roll(critic_batch_info.zy, 1, dims=0) | |
+ if self.detach_goal: | |
+ zgoal = zgoal.detach() | |
+ dists = critic_batch_info.critic.quasimetric_model(critic_batch_info.zx, zgoal, | |
+ proj_grad_enabled=(True, not self.detach_proj_goal)) | |
+ dists = dists.mean() | |
+ return LossResult(loss=dists * (-self.weight), info=dict(dist=dists)) # type: ignore | |
def extra_repr(self) -> str: | |
- return '\n'.join([ | |
- super().extra_repr(), | |
- "detach_target_dist={self.detach_target_dist}", | |
- ]) | |
+ return f"weight={self.weight:g}, detach_proj_goal={self.detach_proj_goal}" | |
+ | |
-class GlobalPushLogLoss(GlobalPushLossBase): | |
+class GlobalPushLogLoss(CriticLossBase): | |
@attrs.define(kw_only=True) | |
- class Conf(GlobalPushLossBase.Conf): | |
- enabled: bool = False | |
+ class Conf: | |
+ # config / argparse uses this to specify behavior | |
+ enabled: bool = False | |
+ detach_goal: bool = False | |
+ detach_proj_goal: bool = False | |
+ weight: float = attrs.field(default=1., validator=attrs.validators.gt(0)) | |
offset: float = attrs.field(default=1., validator=attrs.validators.gt(0)) | |
def make(self) -> Optional['GlobalPushLogLoss']: | |
@@ -339,46 +175,54 @@ class GlobalPushLogLoss(GlobalPushLossBase): | |
return None | |
return GlobalPushLogLoss( | |
weight=self.weight, | |
- weight_future_goal=self.weight_future_goal, | |
detach_goal=self.detach_goal, | |
detach_proj_goal=self.detach_proj_goal, | |
- detach_qmet=self.detach_qmet, | |
- step_cost=self.step_cost, | |
- clamp_max_future_goal=self.clamp_max_future_goal, | |
offset=self.offset, | |
) | |
+ weight: float | |
+ detach_goal: bool | |
+ detach_proj_goal: bool | |
offset: float | |
- def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, detach_qmet: bool, step_cost: float, clamp_max_future_goal: bool, | |
- offset: float): | |
- super().__init__(weight=weight, weight_future_goal=weight_future_goal, detach_goal=detach_goal, | |
- detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, clamp_max_future_goal=clamp_max_future_goal) | |
+ def __init__(self, *, weight: float, detach_goal: bool, detach_proj_goal: bool, offset: float): | |
+ super().__init__() | |
+ self.weight = weight | |
+ self.detach_goal = detach_goal | |
+ self.detach_proj_goal = detach_proj_goal | |
self.offset = offset | |
- def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor, | |
- dist: torch.Tensor, weight: float, info: Mapping[str, torch.Tensor]) -> LossResult: | |
- dict_info: Dict[str, torch.Tensor] = dict(info) | |
- tsfm_dist: torch.Tensor = -dist.add(self.offset).log() | |
+ def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult: | |
+ # To randomly pair zx, zy, we just roll over zy by 1, because zx and zy | |
+ # are latents of randomly ordered random batches. | |
+ zgoal = torch.roll(critic_batch_info.zy, 1, dims=0) | |
+ if self.detach_goal: | |
+ zgoal = zgoal.detach() | |
+ dists = critic_batch_info.critic.quasimetric_model(critic_batch_info.zx, zgoal, | |
+ proj_grad_enabled=(True, not self.detach_proj_goal)) | |
+ # Sec 3.2. Transform so that we penalize large distances less. | |
+ tsfm_dist: torch.Tensor = -dists.add(self.offset).log() | |
tsfm_dist = tsfm_dist.mean() | |
- dict_info.update(dist=dist.mean(), tsfm_dist=tsfm_dist) | |
- return LossResult(loss=tsfm_dist * weight, info=dict_info) | |
+ return LossResult(loss=tsfm_dist * self.weight, info=dict(dist=dists.mean(), tsfm_dist=tsfm_dist)) # type: ignore | |
def extra_repr(self) -> str: | |
- return '\n'.join([ | |
- super().extra_repr(), | |
- f"offset={self.offset:g}", | |
- ]) | |
+ return f"weight={self.weight:g}, detach_proj_goal={self.detach_proj_goal}, offset={self.offset:g}" | |
+ | |
-class GlobalPushRBFLoss(GlobalPushLossBase): | |
+class GlobalPushRBFLoss(CriticLossBase): | |
# say E[opt T] approx sqrt(2)/2 timeout, so E[opt T^2] approx 1/2 timeout^2 | |
# to emulate log E exp(-2 d^2), where 2 d^2 is around 4, we scale model T with r, and use log E exp(- r^2 T^2), and let r^2 T^2 to be around 4 | |
# so r^2 approx 8 / timeout^2 and r approx 2.82 / timeout. If timeout = 850, this = 300 | |
@attrs.define(kw_only=True) | |
- class Conf(GlobalPushLossBase.Conf): | |
+ class Conf: | |
+ # config / argparse uses this to specify behavior | |
+ | |
enabled: bool = False | |
+ detach_goal: bool = False | |
+ detach_proj_goal: bool = False | |
+ weight: float = attrs.field(default=1., validator=attrs.validators.gt(0)) | |
inv_scale: float = attrs.field(default=300., validator=attrs.validators.ge(1e-3)) | |
@@ -387,37 +231,37 @@ class GlobalPushRBFLoss(GlobalPushLossBase): | |
return None | |
return GlobalPushRBFLoss( | |
weight=self.weight, | |
- weight_future_goal=self.weight_future_goal, | |
detach_goal=self.detach_goal, | |
detach_proj_goal=self.detach_proj_goal, | |
- detach_qmet=self.detach_qmet, | |
- step_cost=self.step_cost, | |
- clamp_max_future_goal=self.clamp_max_future_goal, | |
inv_scale=self.inv_scale, | |
) | |
+ weight: float | |
+ detach_goal: bool | |
+ detach_proj_goal: bool | |
inv_scale: float | |
- def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, detach_qmet: bool, step_cost: float, clamp_max_future_goal: bool, | |
- inv_scale: float): | |
- super().__init__(weight=weight, weight_future_goal=weight_future_goal, detach_goal=detach_goal, | |
- detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, clamp_max_future_goal=clamp_max_future_goal) | |
+ def __init__(self, *, weight: float, detach_goal: bool, detach_proj_goal: bool, inv_scale: float): | |
+ super().__init__() | |
+ self.weight = weight | |
+ self.detach_goal = detach_goal | |
+ self.detach_proj_goal = detach_proj_goal | |
self.inv_scale = inv_scale | |
- def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor, | |
- dist: torch.Tensor, weight: float, info: Mapping[str, torch.Tensor]) -> LossResult: | |
- dict_info: Dict[str, torch.Tensor] = dict(info) | |
- inv_scale = dist.detach().square().mean().div(2).sqrt().clamp(1e-3, self.inv_scale) # make E[d^2]/r^2 approx 2 | |
- tsfm_dist: torch.Tensor = (dist / inv_scale).square().neg().exp() | |
+ def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult: | |
+ # To randomly pair zx, zy, we just roll over zy by 1, because zx and zy | |
+ # are latents of randomly ordered random batches. | |
+ zgoal = torch.roll(critic_batch_info.zy, 1, dims=0) | |
+ if self.detach_goal: | |
+ zgoal = zgoal.detach() | |
+ dists = critic_batch_info.critic.quasimetric_model(critic_batch_info.zx, zgoal, | |
+ proj_grad_enabled=(True, not self.detach_proj_goal)) | |
+ inv_scale = dists.detach().square().mean().div(2).sqrt().clamp(1e-3, self.inv_scale) # make E[d^2]/r^2 approx 2 | |
+ tsfm_dist: torch.Tensor = (dists / inv_scale).square().neg().exp() | |
rbf_potential = tsfm_dist.mean().log() | |
- dict_info.update( | |
- dist=dist.mean(), inv_scale=inv_scale, | |
- tsfm_dist=tsfm_dist, rbf_potential=rbf_potential, | |
- ) | |
- return LossResult(loss=rbf_potential * self.weight, info=dict_info) | |
+ return LossResult(loss=rbf_potential * self.weight, | |
+ info=dict(dist=dists.mean(), inv_scale=inv_scale, | |
+ tsfm_dist=tsfm_dist, rbf_potential=rbf_potential)) # type: ignore | |
def extra_repr(self) -> str: | |
- return '\n'.join([ | |
- super().extra_repr(), | |
- f"inv_scale={self.inv_scale:g}", | |
- ]) | |
+ return f"weight={self.weight:g}, detach_proj_goal={self.detach_proj_goal}, inv_scale={self.inv_scale:g}" | |
diff --git a/quasimetric_rl/modules/quasimetric_critic/losses/latent_dynamics.py b/quasimetric_rl/modules/quasimetric_critic/losses/latent_dynamics.py | |
index 19e4886..f5e8248 100644 | |
--- a/quasimetric_rl/modules/quasimetric_critic/losses/latent_dynamics.py | |
+++ b/quasimetric_rl/modules/quasimetric_critic/losses/latent_dynamics.py | |
@@ -2,7 +2,6 @@ from typing import * | |
import attrs | |
-import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
@@ -26,7 +25,7 @@ class LatentDynamicsLoss(CriticLossBase): | |
# weight: float = attrs.field(default=0.1, validator=attrs.validators.gt(0)) | |
epsilon: float = attrs.field(default=0.25, validator=attrs.validators.gt(0)) | |
- # bidirectional: bool = True | |
+ bidirectional: bool = True | |
gamma: Optional[float] = attrs.field( | |
default=None, validator=attrs.validators.optional(attrs.validators.and_( | |
attrs.validators.gt(0), | |
@@ -38,26 +37,22 @@ class LatentDynamicsLoss(CriticLossBase): | |
detach_qmet: bool = False | |
non_quasimetric_dim_mse_weight: float = attrs.field(default=0., validator=attrs.validators.ge(0)) | |
- kind: str = attrs.field( | |
- default='mean', validator=attrs.validators.in_({'mean', 'unidir', 'max', 'bound', 'bound_l2comp', 'l2comp'})) # type: ignore | |
- | |
def make(self) -> 'LatentDynamicsLoss': | |
return LatentDynamicsLoss( | |
# weight=self.weight, | |
epsilon=self.epsilon, | |
- # bidirectional=self.bidirectional, | |
+ bidirectional=self.bidirectional, | |
gamma=self.gamma, | |
detach_sp=self.detach_sp, | |
detach_proj_sp=self.detach_proj_sp, | |
detach_qmet=self.detach_qmet, | |
non_quasimetric_dim_mse_weight=self.non_quasimetric_dim_mse_weight, | |
init_lagrange_multiplier=self.init_lagrange_multiplier, | |
- kind=self.kind, | |
) | |
# weight: float | |
epsilon: float | |
- # bidirectional: bool | |
+ bidirectional: bool | |
gamma: Optional[float] | |
detach_sp: bool | |
detach_proj_sp: bool | |
@@ -65,24 +60,20 @@ class LatentDynamicsLoss(CriticLossBase): | |
non_quasimetric_dim_mse_weight: float | |
c: float | |
init_lagrange_multiplier: float | |
- kind: str | |
- def __init__(self, *, epsilon: float, | |
- # bidirectional: bool, | |
+ def __init__(self, *, epsilon: float, bidirectional: bool, | |
gamma: Optional[float], init_lagrange_multiplier: float, | |
# weight: float, | |
detach_qmet: bool, detach_proj_sp: bool, detach_sp: bool, | |
- non_quasimetric_dim_mse_weight: float, | |
- kind: str): | |
+ non_quasimetric_dim_mse_weight: float): | |
super().__init__() | |
# self.weight = weight | |
self.epsilon = epsilon | |
- # self.bidirectional = bidirectional | |
+ self.bidirectional = bidirectional | |
self.gamma = gamma | |
self.detach_qmet = detach_qmet | |
self.detach_sp = detach_sp | |
self.detach_proj_sp = detach_proj_sp | |
- self.kind = kind | |
self.non_quasimetric_dim_mse_weight = non_quasimetric_dim_mse_weight | |
self.init_lagrange_multiplier = init_lagrange_multiplier | |
self.raw_lagrange_multiplier = nn.Parameter( | |
@@ -95,58 +86,16 @@ class LatentDynamicsLoss(CriticLossBase): | |
pred_zy = critic.latent_dynamics(data.observations, zx, data.actions) | |
if self.detach_sp: | |
zy = zy.detach() | |
+ with critic.quasimetric_model.requiring_grad(not self.detach_qmet): | |
+ dists = critic.quasimetric_model(zy, pred_zy, bidirectional=self.bidirectional, # at least optimize d(s', hat{s'}) | |
+ proj_grad_enabled=(self.detach_proj_sp, True)) | |
+ | |
lagrange_mult = F.softplus(self.raw_lagrange_multiplier) # make positive | |
# lagrange multiplier is minimax training, so grad_mul -1 | |
lagrange_mult = grad_mul(lagrange_mult, -1) | |
- info: Dict[str, torch.Tensor] = {} | |
- | |
- if self.kind == 'unidir': | |
- with critic.quasimetric_model.requiring_grad(not self.detach_qmet): | |
- dists = critic.quasimetric_model(zy, pred_zy, bidirectional=False, # at least optimize d(s', hat{s'}) | |
- proj_grad_enabled=(self.detach_proj_sp, True)) | |
- info.update(dist_n2p=dists.mean()) | |
- elif self.kind in ['mean', 'max', 'bound', 'l2comp']: | |
- with critic.quasimetric_model.requiring_grad(not self.detach_qmet): | |
- dists_comps = critic.quasimetric_model( | |
- zy, pred_zy, bidirectional=True, proj_grad_enabled=(self.detach_proj_sp, True), | |
- reduced=False) | |
- dists = critic.quasimetric_model.quasimetric_head.reduction(dists_comps) | |
- | |
- dist_p2n, dist_n2p = dists.unbind(-1) | |
- info.update(dist_p2n=dist_p2n.mean(), dist_n2p=dist_n2p.mean()) | |
- if self.kind == 'max': | |
- dists = dists.max(-1).values | |
- elif self.kind == 'bound': | |
- with critic.quasimetric_model.requiring_grad(not self.detach_qmet): | |
- dists = critic.quasimetric_model( | |
- zy, pred_zy, bidirectional=False, # NB: false is enough for bound | |
- proj_grad_enabled=(self.detach_proj_sp, True), | |
- symmetric_upperbound=True, | |
- ) | |
- info.update(dist_upbnd=dists.mean()) | |
- elif self.kind == 'bound_l2comp': | |
- with critic.quasimetric_model.requiring_grad(not self.detach_qmet): | |
- dists_comps = critic.quasimetric_model( | |
- zy, pred_zy, bidirectional=False, # NB: false is enough for bound | |
- proj_grad_enabled=(self.detach_proj_sp, True), | |
- reduced=False, symmetric_upperbound=True, | |
- ) | |
- dists = cast( | |
- torch.Tensor, | |
- dists_comps.norm(dim=-1) / np.sqrt(dists_comps.shape[-1]), | |
- ) | |
- info.update(dist_upbnd_l2comp=dists.mean()) | |
- elif self.kind == 'l2comp': | |
- dists = cast( | |
- torch.Tensor, | |
- dists_comps.norm(dim=-1) / np.sqrt(dists_comps.shape[-1]), | |
- ) | |
- info.update(dist_l2comp=dists.mean()) | |
- else: | |
- raise NotImplementedError(self.kind) | |
- | |
+ info = {} | |
if self.gamma is None: | |
sq_dists = dists.square().mean() | |
violation = (sq_dists - self.epsilon ** 2) | |
@@ -168,6 +117,12 @@ class LatentDynamicsLoss(CriticLossBase): | |
info.update(non_quasimetric_dim_mse=non_quasimetric_dim_mse) | |
loss += non_quasimetric_dim_mse * self.non_quasimetric_dim_mse_weight | |
+ if self.bidirectional: | |
+ dist_p2n, dist_n2p = dists.flatten(0, -2).mean(0).unbind(-1) | |
+ info.update(dist_p2n=dist_p2n, dist_n2p=dist_n2p) | |
+ else: | |
+ info.update(dist_n2p=dists.mean()) | |
+ | |
return LossResult( | |
loss=loss, | |
info=info, # type: ignore | |
@@ -175,7 +130,7 @@ class LatentDynamicsLoss(CriticLossBase): | |
def extra_repr(self) -> str: | |
return '\n'.join([ | |
- f"kind={self.kind}, gamma={self.gamma!r}", | |
- f"epsilon={self.epsilon!r}, detach_sp={self.detach_sp}, detach_proj_sp={self.detach_proj_sp}, detach_qmet={self.detach_qmet}", | |
+ f"epsilon={self.epsilon!r}, bidirectional={self.bidirectional}", | |
+ f"gamma={self.gamma!r}, detach_sp={self.detach_sp}, detach_proj_sp={self.detach_proj_sp}, detach_qmet={self.detach_qmet}", | |
]) | |
# return f"weight={self.weight:g}, detach_sp={self.detach_sp}" | |
diff --git a/quasimetric_rl/modules/quasimetric_critic/models/__init__.py b/quasimetric_rl/modules/quasimetric_critic/models/__init__.py | |
index 266f592..f9e05c4 100644 | |
--- a/quasimetric_rl/modules/quasimetric_critic/models/__init__.py | |
+++ b/quasimetric_rl/modules/quasimetric_critic/models/__init__.py | |
@@ -25,11 +25,6 @@ class QuasimetricCritic(Module): | |
attrs.validators.le(1), | |
))) | |
encoder: Encoder.Conf = Encoder.Conf() | |
- target_quasimetric_model_ema: Optional[float] = attrs.field( | |
- default=None, validator=attrs.validators.optional(attrs.validators.and_( | |
- attrs.validators.ge(0), | |
- attrs.validators.le(1), | |
- ))) | |
quasimetric_model: QuasimetricModel.Conf = QuasimetricModel.Conf() | |
latent_dynamics: Dynamics.Conf = Dynamics.Conf() | |
@@ -47,46 +42,34 @@ class QuasimetricCritic(Module): | |
) | |
return QuasimetricCritic(encoder, quasimetric_model, latent_dynamics, | |
share_embedding_from=share_embedding_from, | |
- target_encoder_ema=self.target_encoder_ema, | |
- target_quasimetric_model_ema=self.target_quasimetric_model_ema) | |
+ target_encoder_ema=self.target_encoder_ema) | |
borrowing_embedding: bool | |
encoder: Encoder | |
_target_encoder: Optional[Encoder] | |
target_encoder_ema: Optional[float] | |
quasimetric_model: QuasimetricModel | |
- _target_quasimetric_model: Optional[QuasimetricModel] | |
- target_quasimetric_model_ema: Optional[float] | |
latent_dynamics: Dynamics # FIXME: add BC to model loading & change name | |
def __init__(self, encoder: Encoder, quasimetric_model: QuasimetricModel, latent_dynamics: Dynamics, | |
- share_embedding_from: Optional['QuasimetricCritic'], target_encoder_ema: Optional[float], | |
- target_quasimetric_model_ema: Optional[float] = None): | |
+ share_embedding_from: Optional['QuasimetricCritic'], target_encoder_ema: Optional[float]): | |
super().__init__() | |
self.borrowing_embedding = share_embedding_from is not None | |
if share_embedding_from is not None: | |
encoder = share_embedding_from.encoder | |
- _target_encoder = share_embedding_from._target_encoder | |
+ target_encoder = share_embedding_from.target_encoder | |
assert target_encoder_ema == share_embedding_from.target_encoder_ema | |
quasimetric_model = share_embedding_from.quasimetric_model | |
- assert target_quasimetric_model_ema == share_embedding_from.target_quasimetric_model_ema | |
- _target_quasimetric_model = share_embedding_from._target_quasimetric_model | |
else: | |
if target_encoder_ema is None: | |
- _target_encoder = None | |
- else: | |
- _target_encoder = copy.deepcopy(encoder).requires_grad_(False).eval() | |
- if target_quasimetric_model_ema is None: | |
- _target_quasimetric_model = None | |
+ target_encoder = None | |
else: | |
- _target_quasimetric_model = copy.deepcopy(quasimetric_model).requires_grad_(False).eval() | |
+ target_encoder = copy.deepcopy(encoder).requires_grad_(False).eval() | |
self.encoder = encoder | |
+ self.add_module('_target_encoder', target_encoder) | |
self.target_encoder_ema = target_encoder_ema | |
- self.add_module('_target_encoder', _target_encoder) | |
self.quasimetric_model = quasimetric_model | |
- self.target_quasimetric_model_ema = target_quasimetric_model_ema | |
- self.add_module('_target_quasimetric_model', _target_quasimetric_model) | |
self.latent_dynamics = latent_dynamics | |
def forward(self, x: torch.Tensor, y: torch.Tensor, *, action: Optional[torch.Tensor] = None) -> torch.Tensor: | |
@@ -108,34 +91,16 @@ class QuasimetricCritic(Module): | |
else: | |
return self.encoder | |
- def get_encoder(self, target: bool = False) -> Encoder: | |
- return self.target_encoder if target else self.encoder | |
- | |
- @property | |
- def target_quasimetric_model(self) -> QuasimetricModel: | |
- if self._target_quasimetric_model is not None: | |
- return self._target_quasimetric_model | |
- else: | |
- return self.quasimetric_model | |
- | |
- def get_quasimetric_model(self, target: bool = False) -> QuasimetricModel: | |
- return self.target_quasimetric_model if target else self.quasimetric_model | |
- | |
@torch.no_grad() | |
- def update_target_models_(self): | |
- if not self.borrowing_embedding: | |
- if self.target_encoder_ema is not None: | |
- assert self._target_encoder is not None | |
- for p, p_target in zip(self.encoder.parameters(), self._target_encoder.parameters()): | |
- p_target.data.lerp_(p.data, 1 - self.target_encoder_ema) | |
- if self.target_quasimetric_model_ema is not None: | |
- assert self._target_quasimetric_model is not None | |
- for p, p_target in zip(self.quasimetric_model.parameters(), self._target_quasimetric_model.parameters()): | |
- p_target.data.lerp_(p.data, 1 - self.target_quasimetric_model_ema) | |
+ def update_target_encoder_(self): | |
+ if not self.borrowing_embedding and self.target_encoder_ema is not None: | |
+ assert self._target_encoder is not None | |
+ for p, p_target in zip(self.encoder.parameters(), self._target_encoder.parameters()): | |
+ p_target.data.lerp_(p.data, 1 - self.target_encoder_ema) | |
# for type hints | |
def __call__(self, x: torch.Tensor, y: torch.Tensor, *, action: Optional[torch.Tensor] = None) -> torch.Tensor: | |
return super().__call__(x, y, action=action) | |
def extra_repr(self) -> str: | |
- return f"borrowing_embedding={self.borrowing_embedding}, target_encoder_ema={self.target_encoder_ema}, target_quasimetric_model_ema={self.target_quasimetric_model_ema}" | |
+ return f"borrowing_embedding={self.borrowing_embedding}" | |
diff --git a/quasimetric_rl/modules/quasimetric_critic/models/latent_dynamics.py b/quasimetric_rl/modules/quasimetric_critic/models/latent_dynamics.py | |
index 2774496..0847076 100644 | |
--- a/quasimetric_rl/modules/quasimetric_critic/models/latent_dynamics.py | |
+++ b/quasimetric_rl/modules/quasimetric_critic/models/latent_dynamics.py | |
@@ -20,7 +20,6 @@ class Dynamics(MLP): | |
# config / argparse uses this to specify behavior | |
arch: Tuple[int, ...] = (512, 512) | |
- action_arch: Tuple[int, ...] = () | |
layer_norm: bool = True | |
latent_input: bool = True | |
raw_observation_input_arch: Optional[Tuple[int, ...]] = None | |
@@ -35,7 +34,6 @@ class Dynamics(MLP): | |
latent_input=self.latent_input, | |
raw_observation_input_arch=self.raw_observation_input_arch, | |
env_spec=env_spec, | |
- action_arch=self.action_arch, | |
hidden_sizes=self.arch, | |
layer_norm=self.layer_norm, | |
use_latent_space_proj=self.use_latent_space_proj, | |
@@ -43,8 +41,7 @@ class Dynamics(MLP): | |
fc_fuse=self.fc_fuse, | |
) | |
- action_input: Union[InputEncoding, nn.Sequential] | |
- action_arch: Tuple[int, ...] | |
+ action_input: InputEncoding | |
latent_input: bool | |
raw_observation_input: Optional[nn.Sequential] | |
residual: bool | |
@@ -53,27 +50,11 @@ class Dynamics(MLP): | |
def __init__(self, *, latent: LatentSpaceConf, | |
latent_input: bool, | |
raw_observation_input_arch: Optional[Tuple[int, ...]], | |
- env_spec: EnvSpec, action_arch: Tuple[int, ...], | |
- hidden_sizes: Tuple[int, ...], layer_norm: bool, | |
+ env_spec: EnvSpec, hidden_sizes: Tuple[int, ...], layer_norm: bool, | |
use_latent_space_proj: bool, residual: bool, | |
fc_fuse: bool): | |
- self.action_arch = action_arch | |
action_input = env_spec.make_action_input() | |
- action_outsize = action_input.output_size | |
- if action_arch: | |
- action_input = nn.Sequential( | |
- action_input, | |
- MLP( | |
- action_input.output_size, | |
- action_arch[-1], | |
- hidden_sizes=action_arch[:-1], | |
- activation_last_fc=True, | |
- layer_norm_last_fc=layer_norm, | |
- layer_norm=layer_norm, | |
- ), | |
- ) | |
- action_outsize = action_arch[-1] | |
- mlp_input_size = action_outsize | |
+ mlp_input_size = action_input.output_size | |
mlp_output_size = latent.latent_size | |
self.latent_input = latent_input | |
if latent_input: | |
diff --git a/quasimetric_rl/modules/quasimetric_critic/models/quasimetric_model.py b/quasimetric_rl/modules/quasimetric_critic/models/quasimetric_model.py | |
index 11577ee..589accf 100644 | |
--- a/quasimetric_rl/modules/quasimetric_critic/models/quasimetric_model.py | |
+++ b/quasimetric_rl/modules/quasimetric_critic/models/quasimetric_model.py | |
@@ -20,7 +20,7 @@ class L2(torchqmet.QuasimetricBase): | |
super().__init__(input_size, num_components=1, guaranteed_quasimetric=True, | |
transforms=[], reduction='sum', discount=None) | |
- def compute_components(self, x: torch.Tensor, y: torch.Tensor, symmetric_upperbound: bool = False) -> torch.Tensor: | |
+ def compute_components(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
r''' | |
Inputs: | |
x (torch.Tensor): Shape [..., input_size] | |
@@ -41,11 +41,11 @@ def create_quasimetric_head_from_spec(spec: str) -> torchqmet.QuasimetricBase: | |
assert dim % components == 0, "IQE: dim must be divisible by components" | |
return torchqmet.IQE(dim, dim // components) | |
- def iqe2(*, dim: int, components: int, norm_delta: bool = False, fake_grad: bool = False, reduction: str = 'maxl12_sm') -> torchqmet.IQE: | |
+ def iqe2(*, dim: int, components: int, scale: bool = False, norm_delta: bool = False, fake_grad: bool = False) -> torchqmet.IQE: | |
assert dim % components == 0, "IQE: dim must be divisible by components" | |
return torchqmet.IQE2( | |
dim, dim // components, | |
- reduction=reduction, | |
+ reduction='maxl12_sm' if not scale else 'maxl12_sm_scale', | |
learned_delta=True, | |
learned_div=False, | |
div_init_mul=0.25, | |
@@ -125,12 +125,10 @@ class QuasimetricModel(Module): | |
layer_norm=projector_layer_norm, | |
dropout=projector_dropout, | |
weight_norm_last_fc=projector_weight_norm, | |
- unit_norm_last_fc=projector_unit_norm, | |
- ) | |
+ unit_norm_last_fc=projector_unit_norm) | |
def forward(self, zx: LatentTensor, zy: LatentTensor, *, bidirectional: bool = False, | |
- proj_grad_enabled: Tuple[bool, bool] = (True, True), reduced: bool = True, | |
- symmetric_upperbound: bool = False) -> torch.Tensor: | |
+ proj_grad_enabled: Tuple[bool, bool] = (True, True)) -> torch.Tensor: | |
zx = zx[..., :self.input_slice_size] | |
zy = zy[..., :self.input_slice_size] | |
with self.projector.requiring_grad(proj_grad_enabled[0]): | |
@@ -142,7 +140,7 @@ class QuasimetricModel(Module): | |
px, py = torch.broadcast_tensors(px, py) | |
px, py = torch.stack([px, py], dim=-2), torch.stack([py, px], dim=-2) # [B x 2 x D] | |
- return self.quasimetric_head(px, py, reduced=reduced, symmetric_upperbound=symmetric_upperbound) | |
+ return self.quasimetric_head(px, py) | |
def parameters(self, recurse: bool = True, *, include_head: bool = True) -> Iterator[Parameter]: | |
if include_head: | |
@@ -155,12 +153,8 @@ class QuasimetricModel(Module): | |
# for type hint | |
def __call__(self, zx: LatentTensor, zy: LatentTensor, *, bidirectional: bool = False, | |
- proj_grad_enabled: Tuple[bool, bool] = (True, True), reduced: bool = True, | |
- symmetric_upperbound: bool = False) -> torch.Tensor: | |
- return super().__call__(zx, zy, bidirectional=bidirectional, | |
- proj_grad_enabled=proj_grad_enabled, | |
- reduced=reduced, | |
- symmetric_upperbound=symmetric_upperbound) | |
+ proj_grad_enabled: Tuple[bool, bool] = (True, True)) -> torch.Tensor: | |
+ return super().__call__(zx, zy, bidirectional=bidirectional, proj_grad_enabled=proj_grad_enabled) | |
def extra_repr(self) -> str: | |
return f"input_size={self.input_size}, input_slice_size={self.input_slice_size}" | |
diff --git a/quasimetric_rl/utils/logging.py b/quasimetric_rl/utils/logging.py | |
index 6efab8d..2a9e23f 100644 | |
--- a/quasimetric_rl/utils/logging.py | |
+++ b/quasimetric_rl/utils/logging.py | |
@@ -4,8 +4,6 @@ | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
-from typing import * | |
- | |
import sys | |
import os | |
import logging | |
@@ -36,12 +34,10 @@ class TqdmLoggingHandler(logging.Handler): | |
class MultiLineFormatter(logging.Formatter): | |
- _fmt: str | |
def __init__(self, fmt=None, datefmt=None, style='%'): | |
assert style == '%' | |
super(MultiLineFormatter, self).__init__(fmt, datefmt, style) | |
- assert fmt is not None | |
self.multiline_fmt = fmt | |
def format(self, record): | |
@@ -79,7 +75,7 @@ class MultiLineFormatter(logging.Formatter): | |
output += '\n'.join( | |
self.multiline_fmt % dict(record.__dict__, message=line) | |
for index, line | |
- in enumerate(record.exc_text.decode(sys.getfilesystemencoding(), 'replace').splitlines()) # type: ignore | |
+ in enumerate(record.exc_text.decode(sys.getfilesystemencoding(), 'replace').splitlines()) | |
) | |
return output | |
@@ -100,7 +96,7 @@ def configure(logging_file, log_level=logging.INFO, level_prefix='', prefix='', | |
sys.excepthook = handle_exception # automatically log uncaught errors | |
- handlers: List[logging.Handler] = [] | |
+ handlers = [] | |
if write_to_stdout: | |
handlers.append(TqdmLoggingHandler()) | |
Submodule third_party/torch-quasimetric 186ed87..129eb90 (rewind): | |
diff --git a/third_party/torch-quasimetric/torchqmet/__init__.py b/third_party/torch-quasimetric/torchqmet/__init__.py | |
index b776de5..3afb467 100644 | |
--- a/third_party/torch-quasimetric/torchqmet/__init__.py | |
+++ b/third_party/torch-quasimetric/torchqmet/__init__.py | |
@@ -54,22 +54,19 @@ class QuasimetricBase(nn.Module, metaclass=abc.ABCMeta): | |
''' | |
pass | |
- def forward(self, x: torch.Tensor, y: torch.Tensor, *, reduced: bool = True, **kwargs) -> torch.Tensor: | |
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
assert x.shape[-1] == y.shape[-1] == self.input_size | |
- d = self.compute_components(x, y, **kwargs) | |
+ d = self.compute_components(x, y) | |
d: torch.Tensor = self.transforms(d) | |
scale = self.scale | |
if not self.training: | |
scale = scale.detach() | |
- if reduced: | |
- return self.reduction(d) * scale | |
- else: | |
- return d * scale | |
+ return self.reduction(d) * scale | |
- def __call__(self, x: torch.Tensor, y: torch.Tensor, reduced: bool = True, **kwargs) -> torch.Tensor: | |
+ def __call__(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
# Manually define for typing | |
# https://github.com/pytorch/pytorch/issues/45414 | |
- return super().__call__(x, y, reduced=reduced, **kwargs) | |
+ return super().__call__(x, y) | |
def extra_repr(self) -> str: | |
return f"guaranteed_quasimetric={self.guaranteed_quasimetric}\ninput_size={self.input_size}, num_components={self.num_components}" + ( | |
@@ -82,5 +79,5 @@ from .iqe import IQE, IQE2 | |
from .mrn import MRN, MRNFixed | |
from .neural_norms import DeepNorm, WideNorm | |
-__all__ = ['PQE', 'PQELH', 'PQEGG', 'IQE', 'IQE2', 'MRN', 'MRNFixed', 'DeepNorm', 'WideNorm'] | |
+__all__ = ['PQE', 'PQELH', 'PQEGG', 'IQE', 'MRN', 'MRNFixed', 'DeepNorm', 'WideNorm'] | |
__version__ = "0.1.0" | |
diff --git a/third_party/torch-quasimetric/torchqmet/iqe.py b/third_party/torch-quasimetric/torchqmet/iqe.py | |
index f084b6d..bc03f05 100644 | |
--- a/third_party/torch-quasimetric/torchqmet/iqe.py | |
+++ b/third_party/torch-quasimetric/torchqmet/iqe.py | |
@@ -12,6 +12,7 @@ from . import QuasimetricBase | |
# The PQELH function. | |
+@torch.jit.script | |
def f_PQELH(h: torch.Tensor): # PQELH: strictly monotonically increasing mapping from [0, +infty) -> [0, 1) | |
return -torch.expm1(-h) | |
@@ -20,28 +21,28 @@ def iqe_tensor_delta(x: torch.Tensor, y: torch.Tensor, delta: torch.Tensor, div_ | |
D = x.shape[-1] # D: component_dim | |
# ignore pairs that x >= y | |
- valid = (x < y) # [..., K, D] | |
+ valid = (x < y) | |
# sort to better count | |
- xy = torch.cat(torch.broadcast_tensors(x, y), dim=-1) # [..., K, 2D] | |
+ xy = torch.cat(torch.broadcast_tensors(x, y), dim=-1) | |
sxy, ixy = xy.sort(dim=-1) | |
# neg_inc: the **negated** increment of **input** of f at sorted locations | |
# inc = torch.gather(delta * valid, dim=-1, index=ixy % D) * torch.where(ixy < D, 1, -1) | |
- neg_inc = torch.gather(delta * valid, dim=-1, index=ixy % D) * torch.where(ixy < D, -1, 1) # [..., K, 2D-sort] | |
+ neg_inc = torch.gather(delta * valid, dim=-1, index=ixy % D) * torch.where(ixy < D, -1, 1) | |
# neg_incf: the **negated** increment of **output** of f at sorted locations | |
neg_f_input = torch.cumsum(neg_inc, dim=-1) / div_pre_f[:, None] | |
if fake_grad: | |
neg_f_input__grad_path = neg_f_input.clone() | |
- neg_f_input__grad_path.data.clamp_(min=-15) # fake grad | |
+ neg_f_input__grad_path.data.clamp_(max=17) # fake grad | |
neg_f_input = neg_f_input__grad_path + ( | |
neg_f_input - neg_f_input__grad_path | |
).detach() | |
neg_f = torch.expm1(neg_f_input) | |
- neg_incf = torch.diff(neg_f, dim=-1, prepend=neg_f.new_zeros(()).expand_as(neg_f[..., :1])) | |
+ neg_incf = torch.cat([neg_f.narrow(-1, 0, 1), torch.diff(neg_f, dim=-1)], dim=-1) | |
# reduction | |
if neg_incf.ndim == 3: | |
@@ -62,6 +63,7 @@ def iqe_tensor_delta(x: torch.Tensor, y: torch.Tensor, delta: torch.Tensor, div_ | |
return comp | |
+ | |
def iqe(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
D = x.shape[-1] # D: dim_per_component | |
@@ -87,35 +89,19 @@ def iqe(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
# neg_inp = torch.where(neg_inp_copies == 0, 0., -delta) | |
# f output: 0 -> 0, x -> 1. | |
neg_f = (neg_inp_copies < 0) * (-1.) | |
- neg_incf = torch.diff(neg_f, dim=-1, prepend=neg_f.new_zeros(()).expand_as(neg_f[..., :1])) | |
+ neg_incf = torch.cat([neg_f.narrow(-1, 0, 1), torch.diff(neg_f, dim=-1)], dim=-1) | |
# reduction | |
return (sxy * neg_incf).sum(-1) | |
-def is_notebook(): | |
- r""" | |
- Inspired by | |
- https://github.com/tqdm/tqdm/blob/cc372d09dcd5a5eabdc6ed4cf365bdb0be004d44/tqdm/autonotebook.py | |
- """ | |
- import sys | |
- try: | |
- get_ipython = sys.modules['IPython'].get_ipython | |
- if 'IPKernelApp' not in get_ipython().config: # pragma: no cover | |
- raise ImportError("console") | |
- except Exception: | |
- return False | |
- else: # pragma: no cover | |
- return True | |
- | |
- | |
-if torch.__version__ >= '2.0.1' and not is_notebook(): # well, broken process pool in notebooks | |
+if torch.__version__ >= '2.0.1' and False: # well, broken process pool in notebooks | |
iqe = torch.compile(iqe) | |
iqe_tensor_delta = torch.compile(iqe_tensor_delta) | |
# iqe = torch.compile(iqe, dynamic=True) | |
else: | |
- iqe = torch.jit.script(iqe) # type: ignore | |
- iqe_tensor_delta = torch.jit.script(iqe_tensor_delta) # type: ignore | |
+ iqe = torch.jit.script(iqe) | |
+ iqe_tensor_delta = torch.jit.script(iqe_tensor_delta) | |
class IQE(QuasimetricBase): | |
@@ -245,8 +231,8 @@ class IQE2(IQE): | |
ema_weight: float = 0.95): | |
super().__init__(input_size, dim_per_component, transforms=transforms, reduction=reduction, | |
discount=discount, warn_if_not_quasimetric=warn_if_not_quasimetric) | |
- self.component_dropout_thresh = tuple(component_dropout_thresh) # type: ignore | |
- self.dropout_p_thresh = tuple(dropout_p_thresh) # type: ignore | |
+ self.component_dropout_thresh = tuple(component_dropout_thresh) | |
+ self.dropout_p_thresh = tuple(dropout_p_thresh) | |
self.dropout_batch_frac = float(dropout_batch_frac) | |
self.fake_grad = fake_grad | |
assert 0 <= self.dropout_batch_frac <= 1 | |
@@ -263,7 +249,7 @@ class IQE2(IQE): | |
# ) | |
self.register_parameter( | |
'raw_delta', | |
- torch.nn.Parameter( # type: ignore | |
+ torch.nn.Parameter( | |
torch.zeros(self.latent_2d_shape).requires_grad_() | |
) | |
) | |
@@ -284,7 +270,7 @@ class IQE2(IQE): | |
self.register_parameter( | |
'raw_div', | |
- torch.nn.Parameter(torch.zeros(self.num_components).requires_grad_()) # type: ignore | |
+ torch.nn.Parameter(torch.zeros(self.num_components).requires_grad_()) | |
) | |
else: | |
self.register_buffer( | |
@@ -299,11 +285,11 @@ class IQE2(IQE): | |
self.div_init_mul = div_init_mul | |
self.mul_kind = mul_kind | |
- self.last_components = None # type: ignore | |
- self.last_drop_p = None # type: ignore | |
+ self.last_components = None | |
+ self.last_drop_p = None | |
- def compute_components(self, x: torch.Tensor, y: torch.Tensor, *, symmetric_upperbound: bool = False) -> torch.Tensor: | |
+ def compute_components(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
# if self.raw_delta is None: | |
# components = super().compute_components(x, y) | |
# else: | |
@@ -313,9 +299,6 @@ class IQE2(IQE): | |
delta.data.clamp_(max=1e3 / (self.latent_2d_shape[-1] / 8)) | |
div_pre_f.data.clamp_(min=1e-3) | |
- if symmetric_upperbound: | |
- x, y = torch.minimum(x, y), torch.maximum(x, y) | |
- | |
components = iqe_tensor_delta( | |
x=x.unflatten(-1, self.latent_2d_shape), | |
y=y.unflatten(-1, self.latent_2d_shape), | |
diff --git a/third_party/torch-quasimetric/torchqmet/reductions.py b/third_party/torch-quasimetric/torchqmet/reductions.py | |
index a87be8b..7681242 100644 | |
--- a/third_party/torch-quasimetric/torchqmet/reductions.py | |
+++ b/third_party/torch-quasimetric/torchqmet/reductions.py | |
@@ -59,11 +59,6 @@ class Mean(ReductionBase): | |
return d.mean(dim=-1) | |
-class L2(ReductionBase): | |
- def reduce_distance(self, d: torch.Tensor) -> torch.Tensor: | |
- return d.norm(p=2, dim=-1) | |
- | |
- | |
class MaxMean(ReductionBase): | |
r''' | |
`maxmean` from Neural Norms paper: | |
@@ -149,7 +144,7 @@ class MaxL12_PGsm(ReductionBase): | |
super().__init__(input_num_components=input_num_components, discount=discount) | |
self.raw_alpha = nn.Parameter(torch.tensor([0., 0., 0., 0.], dtype=torch.float32).requires_grad_()) # pre normalizing | |
self.raw_alpha_w = nn.Parameter(torch.tensor([0., 0., 0.], dtype=torch.float32).requires_grad_()) # pre normalizing | |
- self.last_logp = None # type: ignore | |
+ self.last_logp = None | |
self.on_pi = True | |
# self.last_p = None | |
@@ -227,7 +222,7 @@ class MaxL12_PG3(ReductionBase): | |
super().__init__(input_num_components=input_num_components, discount=discount) | |
self.raw_alpha = nn.Parameter(torch.tensor([0., 0., 0.], dtype=torch.float32).requires_grad_()) # pre normalizing | |
self.raw_alpha_w = torch.tensor([], dtype=torch.float32) # just to make logging easier | |
- self.last_logp = None # type: ignore | |
+ self.last_logp = None | |
self.on_pi = True | |
# self.last_p = None | |
@@ -327,7 +322,6 @@ class DeepLinearNetWeightedSum(ReductionBase): | |
REDUCTIONS: Mapping[str, Type[ReductionBase]] = dict( | |
sum=Sum, | |
mean=Mean, | |
- l2=L2, | |
maxmean=MaxMean, | |
maxl12=MaxL12, | |
maxl12_sm=MaxL12_sm, |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment