Skip to content

Instantly share code, notes, and snippets.

@ssnl
Created March 12, 2024 22:21
Show Gist options
  • Save ssnl/f30c30fb7f2086d88fcf69eaf3aa31a8 to your computer and use it in GitHub Desktop.
Save ssnl/f30c30fb7f2086d88fcf69eaf3aa31a8 to your computer and use it in GitHub Desktop.
diff --git a/offline/main.py b/offline/main.py
index 1581d30..6647508 100644
--- a/offline/main.py
+++ b/offline/main.py
@@ -58,7 +58,7 @@ cs.store(name='config', node=Conf())
@hydra.main(version_base=None, config_name="config")
def train(dict_cfg: DictConfig):
cfg: Conf = Conf.from_DictConfig(dict_cfg) # type: ignore
- wandb_run = cfg.setup_for_experiment() # checking & setup logging
+ cfg.setup_for_experiment() # checking & setup logging
dataset = cfg.env.make()
@@ -117,7 +117,7 @@ def train(dict_cfg: DictConfig):
logging.info(f"Checkpointed to {relpath}")
def eval(epoch, it, optim_steps):
- val_result_allenvs = trainer.evaluate(desc=f"opt{optim_steps:08d}")
+ val_result_allenvs = trainer.evaluate()
val_results.clear()
val_results.append(dict(
epoch=epoch,
@@ -170,8 +170,6 @@ def train(dict_cfg: DictConfig):
)
num_total_epochs = int(np.ceil(cfg.total_optim_steps / trainer.num_batches))
- logging.info(f"save folder: {cfg.output_dir}")
- logging.info(f"wandb: {wandb_run.get_url()}")
# Training loop
optim_steps = 0
@@ -181,7 +179,7 @@ def train(dict_cfg: DictConfig):
if start_epoch < num_total_epochs:
for epoch in range(num_total_epochs):
epoch_desc = f"Train epoch {epoch:05d}/{num_total_epochs:05d}"
- for it, (data, data_info) in enumerate(tqdm(trainer.iter_training_data(), total=trainer.num_batches, desc=epoch_desc, leave=False)):
+ for it, (data, data_info) in enumerate(tqdm(trainer.iter_training_data(), total=trainer.num_batches, desc=epoch_desc)):
step_counter.update_then_record_alerts()
optim_steps += 1
diff --git a/offline/trainer.py b/offline/trainer.py
index c75abed..bf29c9d 100644
--- a/offline/trainer.py
+++ b/offline/trainer.py
@@ -93,15 +93,12 @@ class Trainer(object):
max_episode_length=env.max_episode_steps)
return rollout
- def evaluate(self, desc=None) -> Mapping[str, interaction.EvalEpisodeResult]:
+ def evaluate(self) -> Mapping[str, interaction.EvalEpisodeResult]:
envs = self.dataset.create_eval_envs(self.eval_seed)
results: Dict[str, interaction.EvalEpisodeResult] = {}
for k, env in envs.items():
rollouts: List[EpisodeData] = []
- this_desc = f'eval/{k}'
- if desc is not None:
- this_desc = f'{desc}/{this_desc}'
- for _ in tqdm(range(self.num_eval_episodes), desc=this_desc):
+ for _ in tqdm(range(self.num_eval_episodes), desc=f'eval/{k}'):
rollouts.append(self.collect_eval_rollout(env=env))
results[k] = interaction.EvalEpisodeResult.from_episode_rollouts(self.dataset, rollouts)
return results
diff --git a/online/main.py b/online/main.py
index c16e094..2e4400b 100644
--- a/online/main.py
+++ b/online/main.py
@@ -50,7 +50,7 @@ cs.store(name='config', node=Conf())
@hydra.main(version_base=None, config_name="config")
def train(dict_cfg: DictConfig):
cfg: Conf = Conf.from_DictConfig(dict_cfg) # type: ignore
- wandb_run = cfg.setup_for_experiment() # checking & setup logging
+ cfg.setup_for_experiment() # checking & setup logging
replay_buffer = cfg.env.make()
@@ -85,7 +85,7 @@ def train(dict_cfg: DictConfig):
logging.info(f"Checkpointed to {relpath}")
def eval(env_steps, optim_steps):
- val_result_allenvs = trainer.evaluate(desc=f'env{env_steps:08d}_opt{optim_steps:08d}')
+ val_result_allenvs = trainer.evaluate()
val_results.clear()
val_results.append(dict(
env_steps=env_steps,
@@ -113,8 +113,6 @@ def train(dict_cfg: DictConfig):
),
)
- logging.info(f"save folder: {cfg.output_dir}")
- logging.info(f"wandb: {wandb_run.get_url()}")
# Training loop
eval(0, 0); save(0, 0)
for optim_steps, (env_steps, next_iter_new_env_step, data, data_info) in enumerate(trainer.iter_training_data(), start=1):
diff --git a/online/trainer.py b/online/trainer.py
index 8926871..4e35dff 100644
--- a/online/trainer.py
+++ b/online/trainer.py
@@ -133,15 +133,12 @@ class Trainer(object):
self.replay.add_rollout(rollout)
return rollout
- def evaluate(self, desc=None) -> Mapping[str, interaction.EvalEpisodeResult]:
+ def evaluate(self) -> Mapping[str, interaction.EvalEpisodeResult]:
envs = self.make_evaluate_envs()
results: Dict[str, interaction.EvalEpisodeResult] = {}
for k, env in envs.items():
rollouts = []
- this_desc = f'eval/{k}'
- if desc is not None:
- this_desc = f'{desc}/{this_desc}'
- for _ in tqdm(range(self.num_eval_episodes), desc=this_desc):
+ for _ in tqdm(range(self.num_eval_episodes), desc=f'eval/{k}'):
rollouts.append(self.collect_rollout(eval=True, store=False, env=env))
results[k] = interaction.EvalEpisodeResult.from_episode_rollouts(
self.replay, rollouts)
@@ -160,7 +157,7 @@ class Trainer(object):
"""
def yield_data():
num_transitions = self.replay.num_transitions_realized
- for icyc in tqdm(range(self.num_samples_per_cycle), desc=f"{num_transitions} env steps, train batches", leave=False):
+ for icyc in tqdm(range(self.num_samples_per_cycle), desc=f"{num_transitions} env steps, train batches"):
data_t0 = time.time()
data = self.sample()
info = dict(
diff --git a/quasimetric_rl/base_conf.py b/quasimetric_rl/base_conf.py
index e9f2202..b3a332e 100644
--- a/quasimetric_rl/base_conf.py
+++ b/quasimetric_rl/base_conf.py
@@ -172,7 +172,7 @@ class BaseConf(abc.ABC):
else:
raise RuntimeError(f'Output directory {self.output_dir} exists and is complete')
- run = wandb.init(
+ wandb.init(
project=self.wandb_project,
name=self.output_folder.replace('/', '__') + '__' + datetime.now().strftime(r"%Y%m%d_%H:%M:%S"),
config=yaml.safe_load(OmegaConf.to_yaml(self)),
@@ -219,6 +219,3 @@ class BaseConf(abc.ABC):
if self.device.type == 'cuda' and self.device.index is not None:
torch.cuda.set_device(self.device.index)
-
- assert run is not None
- return run
diff --git a/quasimetric_rl/data/base.py b/quasimetric_rl/data/base.py
index 09711c4..bf5383f 100644
--- a/quasimetric_rl/data/base.py
+++ b/quasimetric_rl/data/base.py
@@ -36,7 +36,6 @@ class BatchData(TensorCollectionAttrsMixin): # TensorCollectionAttrsMixin has s
timeouts: torch.Tensor
future_observations: torch.Tensor # sampled!
- future_tdelta: torch.Tensor
@property
def device(self) -> torch.device:
@@ -144,12 +143,6 @@ class EpisodeData(MultiEpisodeData):
rewards=self.rewards[:t],
terminals=self.terminals[:t],
timeouts=self.timeouts[:t],
- observation_infos={
- k: v[:t + 1] for k, v in self.observation_infos.items()
- },
- transition_infos={
- k: v[:t] for k, v in self.transition_infos.items()
- },
)
@@ -242,7 +235,6 @@ class Dataset(torch.utils.data.Dataset):
indices_to_episode_timesteps: torch.Tensor
max_episode_length: int
# -----
- device: torch.device
def create_env(self, *, dict_obseravtion: Optional[bool] = None, seed: Optional[int] = None, **kwargs) -> 'OfflineEnv':
from .offline import OfflineEnv
@@ -280,7 +272,6 @@ class Dataset(torch.utils.data.Dataset):
def __init__(self, kind: str, name: str, *,
future_observation_discount: float,
dummy: bool = False, # when you don't want to load data, e.g., in analysis
- device: torch.device = torch.device('cpu'), # FIXME: get some heuristic
) -> None:
self.kind = kind
self.name = name
@@ -307,22 +298,18 @@ class Dataset(torch.utils.data.Dataset):
indices_to_episode_timesteps.append(torch.arange(l, dtype=torch.int64))
assert len(episodes) > 0, "must have at least one episode"
- self.raw_data = MultiEpisodeData.cat(episodes).to(device)
+ self.raw_data = MultiEpisodeData.cat(episodes)
- self.obs_indices_to_obs_index_in_episode = torch.cat(obs_indices_to_obs_index_in_episode, dim=0).to(device)
- self.indices_to_episode_indices = torch.cat(indices_to_episode_indices, dim=0).to(device)
- self.indices_to_episode_timesteps = torch.cat(indices_to_episode_timesteps, dim=0).to(device)
+ self.obs_indices_to_obs_index_in_episode = torch.cat(obs_indices_to_obs_index_in_episode, dim=0)
+ self.indices_to_episode_indices = torch.cat(indices_to_episode_indices, dim=0)
+ self.indices_to_episode_timesteps = torch.cat(indices_to_episode_timesteps, dim=0)
self.max_episode_length = int(self.raw_data.episode_lengths.max().item())
- self.device = device
-
- # def max_bytes_used(self):
- # return self
def get_observations(self, obs_indices: torch.Tensor):
- return self.raw_data.all_observations[obs_indices.to(self.device)]
+ return self.raw_data.all_observations[obs_indices]
def __getitem__(self, indices: torch.Tensor) -> BatchData:
- indices = torch.as_tensor(indices, device=self.device)
+ indices = torch.as_tensor(indices)
eindices = self.indices_to_episode_indices[indices]
obs_indices = indices + eindices # index for `observation`: skip the s_last from previous episodes
obs = self.get_observations(obs_indices)
@@ -332,7 +319,7 @@ class Dataset(torch.utils.data.Dataset):
tindices = self.indices_to_episode_timesteps[indices]
epilengths = self.raw_data.episode_lengths[eindices] # max idx is this
- deltas = torch.arange(self.max_episode_length, device=self.device)
+ deltas = torch.arange(self.max_episode_length)
pdeltas = torch.where(
# test tidx + 1 + delta <= max_idx = epi_length
(tindices[:, None] + deltas) < epilengths[:, None],
@@ -341,16 +328,14 @@ class Dataset(torch.utils.data.Dataset):
)
deltas = torch.distributions.Categorical(
probs=pdeltas,
- validate_args=False,
- ).sample() + 1
- future_observations = self.get_observations(obs_indices + deltas)
+ ).sample()
+ future_observations = self.get_observations(obs_indices + 1 + deltas)
return BatchData(
observations=obs,
actions=self.raw_data.actions[indices],
next_observations=nobs,
future_observations=future_observations,
- future_tdelta=deltas,
rewards=self.raw_data.rewards[indices],
terminals=terminals,
timeouts=self.raw_data.timeouts[indices],
diff --git a/quasimetric_rl/data/offline/d4rl/antmaze.py b/quasimetric_rl/data/offline/d4rl/antmaze.py
index dbf5198..c48ab34 100644
--- a/quasimetric_rl/data/offline/d4rl/antmaze.py
+++ b/quasimetric_rl/data/offline/d4rl/antmaze.py
@@ -182,13 +182,12 @@ def create_env_antmaze(name, dict_obseravtion: Optional[bool] = None, *, random_
return env
-def load_episodes_antmaze(name, normalize_observation=True):
+def load_episodes_antmaze(name):
env = load_environment(name)
d4rl_dataset = cached_d4rl_dataset(name)
- if normalize_observation:
- # normalize
- d4rl_dataset['observations'] = obs_norm(name, d4rl_dataset['observations'])
- d4rl_dataset['next_observations'] = obs_norm(name, d4rl_dataset['next_observations'])
+ # normalize
+ d4rl_dataset['observations'] = obs_norm(name, d4rl_dataset['observations'])
+ d4rl_dataset['next_observations'] = obs_norm(name, d4rl_dataset['next_observations'])
yield from convert_dict_to_EpisodeData_iter(
sequence_dataset(
env,
@@ -207,11 +206,3 @@ for name in ['antmaze-umaze-v2', 'antmaze-umaze-diverse-v2',
normalize_score_fn=functools.partial(get_normalized_score, name),
eval_specs=dict(single_task=dict(random_start_goal=False), multi_task=dict(random_start_goal=True)),
)
- register_offline_env(
- 'd4rl', name + '-nonorm',
- create_env_fn=functools.partial(create_env_antmaze, name, normalize_observation=False),
- load_episodes_fn=functools.partial(load_episodes_antmaze, name, normalize_observation=False),
- normalize_score_fn=functools.partial(get_normalized_score, name),
- eval_specs=dict(single_task=dict(random_start_goal=False, normalize_observation=False),
- multi_task=dict(random_start_goal=True, normalize_observation=False)),
- )
diff --git a/quasimetric_rl/data/online/memory.py b/quasimetric_rl/data/online/memory.py
index a3445f4..df3aeb6 100644
--- a/quasimetric_rl/data/online/memory.py
+++ b/quasimetric_rl/data/online/memory.py
@@ -154,7 +154,7 @@ class ReplayBuffer(Dataset):
get_empty_episodes(
self.env_spec, self.episode_length,
int(np.ceil(self.increment_num_transitions / self.episode_length)),
- ).to(self.device),
+ ),
],
dim=0,
)
@@ -165,11 +165,11 @@ class ReplayBuffer(Dataset):
# indices_to_episode_timesteps: torch.Tensor
self.indices_to_episode_indices = torch.cat([
self.indices_to_episode_indices,
- torch.repeat_interleave(torch.arange(original_capacity, new_capacity, device=self.device), self.episode_length),
+ torch.repeat_interleave(torch.arange(original_capacity, new_capacity), self.episode_length),
], dim=0)
self.indices_to_episode_timesteps = torch.cat([
self.indices_to_episode_timesteps,
- torch.arange(self.episode_length, device=self.device).repeat(new_capacity - original_capacity),
+ torch.arange(self.episode_length).repeat(new_capacity - original_capacity),
], dim=0)
logging.info(f'ReplayBuffer: Expanded from capacity={original_capacity} to {new_capacity} episodes')
@@ -215,8 +215,7 @@ class ReplayBuffer(Dataset):
def sample(self, batch_size: int) -> BatchData:
indices = torch.as_tensor(
- np.random.choice(self.num_transitions_realized, size=[batch_size]),
- device=self.device,
+ np.random.choice(self.num_transitions_realized, size=[batch_size])
)
return self[indices]
diff --git a/quasimetric_rl/modules/__init__.py b/quasimetric_rl/modules/__init__.py
index 66704ff..79fd35f 100644
--- a/quasimetric_rl/modules/__init__.py
+++ b/quasimetric_rl/modules/__init__.py
@@ -32,15 +32,13 @@ class QRLLosses(Module):
critic_losses: Collection[quasimetric_critic.QuasimetricCriticLosses],
critics_total_grad_clip_norm: Optional[float],
recompute_critic_for_actor_loss: bool,
- critics_share_embedding: bool,
- critic_losses_use_target_encoder: bool):
+ critics_share_embedding: bool):
super().__init__()
self.add_module('actor_loss', actor_loss)
self.critic_losses = torch.nn.ModuleList(critic_losses) # type: ignore
self.critics_total_grad_clip_norm = critics_total_grad_clip_norm
self.recompute_critic_for_actor_loss = recompute_critic_for_actor_loss
self.critics_share_embedding = critics_share_embedding
- self.critic_losses_use_target_encoder = critic_losses_use_target_encoder
def forward(self, agent: QRLAgent, data: BatchData, *, optimize: bool = True) -> LossResult:
# compute CriticBatchInfo
@@ -59,8 +57,8 @@ class QRLLosses(Module):
)
else:
zx = critic.encoder(data.observations)
- zy = critic.get_encoder(target=self.critic_losses_use_target_encoder)(data.next_observations)
- if critic.has_separate_target_encoder and self.critic_losses_use_target_encoder:
+ zy = critic.target_encoder(data.next_observations)
+ if critic.has_separate_target_encoder:
assert not zy.requires_grad
critic_batch_info = quasimetric_critic.CriticBatchInfo(
critic=critic,
@@ -76,24 +74,19 @@ class QRLLosses(Module):
).sum().backward()
if self.critics_total_grad_clip_norm is not None:
- critic_grad_norm = torch.nn.utils.clip_grad_norm_(
- cast(torch.nn.ModuleList, agent.critics).parameters(),
- max_norm=self.critics_total_grad_clip_norm)
- else:
- critic_grad_norm = sum(p.grad.pow(2).sum() for p in cast(torch.nn.ModuleList, agent.critics).parameters() if p.grad is not None)
- critic_grad_norm = cast(torch.Tensor, critic_grad_norm).sqrt()
+ torch.nn.utils.clip_grad_norm_(cast(torch.nn.ModuleList, agent.critics).parameters(),
+ max_norm=self.critics_total_grad_clip_norm)
if optimize:
for critic in agent.critics:
- critic.update_target_models_()
+ critic.update_target_encoder_()
if self.actor_loss is not None:
assert agent.actor is not None
with torch.no_grad(), torch.inference_mode():
- for idx, critic in enumerate(agent.critics): # FIXME?
- if self.recompute_critic_for_actor_loss or (critic.has_separate_target_encoder and self.actor_loss.use_target_encoder):
- zx, zy = critic.get_encoder(target=self.actor_loss.use_target_encoder)(
- torch.stack([data.observations, data.next_observations], dim=0)).unbind(0)
+ for idx, critic in enumerate(agent.critics):
+ if self.recompute_critic_for_actor_loss or critic.has_separate_target_encoder:
+ zx, zy = critic.target_encoder(torch.stack([data.observations, data.next_observations], dim=0)).unbind(0)
critic_batch_infos[idx] = quasimetric_critic.CriticBatchInfo(
critic=critic,
zx=zx,
@@ -103,13 +96,7 @@ class QRLLosses(Module):
loss_results['actor'] = loss_r = self.actor_loss(agent.actor, critic_batch_infos, data)
cast(torch.Tensor, loss_r.loss).backward()
- actor_grad_norm = sum(p.grad.pow(2).sum() for p in cast(torch.nn.ModuleList, agent.actor).parameters() if p.grad is not None)
- actor_grad_norm = cast(torch.Tensor, actor_grad_norm).sqrt()
-
- else:
- actor_grad_norm = 0
-
- return LossResult.combine(loss_results, grad_norm=dict(critic=critic_grad_norm, actor=actor_grad_norm))
+ return LossResult.combine(loss_results)
# for type hints
def __call__(self, agent: QRLAgent, data: BatchData, *, optimize: bool = True) -> LossResult:
@@ -155,12 +142,7 @@ class QRLLosses(Module):
critic_loss.dynamics_lagrange_mult_sched.load_state_dict(optim_scheds[f"critic_{idx:02d}"]['dynamics_lagrange_mult_sched'])
def extra_repr(self) -> str:
- return '\n'.join([
- f'recompute_critic_for_actor_loss={self.recompute_critic_for_actor_loss}',
- f'critics_share_embedding={self.critics_share_embedding}',
- f'critics_total_grad_clip_norm={self.critics_total_grad_clip_norm}',
- f'critic_losses_use_target_encoder={self.critic_losses_use_target_encoder}',
- ])
+ return f'recompute_critic_for_actor_loss={self.recompute_critic_for_actor_loss}'
@attrs.define(kw_only=True)
@@ -173,7 +155,6 @@ class QRLConf:
default=None, validator=attrs.validators.optional(attrs.validators.gt(0)),
)
recompute_critic_for_actor_loss: bool = False
- critic_losses_use_target_encoder: bool = True
def make(self, *, env_spec: EnvSpec, total_optim_steps: int) -> Tuple[QRLAgent, QRLLosses]:
if self.actor is None:
@@ -193,12 +174,9 @@ class QRLConf:
critics.append(critic)
critic_losses.append(critic_loss)
- return QRLAgent(actor=actor, critics=critics), QRLLosses(
- actor_loss=actor_losses, critic_losses=critic_losses,
- critics_share_embedding=self.critics_share_embedding,
- critics_total_grad_clip_norm=self.critics_total_grad_clip_norm,
- recompute_critic_for_actor_loss=self.recompute_critic_for_actor_loss,
- critic_losses_use_target_encoder=self.critic_losses_use_target_encoder,
- )
+ return QRLAgent(actor=actor, critics=critics), QRLLosses(actor_loss=actor_losses, critic_losses=critic_losses,
+ critics_share_embedding=self.critics_share_embedding,
+ critics_total_grad_clip_norm=self.critics_total_grad_clip_norm,
+ recompute_critic_for_actor_loss=self.recompute_critic_for_actor_loss)
__all__ = ['QRLAgent', 'QRLLosses', 'QRLConf', 'InfoT', 'InfoValT']
diff --git a/quasimetric_rl/modules/actor/losses/__init__.py b/quasimetric_rl/modules/actor/losses/__init__.py
index f9c78a1..c8f1ec7 100644
--- a/quasimetric_rl/modules/actor/losses/__init__.py
+++ b/quasimetric_rl/modules/actor/losses/__init__.py
@@ -15,15 +15,12 @@ from ...optim import OptimWrapper, AdamWSpec, LRScheduler
class ActorLossBase(LossBase):
@abc.abstractmethod
- def forward(self, actor: Actor, critic_batch_infos: Collection[CriticBatchInfo], data: BatchData,
- use_target_encoder: bool, use_target_quasimetric_model: bool) -> LossResult:
+ def forward(self, actor: Actor, critic_batch_infos: Collection[CriticBatchInfo], data: BatchData) -> LossResult:
pass
# for type hints
- def __call__(self, actor: Actor, critic_batch_infos: Collection[CriticBatchInfo], data: BatchData,
- use_target_encoder: bool, use_target_quasimetric_model: bool) -> LossResult:
- return super().__call__(actor, critic_batch_infos, data, use_target_encoder=use_target_encoder,
- use_target_quasimetric_model=use_target_quasimetric_model)
+ def __call__(self, actor: Actor, critic_batch_infos: Collection[CriticBatchInfo], data: BatchData) -> LossResult:
+ return super().__call__(actor, critic_batch_infos, data)
from .min_dist import MinDistLoss
@@ -43,9 +40,6 @@ class ActorLosses(ActorLossBase):
actor_optim: AdamWSpec.Conf = AdamWSpec.Conf(lr=3e-5)
entropy_weight_optim: AdamWSpec.Conf = AdamWSpec.Conf(lr=3e-4) # TODO: move to loss
- use_target_encoder: bool = True
- use_target_quasimetric_model: bool = True
-
def make(self, actor: Actor, total_optim_steps: int, env_spec: EnvSpec) -> 'ActorLosses':
return ActorLosses(
actor,
@@ -55,8 +49,6 @@ class ActorLosses(ActorLossBase):
advantage_weighted_regression=self.advantage_weighted_regression.make(),
actor_optim_spec=self.actor_optim.make(),
entropy_weight_optim_spec=self.entropy_weight_optim.make(),
- use_target_encoder=self.use_target_encoder,
- use_target_quasimetric_model=self.use_target_quasimetric_model,
)
min_dist: Optional[MinDistLoss]
@@ -68,14 +60,10 @@ class ActorLosses(ActorLossBase):
entropy_weight_optim: OptimWrapper
entropy_weight_sched: LRScheduler
- use_target_encoder: bool
- use_target_quasimetric_model: bool
-
def __init__(self, actor: Actor, *, total_optim_steps: int,
min_dist: Optional[MinDistLoss], behavior_cloning: Optional[BCLoss],
advantage_weighted_regression: Optional[AWRLoss],
- actor_optim_spec: AdamWSpec, entropy_weight_optim_spec: AdamWSpec,
- use_target_encoder: bool, use_target_quasimetric_model: bool):
+ actor_optim_spec: AdamWSpec, entropy_weight_optim_spec: AdamWSpec):
super().__init__()
self.add_module('min_dist', min_dist)
self.add_module('behavior_cloning', behavior_cloning)
@@ -88,9 +76,6 @@ class ActorLosses(ActorLossBase):
if min_dist is not None:
assert len(list(min_dist.parameters())) <= 1
- self.use_target_encoder = use_target_encoder
- self.use_target_quasimetric_model = use_target_quasimetric_model
-
def optimizers(self) -> Iterable[OptimWrapper]:
return [self.actor_optim, self.entropy_weight_optim]
@@ -101,24 +86,18 @@ class ActorLosses(ActorLossBase):
loss_results: Dict[str, LossResult] = {}
if self.min_dist is not None:
loss_results.update(
- min_dist=self.min_dist(actor, critic_batch_infos, data, self.use_target_encoder, self.use_target_quasimetric_model),
+ min_dist=self.min_dist(actor, critic_batch_infos, data),
)
if self.behavior_cloning is not None:
loss_results.update(
- bc=self.behavior_cloning(actor, critic_batch_infos, data, self.use_target_encoder, self.use_target_quasimetric_model),
+ bc=self.behavior_cloning(actor, critic_batch_infos, data),
)
if self.advantage_weighted_regression is not None:
loss_results.update(
- awr=self.advantage_weighted_regression(actor, critic_batch_infos, data, self.use_target_encoder, self.use_target_quasimetric_model),
+ awr=self.advantage_weighted_regression(actor, critic_batch_infos, data),
)
return LossResult.combine(loss_results)
# for type hints
def __call__(self, actor: Actor, critic_batch_infos: Sequence[CriticBatchInfo], data: BatchData) -> LossResult:
return torch.nn.Module.__call__(self, actor, critic_batch_infos, data)
-
- def extra_repr(self) -> str:
- return '\n'.join([
- f"actor_optim={self.actor_optim}, entropy_weight_optim={self.entropy_weight_optim}",
- f"use_target_encoder={self.use_target_encoder}, use_target_quasimetric_model={self.use_target_quasimetric_model}",
- ])
diff --git a/quasimetric_rl/modules/actor/losses/awr.py b/quasimetric_rl/modules/actor/losses/awr.py
index 1313dc3..f6cf4de 100644
--- a/quasimetric_rl/modules/actor/losses/awr.py
+++ b/quasimetric_rl/modules/actor/losses/awr.py
@@ -81,8 +81,8 @@ class AWRLoss(ActorLossBase):
self.clamp = clamp
self.add_goal_as_future_state = add_goal_as_future_state
- def gather_obs_goal_pairs(self, critic_batch_infos: Sequence[CriticBatchInfo], data: BatchData,
- use_target_encoder: bool) -> Tuple[torch.Tensor, torch.Tensor, Collection[ActorObsGoalCriticInfo]]:
+ def gather_obs_goal_pairs(self, critic_batch_infos: Sequence[CriticBatchInfo],
+ data: BatchData) -> Tuple[torch.Tensor, torch.Tensor, Collection[ActorObsGoalCriticInfo]]:
r"""
Returns (
obs,
@@ -115,7 +115,7 @@ class AWRLoss(ActorLossBase):
# add future_observations
zg = torch.stack([
zg,
- critic_batch_info.critic.get_encoder(target=use_target_encoder)(data.future_observations),
+ critic_batch_info.critic.target_encoder(data.future_observations),
], 0)
# zo = zo.expand_as(zg)
@@ -127,11 +127,9 @@ class AWRLoss(ActorLossBase):
return obs, goal, actor_obs_goal_critic_infos
- def forward(self, actor: Actor, critic_batch_infos: Sequence[CriticBatchInfo], data: BatchData,
- use_target_encoder: bool, use_target_quasimetric_model: bool) -> LossResult:
-
+ def forward(self, actor: Actor, critic_batch_infos: Sequence[CriticBatchInfo], data: BatchData) -> LossResult:
with torch.no_grad():
- obs, goal, actor_obs_goal_critic_infos = self.gather_obs_goal_pairs(critic_batch_infos, data, use_target_encoder)
+ obs, goal, actor_obs_goal_critic_infos = self.gather_obs_goal_pairs(critic_batch_infos, data)
info: Dict[str, Union[float, torch.Tensor]] = {}
@@ -141,13 +139,11 @@ class AWRLoss(ActorLossBase):
for idx, actor_obs_goal_critic_info in enumerate(actor_obs_goal_critic_infos):
critic = actor_obs_goal_critic_info.critic
- latent_dynamics = critic.latent_dynamics
- quasimetric_model = critic.get_quasimetric_model(target=use_target_quasimetric_model)
zo = actor_obs_goal_critic_info.zo.detach() # [B,D]
zg = actor_obs_goal_critic_info.zg.detach() # [2?,B,D]
with torch.no_grad(), critic.mode(False):
- zp = latent_dynamics(data.observations, zo, data.actions) # [B,D]
- if idx == 0 or not critic.borrowing_embedding:
+ zp = critic.latent_dynamics(data.observations, zo, data.actions) # [B,D]
+ if not critic.borrowing_embedding:
zo, zp, zg = bcast_bshape(
(zo, 1),
(zp, 1),
@@ -155,10 +151,10 @@ class AWRLoss(ActorLossBase):
)
z = torch.stack([zo, zp], dim=0) # [2,2?,B,D]
# z = z[(slice(None),) + (None,) * (zg.ndim - zo.ndim) + (Ellipsis,)] # if zg is batched, add enough dim to be broadcast-able
- dist_noact, dist = quasimetric_model(z, zg).unbind(0) # [2,2?,B] -> 2x [2?,B]
+ dist_noact, dist = critic.quasimetric_model(z, zg).unbind(0) # [2,2?,B] -> 2x [2?,B]
dist_noact = dist_noact.detach()
else:
- dist = quasimetric_model(zp, zg)
+ dist = critic.quasimetric_model(zp, zg)
dist_noact = dists_noact[0]
info[f'dist_delta_{idx:02d}'] = (dist_noact - dist).mean()
info[f'dist_{idx:02d}'] = dist.mean()
diff --git a/quasimetric_rl/modules/actor/losses/behavior_cloning.py b/quasimetric_rl/modules/actor/losses/behavior_cloning.py
index eba36af..632221e 100644
--- a/quasimetric_rl/modules/actor/losses/behavior_cloning.py
+++ b/quasimetric_rl/modules/actor/losses/behavior_cloning.py
@@ -34,8 +34,7 @@ class BCLoss(ActorLossBase):
super().__init__()
self.weight = weight
- def forward(self, actor: Actor, critic_batch_infos: Collection[CriticBatchInfo], data: BatchData,
- use_target_encoder: bool, use_target_quasimetric_model: bool) -> LossResult:
+ def forward(self, actor: Actor, critic_batch_infos: Collection[CriticBatchInfo], data: BatchData) -> LossResult:
actor_distn = actor(data.observations, data.future_observations)
log_prob: torch.Tensor = actor_distn.log_prob(data.actions).mean()
loss = -log_prob * self.weight
diff --git a/quasimetric_rl/modules/actor/losses/min_dist.py b/quasimetric_rl/modules/actor/losses/min_dist.py
index 1d57609..bc36ee8 100644
--- a/quasimetric_rl/modules/actor/losses/min_dist.py
+++ b/quasimetric_rl/modules/actor/losses/min_dist.py
@@ -87,8 +87,8 @@ class MinDistLoss(ActorLossBase):
self.register_parameter('raw_entropy_weight', None)
self.target_entropy = None
- def gather_obs_goal_pairs(self, critic_batch_infos: Sequence[CriticBatchInfo], data: BatchData,
- use_target_encoder: bool) -> Tuple[torch.Tensor, torch.Tensor, Collection[ActorObsGoalCriticInfo]]:
+ def gather_obs_goal_pairs(self, critic_batch_infos: Sequence[CriticBatchInfo],
+ data: BatchData) -> Tuple[torch.Tensor, torch.Tensor, Collection[ActorObsGoalCriticInfo]]:
r"""
Returns (
obs,
@@ -121,7 +121,7 @@ class MinDistLoss(ActorLossBase):
# add future_observations
zg = torch.stack([
zg,
- critic_batch_info.critic.get_encoder(target=use_target_encoder)(data.future_observations),
+ critic_batch_info.critic.target_encoder(data.future_observations),
], 0)
# zo = zo.expand_as(zg)
@@ -133,10 +133,9 @@ class MinDistLoss(ActorLossBase):
return obs, goal, actor_obs_goal_critic_infos
- def forward(self, actor: Actor, critic_batch_infos: Sequence[CriticBatchInfo], data: BatchData,
- use_target_encoder: bool, use_target_quasimetric_model: bool) -> LossResult:
+ def forward(self, actor: Actor, critic_batch_infos: Sequence[CriticBatchInfo], data: BatchData) -> LossResult:
with torch.no_grad():
- obs, goal, actor_obs_goal_critic_infos = self.gather_obs_goal_pairs(critic_batch_infos, data, use_target_encoder)
+ obs, goal, actor_obs_goal_critic_infos = self.gather_obs_goal_pairs(critic_batch_infos, data)
actor_distn = actor(obs, goal)
action = actor_distn.rsample()
@@ -148,19 +147,17 @@ class MinDistLoss(ActorLossBase):
for idx, actor_obs_goal_critic_info in enumerate(actor_obs_goal_critic_infos):
critic = actor_obs_goal_critic_info.critic
- latent_dynamics = critic.latent_dynamics
- quasimetric_model = critic.get_quasimetric_model(target=use_target_quasimetric_model)
zo = actor_obs_goal_critic_info.zo.detach() # [B,D]
zg = actor_obs_goal_critic_info.zg.detach() # [2?,B,D]
with critic.requiring_grad(False), critic.mode(False):
- zp = latent_dynamics(data.observations, zo, action) # [2?,B,D]
- if idx == 0 or not critic.borrowing_embedding:
+ zp = critic.latent_dynamics(data.observations, zo, action) # [2?,B,D]
+ if not critic.borrowing_embedding:
# action: [2?,B,A]
z = torch.stack(torch.broadcast_tensors(zo, zp), dim=0) # [2,2?,B,D]
- dist_noact, dist = quasimetric_model(z, zg).unbind(0) # [2,2?,B] -> 2x [2?,B]
+ dist_noact, dist = critic.quasimetric_model(z, zg).unbind(0) # [2,2?,B] -> 2x [2?,B]
dist_noact = dist_noact.detach()
else:
- dist = quasimetric_model(zp, zg) # [2?,B]
+ dist = critic.quasimetric_model(zp, zg) # [2?,B]
dist_noact = dists_noact[0]
info[f'dist_delta_{idx:02d}'] = (dist_noact - dist).mean()
info[f'dist_{idx:02d}'] = dist.mean()
diff --git a/quasimetric_rl/modules/optim.py b/quasimetric_rl/modules/optim.py
index 5019e1b..5131ce0 100644
--- a/quasimetric_rl/modules/optim.py
+++ b/quasimetric_rl/modules/optim.py
@@ -1,7 +1,6 @@
from typing import *
import attrs
-import logging
import contextlib
import torch
@@ -92,11 +91,9 @@ class AdamWSpec:
if len(params) == 0:
params = [dict(params=[])] # dummy param group so pytorch doesn't complain
for ii in range(len(params)):
- pg = params[ii]
- if isinstance(pg, Mapping) and 'lr_mul' in pg: # handle lr_multiplier
- pg['params'] = params_list = cast(List[torch.Tensor], list(pg['params']))
+ if isinstance(params, Mapping) and 'lr_mul' in params: # handle lr_multiplier
+ pg = params[ii]
assert 'lr' not in pg
- logging.info(f'params (#tensor={len(params_list)}, #={sum(p.numel() for p in params_list)}): lr_mul={pg["lr_mul"]}')
pg['lr'] = self.lr * pg['lr_mul'] # type: ignore
del pg['lr_mul']
return OptimWrapper(
diff --git a/quasimetric_rl/modules/quasimetric_critic/losses/__init__.py b/quasimetric_rl/modules/quasimetric_critic/losses/__init__.py
index 48ed700..94f59ba 100644
--- a/quasimetric_rl/modules/quasimetric_critic/losses/__init__.py
+++ b/quasimetric_rl/modules/quasimetric_critic/losses/__init__.py
@@ -31,7 +31,7 @@ class CriticLossBase(LossBase):
return super().__call__(data, critic_batch_info)
-from .global_push import GlobalPushLoss, GlobalPushLinearLoss, GlobalPushLogLoss, GlobalPushRBFLoss, GlobalPushNextMSELoss
+from .global_push import GlobalPushLoss, GlobalPushLinearLoss, GlobalPushLogLoss, GlobalPushRBFLoss
from .local_constraint import LocalConstraintLoss
from .latent_dynamics import LatentDynamicsLoss
@@ -41,7 +41,6 @@ class QuasimetricCriticLosses(CriticLossBase):
class Conf:
global_push: GlobalPushLoss.Conf = GlobalPushLoss.Conf()
global_push_linear: GlobalPushLinearLoss.Conf = GlobalPushLinearLoss.Conf()
- global_push_next_mse: GlobalPushNextMSELoss.Conf = GlobalPushNextMSELoss.Conf()
global_push_log: GlobalPushLogLoss.Conf = GlobalPushLogLoss.Conf()
global_push_rbf: GlobalPushRBFLoss.Conf = GlobalPushRBFLoss.Conf()
local_constraint: LocalConstraintLoss.Conf = LocalConstraintLoss.Conf()
@@ -69,7 +68,6 @@ class QuasimetricCriticLosses(CriticLossBase):
# global losses
global_push=self.global_push.make(),
global_push_linear=self.global_push_linear.make(),
- global_push_next_mse=self.global_push_next_mse.make(),
global_push_log=self.global_push_log.make(),
global_push_rbf=self.global_push_rbf.make(),
# local loss
@@ -92,7 +90,6 @@ class QuasimetricCriticLosses(CriticLossBase):
borrowing_embedding: bool
global_push: Optional[GlobalPushLoss]
global_push_linear: Optional[GlobalPushLinearLoss]
- global_push_next_mse: Optional[GlobalPushNextMSELoss]
global_push_log: Optional[GlobalPushLogLoss]
global_push_rbf: Optional[GlobalPushRBFLoss]
local_constraint: Optional[LocalConstraintLoss]
@@ -109,8 +106,7 @@ class QuasimetricCriticLosses(CriticLossBase):
def __init__(self, critic: QuasimetricCritic, *, total_optim_steps: int,
share_embedding_from: Optional[QuasimetricCritic] = None,
global_push: Optional[GlobalPushLoss], global_push_linear: Optional[GlobalPushLinearLoss],
- global_push_next_mse: Optional[GlobalPushNextMSELoss], global_push_log: Optional[GlobalPushLogLoss],
- global_push_rbf: Optional[GlobalPushRBFLoss],
+ global_push_log: Optional[GlobalPushLogLoss], global_push_rbf: Optional[GlobalPushRBFLoss],
local_constraint: Optional[LocalConstraintLoss], latent_dynamics: LatentDynamicsLoss,
critic_optim_spec: AdamWSpec,
latent_dynamics_lr_mul: float,
@@ -130,7 +126,6 @@ class QuasimetricCriticLosses(CriticLossBase):
local_constraint = None
self.add_module('global_push', global_push)
self.add_module('global_push_linear', global_push_linear)
- self.add_module('global_push_next_mse', global_push_next_mse)
self.add_module('global_push_log', global_push_log)
self.add_module('global_push_rbf', global_push_rbf)
self.add_module('local_constraint', local_constraint)
@@ -189,8 +184,6 @@ class QuasimetricCriticLosses(CriticLossBase):
loss_results.update(global_push=self.global_push(data, critic_batch_info))
if self.global_push_linear is not None:
loss_results.update(global_push_linear=self.global_push_linear(data, critic_batch_info))
- if self.global_push_next_mse is not None:
- loss_results.update(global_push_next_mse=self.global_push_next_mse(data, critic_batch_info))
if self.global_push_log is not None:
loss_results.update(global_push_log=self.global_push_log(data, critic_batch_info))
if self.global_push_rbf is not None:
diff --git a/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py b/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py
index c4e2656..7c3f9f3 100644
--- a/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py
+++ b/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py
@@ -53,104 +53,16 @@ from . import CriticLossBase, CriticBatchInfo
# return f"weight={self.weight:g}, softplus_beta={self.softplus_beta:g}, softplus_offset={self.softplus_offset:g}"
-class GlobalPushLossBase(CriticLossBase):
+
+class GlobalPushLoss(CriticLossBase):
@attrs.define(kw_only=True)
- class Conf(abc.ABC):
+ class Conf:
# config / argparse uses this to specify behavior
enabled: bool = True
detach_goal: bool = False
detach_proj_goal: bool = False
- detach_qmet: bool = False
- step_cost: float = attrs.field(default=1., validator=attrs.validators.gt(0))
weight: float = attrs.field(default=1., validator=attrs.validators.gt(0))
- weight_future_goal: float = attrs.field(default=0., validator=attrs.validators.ge(0))
- clamp_max_future_goal: bool = True
-
- @abc.abstractmethod
- def make(self) -> Optional['GlobalPushLossBase']:
- if not self.enabled:
- return None
-
- weight: float
- weight_future_goal: float
- detach_goal: bool
- detach_proj_goal: bool
- detach_qmet: bool
- step_cost: float
- clamp_max_future_goal: bool
-
- def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool,
- detach_qmet: bool, step_cost: float, clamp_max_future_goal: bool):
- super().__init__()
- self.weight = weight
- self.weight_future_goal = weight_future_goal
- self.detach_goal = detach_goal
- self.detach_proj_goal = detach_proj_goal
- self.detach_qmet = detach_qmet
- self.step_cost = step_cost
- self.clamp_max_future_goal = clamp_max_future_goal
-
- def generate_dist_weight(self, data: BatchData, critic_batch_info: CriticBatchInfo):
- def get_dist(za: torch.Tensor, zb: torch.Tensor):
- if self.detach_goal:
- zb = zb.detach()
- with critic_batch_info.critic.quasimetric_model.requiring_grad(not self.detach_qmet):
- return critic_batch_info.critic.quasimetric_model(za, zb, proj_grad_enabled=(True, not self.detach_proj_goal))
-
- # To randomly pair zx, zy, we just roll over zy by 1, because zx and zy
- # are latents of randomly ordered random batches.
- zgoal = torch.roll(critic_batch_info.zy, 1, dims=0)
- yield (
- 'random_goal',
- zgoal,
- get_dist(critic_batch_info.zx, zgoal),
- self.weight,
- {},
- )
- if self.weight_future_goal > 0:
- zgoal = critic_batch_info.critic.encoder(data.future_observations)
- dist = get_dist(critic_batch_info.zx, zgoal)
- if self.clamp_max_future_goal:
- observed_upper_bound = self.step_cost * data.future_tdelta
- info = dict(
- ratio_future_observed_dist=(dist / observed_upper_bound).mean(),
- exceed_future_observed_dist_rate=(dist > observed_upper_bound).mean(dtype=torch.float32),
- )
- dist = dist.clamp_max(self.step_cost * data.future_tdelta)
- else:
- info = {}
- yield (
- 'future_goal',
- zgoal,
- dist,
- self.weight_future_goal,info,
- )
-
- @abc.abstractmethod
- def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor,
- dist: torch.Tensor, weight: float, info: Mapping[str, torch.Tensor]) -> LossResult:
- raise NotImplementedError
-
- def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult:
- return LossResult.combine(
- {
- name: self.compute_loss(data, critic_batch_info, zgoal, dist, weight, info)
- for name, zgoal, dist, weight, info in self.generate_dist_weight(data, critic_batch_info)
- },
- )
-
- def extra_repr(self) -> str:
- return '\n'.join([
- f"weight={self.weight:g}, detach_proj_goal={self.detach_proj_goal}",
- f"weight_future_goal={self.weight_future_goal:g}, detach_qmet={self.detach_qmet}",
- f"step_cost={self.step_cost:g}, clamp_max_future_goal={self.clamp_max_future_goal}",
- ])
-
-
-class GlobalPushLoss(GlobalPushLossBase):
- @attrs.define(kw_only=True)
- class Conf(GlobalPushLossBase.Conf):
# smaller => smoother loss
softplus_beta: float = attrs.field(default=0.1, validator=attrs.validators.gt(0))
@@ -163,175 +75,99 @@ class GlobalPushLoss(GlobalPushLossBase):
return None
return GlobalPushLoss(
weight=self.weight,
- weight_future_goal=self.weight_future_goal,
detach_goal=self.detach_goal,
detach_proj_goal=self.detach_proj_goal,
- detach_qmet=self.detach_qmet,
- step_cost=self.step_cost,
- clamp_max_future_goal=self.clamp_max_future_goal,
softplus_beta=self.softplus_beta,
softplus_offset=self.softplus_offset,
)
+ weight: float
+ detach_goal: bool
+ detach_proj_goal: bool
softplus_beta: float
softplus_offset: float
- def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, detach_qmet: bool, step_cost: float, clamp_max_future_goal: bool,
+ def __init__(self, *, weight: float, detach_goal: bool, detach_proj_goal: bool,
softplus_beta: float, softplus_offset: float):
- super().__init__(weight=weight, weight_future_goal=weight_future_goal, detach_goal=detach_goal,
- detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, clamp_max_future_goal=clamp_max_future_goal)
+ super().__init__()
+ self.weight = weight
+ self.detach_goal = detach_goal
+ self.detach_proj_goal = detach_proj_goal
self.softplus_beta = softplus_beta
self.softplus_offset = softplus_offset
- def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor,
- dist: torch.Tensor, weight: float, info: Mapping[str, torch.Tensor]) -> LossResult:
- dict_info: Dict[str, torch.Tensor] = dict(info)
+ def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult:
+ # To randomly pair zx, zy, we just roll over zy by 1, because zx and zy
+ # are latents of randomly ordered random batches.
+ zgoal = torch.roll(critic_batch_info.zy, 1, dims=0)
+ if self.detach_goal:
+ zgoal = zgoal.detach()
+ dists = critic_batch_info.critic.quasimetric_model(critic_batch_info.zx, zgoal,
+ proj_grad_enabled=(True, not self.detach_proj_goal))
# Sec 3.2. Transform so that we penalize large distances less.
- tsfm_dist: torch.Tensor = F.softplus(self.softplus_offset - dist, beta=self.softplus_beta) # type: ignore
+ tsfm_dist: torch.Tensor = F.softplus(self.softplus_offset - dists, beta=self.softplus_beta) # type: ignore
tsfm_dist = tsfm_dist.mean()
- dict_info.update(dist=dist.mean(), tsfm_dist=tsfm_dist)
- return LossResult(loss=tsfm_dist * weight, info=dict_info)
+ return LossResult(loss=tsfm_dist * self.weight, info=dict(dist=dists.mean(), tsfm_dist=tsfm_dist)) # type: ignore
def extra_repr(self) -> str:
- return '\n'.join([
- super().extra_repr(),
- f"softplus_beta={self.softplus_beta:g}, softplus_offset={self.softplus_offset:g}",
- ])
+ return f"weight={self.weight:g}, detach_proj_goal={self.detach_proj_goal}, softplus_beta={self.softplus_beta:g}, softplus_offset={self.softplus_offset:g}"
+
-class GlobalPushLinearLoss(GlobalPushLossBase):
+class GlobalPushLinearLoss(CriticLossBase):
@attrs.define(kw_only=True)
- class Conf(GlobalPushLossBase.Conf):
- enabled: bool = False
+ class Conf:
+ # config / argparse uses this to specify behavior
- clamp_max: Optional[float] = attrs.field(default=None, validator=attrs.validators.optional(attrs.validators.gt(0)))
+ enabled: bool = False
+ detach_goal: bool = False
+ detach_proj_goal: bool = False
+ weight: float = attrs.field(default=1., validator=attrs.validators.gt(0))
def make(self) -> Optional['GlobalPushLinearLoss']:
if not self.enabled:
return None
return GlobalPushLinearLoss(
weight=self.weight,
- weight_future_goal=self.weight_future_goal,
detach_goal=self.detach_goal,
detach_proj_goal=self.detach_proj_goal,
- detach_qmet=self.detach_qmet,
- step_cost=self.step_cost,
- clamp_max_future_goal=self.clamp_max_future_goal,
- clamp_max=self.clamp_max,
)
- clamp_max: Optional[float]
-
- def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, detach_qmet: bool, step_cost: float, clamp_max_future_goal: bool,
- clamp_max: Optional[float]):
- super().__init__(weight=weight, weight_future_goal=weight_future_goal, detach_goal=detach_goal,
- detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, clamp_max_future_goal=clamp_max_future_goal)
- self.clamp_max = clamp_max
-
- def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor,
- dist: torch.Tensor, weight: float, info: Mapping[str, torch.Tensor]) -> LossResult:
- dict_info: Dict[str, torch.Tensor] = dict(info)
- if self.clamp_max is None:
- dist = dist.mean()
- dict_info.update(dist=dist)
- neg_loss = dist
- else:
- tsfm_dist = dist.clamp_max(self.clamp_max)
- dict_info.update(
- dist=dist.mean(),
- tsfm_dist=tsfm_dist.mean(),
- exceed_rate=(dist >= self.clamp_max).mean(dtype=torch.float32),
- )
- neg_loss = tsfm_dist
- return LossResult(loss=neg_loss * weight, info=dict_info)
-
- def extra_repr(self) -> str:
- return '\n'.join([
- super().extra_repr(),
- f"clamp_max={self.clamp_max!r}",
- ])
-
-
-
-
-class GlobalPushNextMSELoss(GlobalPushLossBase):
- @attrs.define(kw_only=True)
- class Conf(GlobalPushLossBase.Conf):
- enabled: bool = False
- detach_target_dist: bool = True
- allow_gt: bool = False
- gamma: Optional[float] = attrs.field(
- default=None, validator=attrs.validators.optional(attrs.validators.and_(
- attrs.validators.gt(0),
- attrs.validators.lt(1),
- )))
-
- def make(self) -> Optional['GlobalPushNextMSELoss']:
- if not self.enabled:
- return None
- return GlobalPushNextMSELoss(
- weight=self.weight,
- weight_future_goal=self.weight_future_goal,
- detach_goal=self.detach_goal,
- detach_proj_goal=self.detach_proj_goal,
- detach_qmet=self.detach_qmet,
- clamp_max_future_goal=self.clamp_max_future_goal,
- step_cost=self.step_cost,
- detach_target_dist=self.detach_target_dist,
- allow_gt=self.allow_gt,
- gamma=self.gamma,
- )
-
- def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, detach_qmet: bool, step_cost: float, clamp_max_future_goal: bool,
- detach_target_dist: bool, allow_gt: bool, gamma: Optional[float]):
- super().__init__(weight=weight, weight_future_goal=weight_future_goal, detach_goal=detach_goal,
- detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, clamp_max_future_goal=clamp_max_future_goal)
- self.detach_target_dist = detach_target_dist
- self.allow_gt = allow_gt
- self.gamma = gamma
-
- def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor,
- dist: torch.Tensor, weight: float, info: Mapping[str, torch.Tensor]) -> LossResult:
- dict_info: Dict[str, torch.Tensor] = dict(info)
-
- with torch.enable_grad(self.detach_target_dist):
- # by tri-eq, the actual cost can't be larger that step_cost + d(s', g)
- next_dist = critic_batch_info.critic.quasimetric_model(
- critic_batch_info.zy, zgoal, proj_grad_enabled=(True, not self.detach_proj_goal)
- )
- if self.detach_target_dist:
- next_dist = next_dist.detach()
- target_dist = self.step_cost + next_dist
-
- if self.allow_gt:
- dist = dist.clamp_max(target_dist)
- dict_info.update(
- exceed_rate=(dist >= target_dist).mean(dtype=torch.float32),
- )
-
- if self.gamma is None:
- loss = F.mse_loss(dist, target_dist)
- else:
- loss = F.mse_loss(self.gamma ** dist, self.gamma ** target_dist)
+ weight: float
+ detach_goal: bool
+ detach_proj_goal: bool
- dict_info.update(
- loss=loss, dist=dist.mean(), target_dist=target_dist.mean()
- )
+ def __init__(self, *, weight: float, detach_goal: bool, detach_proj_goal: bool):
+ super().__init__()
+ self.weight = weight
+ self.detach_goal = detach_goal
+ self.detach_proj_goal = detach_proj_goal
- return LossResult(loss=loss * self.weight, info=dict_info)
+ def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult:
+ # To randomly pair zx, zy, we just roll over zy by 1, because zx and zy
+ # are latents of randomly ordered random batches.
+ zgoal = torch.roll(critic_batch_info.zy, 1, dims=0)
+ if self.detach_goal:
+ zgoal = zgoal.detach()
+ dists = critic_batch_info.critic.quasimetric_model(critic_batch_info.zx, zgoal,
+ proj_grad_enabled=(True, not self.detach_proj_goal))
+ dists = dists.mean()
+ return LossResult(loss=dists * (-self.weight), info=dict(dist=dists)) # type: ignore
def extra_repr(self) -> str:
- return '\n'.join([
- super().extra_repr(),
- "detach_target_dist={self.detach_target_dist}",
- ])
+ return f"weight={self.weight:g}, detach_proj_goal={self.detach_proj_goal}"
+
-class GlobalPushLogLoss(GlobalPushLossBase):
+class GlobalPushLogLoss(CriticLossBase):
@attrs.define(kw_only=True)
- class Conf(GlobalPushLossBase.Conf):
- enabled: bool = False
+ class Conf:
+ # config / argparse uses this to specify behavior
+ enabled: bool = False
+ detach_goal: bool = False
+ detach_proj_goal: bool = False
+ weight: float = attrs.field(default=1., validator=attrs.validators.gt(0))
offset: float = attrs.field(default=1., validator=attrs.validators.gt(0))
def make(self) -> Optional['GlobalPushLogLoss']:
@@ -339,46 +175,54 @@ class GlobalPushLogLoss(GlobalPushLossBase):
return None
return GlobalPushLogLoss(
weight=self.weight,
- weight_future_goal=self.weight_future_goal,
detach_goal=self.detach_goal,
detach_proj_goal=self.detach_proj_goal,
- detach_qmet=self.detach_qmet,
- step_cost=self.step_cost,
- clamp_max_future_goal=self.clamp_max_future_goal,
offset=self.offset,
)
+ weight: float
+ detach_goal: bool
+ detach_proj_goal: bool
offset: float
- def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, detach_qmet: bool, step_cost: float, clamp_max_future_goal: bool,
- offset: float):
- super().__init__(weight=weight, weight_future_goal=weight_future_goal, detach_goal=detach_goal,
- detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, clamp_max_future_goal=clamp_max_future_goal)
+ def __init__(self, *, weight: float, detach_goal: bool, detach_proj_goal: bool, offset: float):
+ super().__init__()
+ self.weight = weight
+ self.detach_goal = detach_goal
+ self.detach_proj_goal = detach_proj_goal
self.offset = offset
- def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor,
- dist: torch.Tensor, weight: float, info: Mapping[str, torch.Tensor]) -> LossResult:
- dict_info: Dict[str, torch.Tensor] = dict(info)
- tsfm_dist: torch.Tensor = -dist.add(self.offset).log()
+ def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult:
+ # To randomly pair zx, zy, we just roll over zy by 1, because zx and zy
+ # are latents of randomly ordered random batches.
+ zgoal = torch.roll(critic_batch_info.zy, 1, dims=0)
+ if self.detach_goal:
+ zgoal = zgoal.detach()
+ dists = critic_batch_info.critic.quasimetric_model(critic_batch_info.zx, zgoal,
+ proj_grad_enabled=(True, not self.detach_proj_goal))
+ # Sec 3.2. Transform so that we penalize large distances less.
+ tsfm_dist: torch.Tensor = -dists.add(self.offset).log()
tsfm_dist = tsfm_dist.mean()
- dict_info.update(dist=dist.mean(), tsfm_dist=tsfm_dist)
- return LossResult(loss=tsfm_dist * weight, info=dict_info)
+ return LossResult(loss=tsfm_dist * self.weight, info=dict(dist=dists.mean(), tsfm_dist=tsfm_dist)) # type: ignore
def extra_repr(self) -> str:
- return '\n'.join([
- super().extra_repr(),
- f"offset={self.offset:g}",
- ])
+ return f"weight={self.weight:g}, detach_proj_goal={self.detach_proj_goal}, offset={self.offset:g}"
+
-class GlobalPushRBFLoss(GlobalPushLossBase):
+class GlobalPushRBFLoss(CriticLossBase):
# say E[opt T] approx sqrt(2)/2 timeout, so E[opt T^2] approx 1/2 timeout^2
# to emulate log E exp(-2 d^2), where 2 d^2 is around 4, we scale model T with r, and use log E exp(- r^2 T^2), and let r^2 T^2 to be around 4
# so r^2 approx 8 / timeout^2 and r approx 2.82 / timeout. If timeout = 850, this = 300
@attrs.define(kw_only=True)
- class Conf(GlobalPushLossBase.Conf):
+ class Conf:
+ # config / argparse uses this to specify behavior
+
enabled: bool = False
+ detach_goal: bool = False
+ detach_proj_goal: bool = False
+ weight: float = attrs.field(default=1., validator=attrs.validators.gt(0))
inv_scale: float = attrs.field(default=300., validator=attrs.validators.ge(1e-3))
@@ -387,37 +231,37 @@ class GlobalPushRBFLoss(GlobalPushLossBase):
return None
return GlobalPushRBFLoss(
weight=self.weight,
- weight_future_goal=self.weight_future_goal,
detach_goal=self.detach_goal,
detach_proj_goal=self.detach_proj_goal,
- detach_qmet=self.detach_qmet,
- step_cost=self.step_cost,
- clamp_max_future_goal=self.clamp_max_future_goal,
inv_scale=self.inv_scale,
)
+ weight: float
+ detach_goal: bool
+ detach_proj_goal: bool
inv_scale: float
- def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, detach_qmet: bool, step_cost: float, clamp_max_future_goal: bool,
- inv_scale: float):
- super().__init__(weight=weight, weight_future_goal=weight_future_goal, detach_goal=detach_goal,
- detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, clamp_max_future_goal=clamp_max_future_goal)
+ def __init__(self, *, weight: float, detach_goal: bool, detach_proj_goal: bool, inv_scale: float):
+ super().__init__()
+ self.weight = weight
+ self.detach_goal = detach_goal
+ self.detach_proj_goal = detach_proj_goal
self.inv_scale = inv_scale
- def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor,
- dist: torch.Tensor, weight: float, info: Mapping[str, torch.Tensor]) -> LossResult:
- dict_info: Dict[str, torch.Tensor] = dict(info)
- inv_scale = dist.detach().square().mean().div(2).sqrt().clamp(1e-3, self.inv_scale) # make E[d^2]/r^2 approx 2
- tsfm_dist: torch.Tensor = (dist / inv_scale).square().neg().exp()
+ def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult:
+ # To randomly pair zx, zy, we just roll over zy by 1, because zx and zy
+ # are latents of randomly ordered random batches.
+ zgoal = torch.roll(critic_batch_info.zy, 1, dims=0)
+ if self.detach_goal:
+ zgoal = zgoal.detach()
+ dists = critic_batch_info.critic.quasimetric_model(critic_batch_info.zx, zgoal,
+ proj_grad_enabled=(True, not self.detach_proj_goal))
+ inv_scale = dists.detach().square().mean().div(2).sqrt().clamp(1e-3, self.inv_scale) # make E[d^2]/r^2 approx 2
+ tsfm_dist: torch.Tensor = (dists / inv_scale).square().neg().exp()
rbf_potential = tsfm_dist.mean().log()
- dict_info.update(
- dist=dist.mean(), inv_scale=inv_scale,
- tsfm_dist=tsfm_dist, rbf_potential=rbf_potential,
- )
- return LossResult(loss=rbf_potential * self.weight, info=dict_info)
+ return LossResult(loss=rbf_potential * self.weight,
+ info=dict(dist=dists.mean(), inv_scale=inv_scale,
+ tsfm_dist=tsfm_dist, rbf_potential=rbf_potential)) # type: ignore
def extra_repr(self) -> str:
- return '\n'.join([
- super().extra_repr(),
- f"inv_scale={self.inv_scale:g}",
- ])
+ return f"weight={self.weight:g}, detach_proj_goal={self.detach_proj_goal}, inv_scale={self.inv_scale:g}"
diff --git a/quasimetric_rl/modules/quasimetric_critic/losses/latent_dynamics.py b/quasimetric_rl/modules/quasimetric_critic/losses/latent_dynamics.py
index 19e4886..f5e8248 100644
--- a/quasimetric_rl/modules/quasimetric_critic/losses/latent_dynamics.py
+++ b/quasimetric_rl/modules/quasimetric_critic/losses/latent_dynamics.py
@@ -2,7 +2,6 @@ from typing import *
import attrs
-import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -26,7 +25,7 @@ class LatentDynamicsLoss(CriticLossBase):
# weight: float = attrs.field(default=0.1, validator=attrs.validators.gt(0))
epsilon: float = attrs.field(default=0.25, validator=attrs.validators.gt(0))
- # bidirectional: bool = True
+ bidirectional: bool = True
gamma: Optional[float] = attrs.field(
default=None, validator=attrs.validators.optional(attrs.validators.and_(
attrs.validators.gt(0),
@@ -38,26 +37,22 @@ class LatentDynamicsLoss(CriticLossBase):
detach_qmet: bool = False
non_quasimetric_dim_mse_weight: float = attrs.field(default=0., validator=attrs.validators.ge(0))
- kind: str = attrs.field(
- default='mean', validator=attrs.validators.in_({'mean', 'unidir', 'max', 'bound', 'bound_l2comp', 'l2comp'})) # type: ignore
-
def make(self) -> 'LatentDynamicsLoss':
return LatentDynamicsLoss(
# weight=self.weight,
epsilon=self.epsilon,
- # bidirectional=self.bidirectional,
+ bidirectional=self.bidirectional,
gamma=self.gamma,
detach_sp=self.detach_sp,
detach_proj_sp=self.detach_proj_sp,
detach_qmet=self.detach_qmet,
non_quasimetric_dim_mse_weight=self.non_quasimetric_dim_mse_weight,
init_lagrange_multiplier=self.init_lagrange_multiplier,
- kind=self.kind,
)
# weight: float
epsilon: float
- # bidirectional: bool
+ bidirectional: bool
gamma: Optional[float]
detach_sp: bool
detach_proj_sp: bool
@@ -65,24 +60,20 @@ class LatentDynamicsLoss(CriticLossBase):
non_quasimetric_dim_mse_weight: float
c: float
init_lagrange_multiplier: float
- kind: str
- def __init__(self, *, epsilon: float,
- # bidirectional: bool,
+ def __init__(self, *, epsilon: float, bidirectional: bool,
gamma: Optional[float], init_lagrange_multiplier: float,
# weight: float,
detach_qmet: bool, detach_proj_sp: bool, detach_sp: bool,
- non_quasimetric_dim_mse_weight: float,
- kind: str):
+ non_quasimetric_dim_mse_weight: float):
super().__init__()
# self.weight = weight
self.epsilon = epsilon
- # self.bidirectional = bidirectional
+ self.bidirectional = bidirectional
self.gamma = gamma
self.detach_qmet = detach_qmet
self.detach_sp = detach_sp
self.detach_proj_sp = detach_proj_sp
- self.kind = kind
self.non_quasimetric_dim_mse_weight = non_quasimetric_dim_mse_weight
self.init_lagrange_multiplier = init_lagrange_multiplier
self.raw_lagrange_multiplier = nn.Parameter(
@@ -95,58 +86,16 @@ class LatentDynamicsLoss(CriticLossBase):
pred_zy = critic.latent_dynamics(data.observations, zx, data.actions)
if self.detach_sp:
zy = zy.detach()
+ with critic.quasimetric_model.requiring_grad(not self.detach_qmet):
+ dists = critic.quasimetric_model(zy, pred_zy, bidirectional=self.bidirectional, # at least optimize d(s', hat{s'})
+ proj_grad_enabled=(self.detach_proj_sp, True))
+
lagrange_mult = F.softplus(self.raw_lagrange_multiplier) # make positive
# lagrange multiplier is minimax training, so grad_mul -1
lagrange_mult = grad_mul(lagrange_mult, -1)
- info: Dict[str, torch.Tensor] = {}
-
- if self.kind == 'unidir':
- with critic.quasimetric_model.requiring_grad(not self.detach_qmet):
- dists = critic.quasimetric_model(zy, pred_zy, bidirectional=False, # at least optimize d(s', hat{s'})
- proj_grad_enabled=(self.detach_proj_sp, True))
- info.update(dist_n2p=dists.mean())
- elif self.kind in ['mean', 'max', 'bound', 'l2comp']:
- with critic.quasimetric_model.requiring_grad(not self.detach_qmet):
- dists_comps = critic.quasimetric_model(
- zy, pred_zy, bidirectional=True, proj_grad_enabled=(self.detach_proj_sp, True),
- reduced=False)
- dists = critic.quasimetric_model.quasimetric_head.reduction(dists_comps)
-
- dist_p2n, dist_n2p = dists.unbind(-1)
- info.update(dist_p2n=dist_p2n.mean(), dist_n2p=dist_n2p.mean())
- if self.kind == 'max':
- dists = dists.max(-1).values
- elif self.kind == 'bound':
- with critic.quasimetric_model.requiring_grad(not self.detach_qmet):
- dists = critic.quasimetric_model(
- zy, pred_zy, bidirectional=False, # NB: false is enough for bound
- proj_grad_enabled=(self.detach_proj_sp, True),
- symmetric_upperbound=True,
- )
- info.update(dist_upbnd=dists.mean())
- elif self.kind == 'bound_l2comp':
- with critic.quasimetric_model.requiring_grad(not self.detach_qmet):
- dists_comps = critic.quasimetric_model(
- zy, pred_zy, bidirectional=False, # NB: false is enough for bound
- proj_grad_enabled=(self.detach_proj_sp, True),
- reduced=False, symmetric_upperbound=True,
- )
- dists = cast(
- torch.Tensor,
- dists_comps.norm(dim=-1) / np.sqrt(dists_comps.shape[-1]),
- )
- info.update(dist_upbnd_l2comp=dists.mean())
- elif self.kind == 'l2comp':
- dists = cast(
- torch.Tensor,
- dists_comps.norm(dim=-1) / np.sqrt(dists_comps.shape[-1]),
- )
- info.update(dist_l2comp=dists.mean())
- else:
- raise NotImplementedError(self.kind)
-
+ info = {}
if self.gamma is None:
sq_dists = dists.square().mean()
violation = (sq_dists - self.epsilon ** 2)
@@ -168,6 +117,12 @@ class LatentDynamicsLoss(CriticLossBase):
info.update(non_quasimetric_dim_mse=non_quasimetric_dim_mse)
loss += non_quasimetric_dim_mse * self.non_quasimetric_dim_mse_weight
+ if self.bidirectional:
+ dist_p2n, dist_n2p = dists.flatten(0, -2).mean(0).unbind(-1)
+ info.update(dist_p2n=dist_p2n, dist_n2p=dist_n2p)
+ else:
+ info.update(dist_n2p=dists.mean())
+
return LossResult(
loss=loss,
info=info, # type: ignore
@@ -175,7 +130,7 @@ class LatentDynamicsLoss(CriticLossBase):
def extra_repr(self) -> str:
return '\n'.join([
- f"kind={self.kind}, gamma={self.gamma!r}",
- f"epsilon={self.epsilon!r}, detach_sp={self.detach_sp}, detach_proj_sp={self.detach_proj_sp}, detach_qmet={self.detach_qmet}",
+ f"epsilon={self.epsilon!r}, bidirectional={self.bidirectional}",
+ f"gamma={self.gamma!r}, detach_sp={self.detach_sp}, detach_proj_sp={self.detach_proj_sp}, detach_qmet={self.detach_qmet}",
])
# return f"weight={self.weight:g}, detach_sp={self.detach_sp}"
diff --git a/quasimetric_rl/modules/quasimetric_critic/models/__init__.py b/quasimetric_rl/modules/quasimetric_critic/models/__init__.py
index 266f592..f9e05c4 100644
--- a/quasimetric_rl/modules/quasimetric_critic/models/__init__.py
+++ b/quasimetric_rl/modules/quasimetric_critic/models/__init__.py
@@ -25,11 +25,6 @@ class QuasimetricCritic(Module):
attrs.validators.le(1),
)))
encoder: Encoder.Conf = Encoder.Conf()
- target_quasimetric_model_ema: Optional[float] = attrs.field(
- default=None, validator=attrs.validators.optional(attrs.validators.and_(
- attrs.validators.ge(0),
- attrs.validators.le(1),
- )))
quasimetric_model: QuasimetricModel.Conf = QuasimetricModel.Conf()
latent_dynamics: Dynamics.Conf = Dynamics.Conf()
@@ -47,46 +42,34 @@ class QuasimetricCritic(Module):
)
return QuasimetricCritic(encoder, quasimetric_model, latent_dynamics,
share_embedding_from=share_embedding_from,
- target_encoder_ema=self.target_encoder_ema,
- target_quasimetric_model_ema=self.target_quasimetric_model_ema)
+ target_encoder_ema=self.target_encoder_ema)
borrowing_embedding: bool
encoder: Encoder
_target_encoder: Optional[Encoder]
target_encoder_ema: Optional[float]
quasimetric_model: QuasimetricModel
- _target_quasimetric_model: Optional[QuasimetricModel]
- target_quasimetric_model_ema: Optional[float]
latent_dynamics: Dynamics # FIXME: add BC to model loading & change name
def __init__(self, encoder: Encoder, quasimetric_model: QuasimetricModel, latent_dynamics: Dynamics,
- share_embedding_from: Optional['QuasimetricCritic'], target_encoder_ema: Optional[float],
- target_quasimetric_model_ema: Optional[float] = None):
+ share_embedding_from: Optional['QuasimetricCritic'], target_encoder_ema: Optional[float]):
super().__init__()
self.borrowing_embedding = share_embedding_from is not None
if share_embedding_from is not None:
encoder = share_embedding_from.encoder
- _target_encoder = share_embedding_from._target_encoder
+ target_encoder = share_embedding_from.target_encoder
assert target_encoder_ema == share_embedding_from.target_encoder_ema
quasimetric_model = share_embedding_from.quasimetric_model
- assert target_quasimetric_model_ema == share_embedding_from.target_quasimetric_model_ema
- _target_quasimetric_model = share_embedding_from._target_quasimetric_model
else:
if target_encoder_ema is None:
- _target_encoder = None
- else:
- _target_encoder = copy.deepcopy(encoder).requires_grad_(False).eval()
- if target_quasimetric_model_ema is None:
- _target_quasimetric_model = None
+ target_encoder = None
else:
- _target_quasimetric_model = copy.deepcopy(quasimetric_model).requires_grad_(False).eval()
+ target_encoder = copy.deepcopy(encoder).requires_grad_(False).eval()
self.encoder = encoder
+ self.add_module('_target_encoder', target_encoder)
self.target_encoder_ema = target_encoder_ema
- self.add_module('_target_encoder', _target_encoder)
self.quasimetric_model = quasimetric_model
- self.target_quasimetric_model_ema = target_quasimetric_model_ema
- self.add_module('_target_quasimetric_model', _target_quasimetric_model)
self.latent_dynamics = latent_dynamics
def forward(self, x: torch.Tensor, y: torch.Tensor, *, action: Optional[torch.Tensor] = None) -> torch.Tensor:
@@ -108,34 +91,16 @@ class QuasimetricCritic(Module):
else:
return self.encoder
- def get_encoder(self, target: bool = False) -> Encoder:
- return self.target_encoder if target else self.encoder
-
- @property
- def target_quasimetric_model(self) -> QuasimetricModel:
- if self._target_quasimetric_model is not None:
- return self._target_quasimetric_model
- else:
- return self.quasimetric_model
-
- def get_quasimetric_model(self, target: bool = False) -> QuasimetricModel:
- return self.target_quasimetric_model if target else self.quasimetric_model
-
@torch.no_grad()
- def update_target_models_(self):
- if not self.borrowing_embedding:
- if self.target_encoder_ema is not None:
- assert self._target_encoder is not None
- for p, p_target in zip(self.encoder.parameters(), self._target_encoder.parameters()):
- p_target.data.lerp_(p.data, 1 - self.target_encoder_ema)
- if self.target_quasimetric_model_ema is not None:
- assert self._target_quasimetric_model is not None
- for p, p_target in zip(self.quasimetric_model.parameters(), self._target_quasimetric_model.parameters()):
- p_target.data.lerp_(p.data, 1 - self.target_quasimetric_model_ema)
+ def update_target_encoder_(self):
+ if not self.borrowing_embedding and self.target_encoder_ema is not None:
+ assert self._target_encoder is not None
+ for p, p_target in zip(self.encoder.parameters(), self._target_encoder.parameters()):
+ p_target.data.lerp_(p.data, 1 - self.target_encoder_ema)
# for type hints
def __call__(self, x: torch.Tensor, y: torch.Tensor, *, action: Optional[torch.Tensor] = None) -> torch.Tensor:
return super().__call__(x, y, action=action)
def extra_repr(self) -> str:
- return f"borrowing_embedding={self.borrowing_embedding}, target_encoder_ema={self.target_encoder_ema}, target_quasimetric_model_ema={self.target_quasimetric_model_ema}"
+ return f"borrowing_embedding={self.borrowing_embedding}"
diff --git a/quasimetric_rl/modules/quasimetric_critic/models/latent_dynamics.py b/quasimetric_rl/modules/quasimetric_critic/models/latent_dynamics.py
index 2774496..0847076 100644
--- a/quasimetric_rl/modules/quasimetric_critic/models/latent_dynamics.py
+++ b/quasimetric_rl/modules/quasimetric_critic/models/latent_dynamics.py
@@ -20,7 +20,6 @@ class Dynamics(MLP):
# config / argparse uses this to specify behavior
arch: Tuple[int, ...] = (512, 512)
- action_arch: Tuple[int, ...] = ()
layer_norm: bool = True
latent_input: bool = True
raw_observation_input_arch: Optional[Tuple[int, ...]] = None
@@ -35,7 +34,6 @@ class Dynamics(MLP):
latent_input=self.latent_input,
raw_observation_input_arch=self.raw_observation_input_arch,
env_spec=env_spec,
- action_arch=self.action_arch,
hidden_sizes=self.arch,
layer_norm=self.layer_norm,
use_latent_space_proj=self.use_latent_space_proj,
@@ -43,8 +41,7 @@ class Dynamics(MLP):
fc_fuse=self.fc_fuse,
)
- action_input: Union[InputEncoding, nn.Sequential]
- action_arch: Tuple[int, ...]
+ action_input: InputEncoding
latent_input: bool
raw_observation_input: Optional[nn.Sequential]
residual: bool
@@ -53,27 +50,11 @@ class Dynamics(MLP):
def __init__(self, *, latent: LatentSpaceConf,
latent_input: bool,
raw_observation_input_arch: Optional[Tuple[int, ...]],
- env_spec: EnvSpec, action_arch: Tuple[int, ...],
- hidden_sizes: Tuple[int, ...], layer_norm: bool,
+ env_spec: EnvSpec, hidden_sizes: Tuple[int, ...], layer_norm: bool,
use_latent_space_proj: bool, residual: bool,
fc_fuse: bool):
- self.action_arch = action_arch
action_input = env_spec.make_action_input()
- action_outsize = action_input.output_size
- if action_arch:
- action_input = nn.Sequential(
- action_input,
- MLP(
- action_input.output_size,
- action_arch[-1],
- hidden_sizes=action_arch[:-1],
- activation_last_fc=True,
- layer_norm_last_fc=layer_norm,
- layer_norm=layer_norm,
- ),
- )
- action_outsize = action_arch[-1]
- mlp_input_size = action_outsize
+ mlp_input_size = action_input.output_size
mlp_output_size = latent.latent_size
self.latent_input = latent_input
if latent_input:
diff --git a/quasimetric_rl/modules/quasimetric_critic/models/quasimetric_model.py b/quasimetric_rl/modules/quasimetric_critic/models/quasimetric_model.py
index 11577ee..589accf 100644
--- a/quasimetric_rl/modules/quasimetric_critic/models/quasimetric_model.py
+++ b/quasimetric_rl/modules/quasimetric_critic/models/quasimetric_model.py
@@ -20,7 +20,7 @@ class L2(torchqmet.QuasimetricBase):
super().__init__(input_size, num_components=1, guaranteed_quasimetric=True,
transforms=[], reduction='sum', discount=None)
- def compute_components(self, x: torch.Tensor, y: torch.Tensor, symmetric_upperbound: bool = False) -> torch.Tensor:
+ def compute_components(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
r'''
Inputs:
x (torch.Tensor): Shape [..., input_size]
@@ -41,11 +41,11 @@ def create_quasimetric_head_from_spec(spec: str) -> torchqmet.QuasimetricBase:
assert dim % components == 0, "IQE: dim must be divisible by components"
return torchqmet.IQE(dim, dim // components)
- def iqe2(*, dim: int, components: int, norm_delta: bool = False, fake_grad: bool = False, reduction: str = 'maxl12_sm') -> torchqmet.IQE:
+ def iqe2(*, dim: int, components: int, scale: bool = False, norm_delta: bool = False, fake_grad: bool = False) -> torchqmet.IQE:
assert dim % components == 0, "IQE: dim must be divisible by components"
return torchqmet.IQE2(
dim, dim // components,
- reduction=reduction,
+ reduction='maxl12_sm' if not scale else 'maxl12_sm_scale',
learned_delta=True,
learned_div=False,
div_init_mul=0.25,
@@ -125,12 +125,10 @@ class QuasimetricModel(Module):
layer_norm=projector_layer_norm,
dropout=projector_dropout,
weight_norm_last_fc=projector_weight_norm,
- unit_norm_last_fc=projector_unit_norm,
- )
+ unit_norm_last_fc=projector_unit_norm)
def forward(self, zx: LatentTensor, zy: LatentTensor, *, bidirectional: bool = False,
- proj_grad_enabled: Tuple[bool, bool] = (True, True), reduced: bool = True,
- symmetric_upperbound: bool = False) -> torch.Tensor:
+ proj_grad_enabled: Tuple[bool, bool] = (True, True)) -> torch.Tensor:
zx = zx[..., :self.input_slice_size]
zy = zy[..., :self.input_slice_size]
with self.projector.requiring_grad(proj_grad_enabled[0]):
@@ -142,7 +140,7 @@ class QuasimetricModel(Module):
px, py = torch.broadcast_tensors(px, py)
px, py = torch.stack([px, py], dim=-2), torch.stack([py, px], dim=-2) # [B x 2 x D]
- return self.quasimetric_head(px, py, reduced=reduced, symmetric_upperbound=symmetric_upperbound)
+ return self.quasimetric_head(px, py)
def parameters(self, recurse: bool = True, *, include_head: bool = True) -> Iterator[Parameter]:
if include_head:
@@ -155,12 +153,8 @@ class QuasimetricModel(Module):
# for type hint
def __call__(self, zx: LatentTensor, zy: LatentTensor, *, bidirectional: bool = False,
- proj_grad_enabled: Tuple[bool, bool] = (True, True), reduced: bool = True,
- symmetric_upperbound: bool = False) -> torch.Tensor:
- return super().__call__(zx, zy, bidirectional=bidirectional,
- proj_grad_enabled=proj_grad_enabled,
- reduced=reduced,
- symmetric_upperbound=symmetric_upperbound)
+ proj_grad_enabled: Tuple[bool, bool] = (True, True)) -> torch.Tensor:
+ return super().__call__(zx, zy, bidirectional=bidirectional, proj_grad_enabled=proj_grad_enabled)
def extra_repr(self) -> str:
return f"input_size={self.input_size}, input_slice_size={self.input_slice_size}"
diff --git a/quasimetric_rl/utils/logging.py b/quasimetric_rl/utils/logging.py
index 6efab8d..2a9e23f 100644
--- a/quasimetric_rl/utils/logging.py
+++ b/quasimetric_rl/utils/logging.py
@@ -4,8 +4,6 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
-from typing import *
-
import sys
import os
import logging
@@ -36,12 +34,10 @@ class TqdmLoggingHandler(logging.Handler):
class MultiLineFormatter(logging.Formatter):
- _fmt: str
def __init__(self, fmt=None, datefmt=None, style='%'):
assert style == '%'
super(MultiLineFormatter, self).__init__(fmt, datefmt, style)
- assert fmt is not None
self.multiline_fmt = fmt
def format(self, record):
@@ -79,7 +75,7 @@ class MultiLineFormatter(logging.Formatter):
output += '\n'.join(
self.multiline_fmt % dict(record.__dict__, message=line)
for index, line
- in enumerate(record.exc_text.decode(sys.getfilesystemencoding(), 'replace').splitlines()) # type: ignore
+ in enumerate(record.exc_text.decode(sys.getfilesystemencoding(), 'replace').splitlines())
)
return output
@@ -100,7 +96,7 @@ def configure(logging_file, log_level=logging.INFO, level_prefix='', prefix='',
sys.excepthook = handle_exception # automatically log uncaught errors
- handlers: List[logging.Handler] = []
+ handlers = []
if write_to_stdout:
handlers.append(TqdmLoggingHandler())
Submodule third_party/torch-quasimetric 186ed87..129eb90 (rewind):
diff --git a/third_party/torch-quasimetric/torchqmet/__init__.py b/third_party/torch-quasimetric/torchqmet/__init__.py
index b776de5..3afb467 100644
--- a/third_party/torch-quasimetric/torchqmet/__init__.py
+++ b/third_party/torch-quasimetric/torchqmet/__init__.py
@@ -54,22 +54,19 @@ class QuasimetricBase(nn.Module, metaclass=abc.ABCMeta):
'''
pass
- def forward(self, x: torch.Tensor, y: torch.Tensor, *, reduced: bool = True, **kwargs) -> torch.Tensor:
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
assert x.shape[-1] == y.shape[-1] == self.input_size
- d = self.compute_components(x, y, **kwargs)
+ d = self.compute_components(x, y)
d: torch.Tensor = self.transforms(d)
scale = self.scale
if not self.training:
scale = scale.detach()
- if reduced:
- return self.reduction(d) * scale
- else:
- return d * scale
+ return self.reduction(d) * scale
- def __call__(self, x: torch.Tensor, y: torch.Tensor, reduced: bool = True, **kwargs) -> torch.Tensor:
+ def __call__(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
# Manually define for typing
# https://github.com/pytorch/pytorch/issues/45414
- return super().__call__(x, y, reduced=reduced, **kwargs)
+ return super().__call__(x, y)
def extra_repr(self) -> str:
return f"guaranteed_quasimetric={self.guaranteed_quasimetric}\ninput_size={self.input_size}, num_components={self.num_components}" + (
@@ -82,5 +79,5 @@ from .iqe import IQE, IQE2
from .mrn import MRN, MRNFixed
from .neural_norms import DeepNorm, WideNorm
-__all__ = ['PQE', 'PQELH', 'PQEGG', 'IQE', 'IQE2', 'MRN', 'MRNFixed', 'DeepNorm', 'WideNorm']
+__all__ = ['PQE', 'PQELH', 'PQEGG', 'IQE', 'MRN', 'MRNFixed', 'DeepNorm', 'WideNorm']
__version__ = "0.1.0"
diff --git a/third_party/torch-quasimetric/torchqmet/iqe.py b/third_party/torch-quasimetric/torchqmet/iqe.py
index f084b6d..bc03f05 100644
--- a/third_party/torch-quasimetric/torchqmet/iqe.py
+++ b/third_party/torch-quasimetric/torchqmet/iqe.py
@@ -12,6 +12,7 @@ from . import QuasimetricBase
# The PQELH function.
+@torch.jit.script
def f_PQELH(h: torch.Tensor): # PQELH: strictly monotonically increasing mapping from [0, +infty) -> [0, 1)
return -torch.expm1(-h)
@@ -20,28 +21,28 @@ def iqe_tensor_delta(x: torch.Tensor, y: torch.Tensor, delta: torch.Tensor, div_
D = x.shape[-1] # D: component_dim
# ignore pairs that x >= y
- valid = (x < y) # [..., K, D]
+ valid = (x < y)
# sort to better count
- xy = torch.cat(torch.broadcast_tensors(x, y), dim=-1) # [..., K, 2D]
+ xy = torch.cat(torch.broadcast_tensors(x, y), dim=-1)
sxy, ixy = xy.sort(dim=-1)
# neg_inc: the **negated** increment of **input** of f at sorted locations
# inc = torch.gather(delta * valid, dim=-1, index=ixy % D) * torch.where(ixy < D, 1, -1)
- neg_inc = torch.gather(delta * valid, dim=-1, index=ixy % D) * torch.where(ixy < D, -1, 1) # [..., K, 2D-sort]
+ neg_inc = torch.gather(delta * valid, dim=-1, index=ixy % D) * torch.where(ixy < D, -1, 1)
# neg_incf: the **negated** increment of **output** of f at sorted locations
neg_f_input = torch.cumsum(neg_inc, dim=-1) / div_pre_f[:, None]
if fake_grad:
neg_f_input__grad_path = neg_f_input.clone()
- neg_f_input__grad_path.data.clamp_(min=-15) # fake grad
+ neg_f_input__grad_path.data.clamp_(max=17) # fake grad
neg_f_input = neg_f_input__grad_path + (
neg_f_input - neg_f_input__grad_path
).detach()
neg_f = torch.expm1(neg_f_input)
- neg_incf = torch.diff(neg_f, dim=-1, prepend=neg_f.new_zeros(()).expand_as(neg_f[..., :1]))
+ neg_incf = torch.cat([neg_f.narrow(-1, 0, 1), torch.diff(neg_f, dim=-1)], dim=-1)
# reduction
if neg_incf.ndim == 3:
@@ -62,6 +63,7 @@ def iqe_tensor_delta(x: torch.Tensor, y: torch.Tensor, delta: torch.Tensor, div_
return comp
+
def iqe(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
D = x.shape[-1] # D: dim_per_component
@@ -87,35 +89,19 @@ def iqe(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
# neg_inp = torch.where(neg_inp_copies == 0, 0., -delta)
# f output: 0 -> 0, x -> 1.
neg_f = (neg_inp_copies < 0) * (-1.)
- neg_incf = torch.diff(neg_f, dim=-1, prepend=neg_f.new_zeros(()).expand_as(neg_f[..., :1]))
+ neg_incf = torch.cat([neg_f.narrow(-1, 0, 1), torch.diff(neg_f, dim=-1)], dim=-1)
# reduction
return (sxy * neg_incf).sum(-1)
-def is_notebook():
- r"""
- Inspired by
- https://github.com/tqdm/tqdm/blob/cc372d09dcd5a5eabdc6ed4cf365bdb0be004d44/tqdm/autonotebook.py
- """
- import sys
- try:
- get_ipython = sys.modules['IPython'].get_ipython
- if 'IPKernelApp' not in get_ipython().config: # pragma: no cover
- raise ImportError("console")
- except Exception:
- return False
- else: # pragma: no cover
- return True
-
-
-if torch.__version__ >= '2.0.1' and not is_notebook(): # well, broken process pool in notebooks
+if torch.__version__ >= '2.0.1' and False: # well, broken process pool in notebooks
iqe = torch.compile(iqe)
iqe_tensor_delta = torch.compile(iqe_tensor_delta)
# iqe = torch.compile(iqe, dynamic=True)
else:
- iqe = torch.jit.script(iqe) # type: ignore
- iqe_tensor_delta = torch.jit.script(iqe_tensor_delta) # type: ignore
+ iqe = torch.jit.script(iqe)
+ iqe_tensor_delta = torch.jit.script(iqe_tensor_delta)
class IQE(QuasimetricBase):
@@ -245,8 +231,8 @@ class IQE2(IQE):
ema_weight: float = 0.95):
super().__init__(input_size, dim_per_component, transforms=transforms, reduction=reduction,
discount=discount, warn_if_not_quasimetric=warn_if_not_quasimetric)
- self.component_dropout_thresh = tuple(component_dropout_thresh) # type: ignore
- self.dropout_p_thresh = tuple(dropout_p_thresh) # type: ignore
+ self.component_dropout_thresh = tuple(component_dropout_thresh)
+ self.dropout_p_thresh = tuple(dropout_p_thresh)
self.dropout_batch_frac = float(dropout_batch_frac)
self.fake_grad = fake_grad
assert 0 <= self.dropout_batch_frac <= 1
@@ -263,7 +249,7 @@ class IQE2(IQE):
# )
self.register_parameter(
'raw_delta',
- torch.nn.Parameter( # type: ignore
+ torch.nn.Parameter(
torch.zeros(self.latent_2d_shape).requires_grad_()
)
)
@@ -284,7 +270,7 @@ class IQE2(IQE):
self.register_parameter(
'raw_div',
- torch.nn.Parameter(torch.zeros(self.num_components).requires_grad_()) # type: ignore
+ torch.nn.Parameter(torch.zeros(self.num_components).requires_grad_())
)
else:
self.register_buffer(
@@ -299,11 +285,11 @@ class IQE2(IQE):
self.div_init_mul = div_init_mul
self.mul_kind = mul_kind
- self.last_components = None # type: ignore
- self.last_drop_p = None # type: ignore
+ self.last_components = None
+ self.last_drop_p = None
- def compute_components(self, x: torch.Tensor, y: torch.Tensor, *, symmetric_upperbound: bool = False) -> torch.Tensor:
+ def compute_components(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
# if self.raw_delta is None:
# components = super().compute_components(x, y)
# else:
@@ -313,9 +299,6 @@ class IQE2(IQE):
delta.data.clamp_(max=1e3 / (self.latent_2d_shape[-1] / 8))
div_pre_f.data.clamp_(min=1e-3)
- if symmetric_upperbound:
- x, y = torch.minimum(x, y), torch.maximum(x, y)
-
components = iqe_tensor_delta(
x=x.unflatten(-1, self.latent_2d_shape),
y=y.unflatten(-1, self.latent_2d_shape),
diff --git a/third_party/torch-quasimetric/torchqmet/reductions.py b/third_party/torch-quasimetric/torchqmet/reductions.py
index a87be8b..7681242 100644
--- a/third_party/torch-quasimetric/torchqmet/reductions.py
+++ b/third_party/torch-quasimetric/torchqmet/reductions.py
@@ -59,11 +59,6 @@ class Mean(ReductionBase):
return d.mean(dim=-1)
-class L2(ReductionBase):
- def reduce_distance(self, d: torch.Tensor) -> torch.Tensor:
- return d.norm(p=2, dim=-1)
-
-
class MaxMean(ReductionBase):
r'''
`maxmean` from Neural Norms paper:
@@ -149,7 +144,7 @@ class MaxL12_PGsm(ReductionBase):
super().__init__(input_num_components=input_num_components, discount=discount)
self.raw_alpha = nn.Parameter(torch.tensor([0., 0., 0., 0.], dtype=torch.float32).requires_grad_()) # pre normalizing
self.raw_alpha_w = nn.Parameter(torch.tensor([0., 0., 0.], dtype=torch.float32).requires_grad_()) # pre normalizing
- self.last_logp = None # type: ignore
+ self.last_logp = None
self.on_pi = True
# self.last_p = None
@@ -227,7 +222,7 @@ class MaxL12_PG3(ReductionBase):
super().__init__(input_num_components=input_num_components, discount=discount)
self.raw_alpha = nn.Parameter(torch.tensor([0., 0., 0.], dtype=torch.float32).requires_grad_()) # pre normalizing
self.raw_alpha_w = torch.tensor([], dtype=torch.float32) # just to make logging easier
- self.last_logp = None # type: ignore
+ self.last_logp = None
self.on_pi = True
# self.last_p = None
@@ -327,7 +322,6 @@ class DeepLinearNetWeightedSum(ReductionBase):
REDUCTIONS: Mapping[str, Type[ReductionBase]] = dict(
sum=Sum,
mean=Mean,
- l2=L2,
maxmean=MaxMean,
maxl12=MaxL12,
maxl12_sm=MaxL12_sm,
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment