Created
March 16, 2024 14:03
-
-
Save ssnl/ebd0e033ba7c87641709b82164370f36 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/quasimetric_rl/data/base.py b/quasimetric_rl/data/base.py | |
index 09711c4..8412f7d 100644 | |
--- a/quasimetric_rl/data/base.py | |
+++ b/quasimetric_rl/data/base.py | |
@@ -208,6 +208,7 @@ class Dataset(torch.utils.data.Dataset): | |
kind: str = MISSING # d4rl, gcrl, etc. | |
name: str = MISSING # maze2d-umaze-v1, etc. | |
+ horizon: int = attrs.field(default=1, validator=attrs.validators.gt(0)) # type: ignore | |
# Defines how to fetch the future observation. smaller -> more recent | |
future_observation_discount: float = attrs.field(default=0.99, validator=attrs.validators.and_( | |
@@ -217,6 +218,7 @@ class Dataset(torch.utils.data.Dataset): | |
def make(self, *, dummy: bool = False) -> 'Dataset': | |
return Dataset(self.kind, self.name, | |
+ horizon=self.horizon, | |
future_observation_discount=self.future_observation_discount, | |
dummy=dummy) | |
@@ -278,6 +280,7 @@ class Dataset(torch.utils.data.Dataset): | |
return NORMALIZE_SCORE_REGISTRY[self.kind, self.name](timestep_reward) | |
def __init__(self, kind: str, name: str, *, | |
+ horizon: int, | |
future_observation_discount: float, | |
dummy: bool = False, # when you don't want to load data, e.g., in analysis | |
device: torch.device = torch.device('cpu'), # FIXME: get some heuristic | |
@@ -285,6 +288,7 @@ class Dataset(torch.utils.data.Dataset): | |
self.kind = kind | |
self.name = name | |
self.future_observation_discount = future_observation_discount | |
+ self.horizon = horizon | |
self.env_spec = EnvSpec.from_env(self.create_env()) | |
@@ -297,19 +301,19 @@ class Dataset(torch.utils.data.Dataset): | |
from .utils import get_empty_episode | |
episodes = (get_empty_episode(self.env_spec, episode_length=1),) | |
- obs_indices_to_obs_index_in_episode = [] | |
+ # obs_indices_to_obs_index_in_episode = [] | |
indices_to_episode_indices = [] | |
indices_to_episode_timesteps = [] | |
for eidx, episode in enumerate(episodes): | |
l = episode.num_transitions | |
- obs_indices_to_obs_index_in_episode.append(torch.arange(l + 1, dtype=torch.int64)) | |
- indices_to_episode_indices.append(torch.full([l], eidx, dtype=torch.int64)) | |
- indices_to_episode_timesteps.append(torch.arange(l, dtype=torch.int64)) | |
+ # obs_indices_to_obs_index_in_episode.append(torch.arange(l + 1, dtype=torch.int64)) | |
+ indices_to_episode_indices.append(torch.full([l + 1 - horizon], eidx, dtype=torch.int64)) | |
+ indices_to_episode_timesteps.append(torch.arange(l + 1 - horizon, dtype=torch.int64)) | |
assert len(episodes) > 0, "must have at least one episode" | |
self.raw_data = MultiEpisodeData.cat(episodes).to(device) | |
- self.obs_indices_to_obs_index_in_episode = torch.cat(obs_indices_to_obs_index_in_episode, dim=0).to(device) | |
+ # self.obs_indices_to_obs_index_in_episode = torch.cat(obs_indices_to_obs_index_in_episode, dim=0).to(device) | |
self.indices_to_episode_indices = torch.cat(indices_to_episode_indices, dim=0).to(device) | |
self.indices_to_episode_timesteps = torch.cat(indices_to_episode_timesteps, dim=0).to(device) | |
self.max_episode_length = int(self.raw_data.episode_lengths.max().item()) | |
@@ -324,19 +328,22 @@ class Dataset(torch.utils.data.Dataset): | |
def __getitem__(self, indices: torch.Tensor) -> BatchData: | |
indices = torch.as_tensor(indices, device=self.device) | |
eindices = self.indices_to_episode_indices[indices] | |
+ indices = indices[..., None] + torch.arange(self.horizon + 1, device=self.device) | |
+ eindices = eindices[..., None].expand(indices.shape) | |
obs_indices = indices + eindices # index for `observation`: skip the s_last from previous episodes | |
obs = self.get_observations(obs_indices) | |
- nobs = self.get_observations(obs_indices + 1) | |
- | |
- terminals = self.raw_data.terminals[indices] | |
+ # nobs = self.get_observations(obs_indices + 1) | |
+ indices = indices[..., :-1] # remove the last one which is the next observation | |
+ obs_indices = obs_indices[..., :-1] | |
+ eindices = eindices[..., :-1] | |
tindices = self.indices_to_episode_timesteps[indices] | |
epilengths = self.raw_data.episode_lengths[eindices] # max idx is this | |
deltas = torch.arange(self.max_episode_length, device=self.device) | |
pdeltas = torch.where( | |
# test tidx + 1 + delta <= max_idx = epi_length | |
- (tindices[:, None] + deltas) < epilengths[:, None], | |
- self.future_observation_discount ** deltas, | |
+ (tindices[..., None] + deltas) < epilengths[..., None], | |
+ (self.future_observation_discount ** deltas).expand(tindices.shape + (self.max_episode_length,)), | |
0, | |
) | |
deltas = torch.distributions.Categorical( | |
@@ -346,24 +353,25 @@ class Dataset(torch.utils.data.Dataset): | |
future_observations = self.get_observations(obs_indices + deltas) | |
return BatchData( | |
- observations=obs, | |
- actions=self.raw_data.actions[indices], | |
- next_observations=nobs, | |
- future_observations=future_observations, | |
- future_tdelta=deltas, | |
- rewards=self.raw_data.rewards[indices], | |
- terminals=terminals, | |
- timeouts=self.raw_data.timeouts[indices], | |
+ observations=obs.select(indices.ndim - 1, 0), | |
+ actions=self.raw_data.actions[indices].squeeze(indices.ndim - 1), | |
+ next_observations=obs.select(indices.ndim - 1, -1), | |
+ future_observations=future_observations.squeeze(indices.ndim - 1), | |
+ future_tdelta=deltas.squeeze(indices.ndim - 1), | |
+ rewards=self.raw_data.rewards[indices].squeeze(indices.ndim - 1), | |
+ terminals=self.raw_data.terminals[indices].squeeze(indices.ndim - 1), | |
+ timeouts=self.raw_data.timeouts[indices].squeeze(indices.ndim - 1), | |
) | |
def __len__(self): | |
- return self.raw_data.num_transitions | |
+ return self.indices_to_episode_indices.shape[0] | |
def __repr__(self): | |
return rf""" | |
{self.__class__.__name__}( | |
kind={self.kind!r}, | |
name={self.name!r}, | |
+ horizon={self.horizon!r}, | |
future_observation_discount={self.future_observation_discount!r}, | |
env_spec={self.env_spec!r}, | |
)""".lstrip('\n') | |
diff --git a/quasimetric_rl/modules/__init__.py b/quasimetric_rl/modules/__init__.py | |
index 1ed2315..81b5443 100644 | |
--- a/quasimetric_rl/modules/__init__.py | |
+++ b/quasimetric_rl/modules/__init__.py | |
@@ -9,6 +9,7 @@ from . import actor, quasimetric_critic | |
from ..data import EnvSpec, BatchData | |
from .utils import LossResult, Module, InfoT, InfoValT | |
+from ..flags import FLAGS | |
class QRLAgent(Module): | |
@@ -33,8 +34,7 @@ class QRLLosses(Module): | |
critics_total_grad_clip_norm: Optional[float], | |
recompute_critic_for_actor_loss: bool, | |
critics_share_embedding: bool, | |
- critic_losses_use_target_encoder: bool, | |
- actor_loss_uses_target_encoder: bool): | |
+ critic_losses_use_target_encoder: bool): | |
super().__init__() | |
self.add_module('actor_loss', actor_loss) | |
self.critic_losses = torch.nn.ModuleList(critic_losses) # type: ignore | |
@@ -42,7 +42,6 @@ class QRLLosses(Module): | |
self.recompute_critic_for_actor_loss = recompute_critic_for_actor_loss | |
self.critics_share_embedding = critics_share_embedding | |
self.critic_losses_use_target_encoder = critic_losses_use_target_encoder | |
- self.actor_loss_uses_target_encoder = actor_loss_uses_target_encoder | |
def forward(self, agent: QRLAgent, data: BatchData, *, optimize: bool = True) -> LossResult: | |
# compute CriticBatchInfo | |
@@ -54,11 +53,7 @@ class QRLLosses(Module): | |
stack.enter_context(critic_loss.optim_update_context(optimize=optimize)) | |
if self.critics_share_embedding and idx > 0: | |
- critic_batch_info = quasimetric_critic.CriticBatchInfo( | |
- critic=critic, | |
- zx=critic_batch_infos[0].zx, | |
- zy=critic_batch_infos[0].zy, | |
- ) | |
+ critic_batch_info = attrs.evolve(critic_batch_infos[0], critic=critic) | |
else: | |
zx = critic.encoder(data.observations) | |
zy = critic.get_encoder(target=self.critic_losses_use_target_encoder)(data.next_observations) | |
@@ -68,40 +63,96 @@ class QRLLosses(Module): | |
critic=critic, | |
zx=zx, | |
zy=zy, | |
+ zy_from_target_encoder=self.critic_losses_use_target_encoder, | |
) | |
loss_results[f"critic_{idx:02d}"] = critic_loss(data, critic_batch_info) # we update together to handle shared embedding | |
critic_batch_infos.append(critic_batch_info) | |
+ critic_grad_norm: InfoValT = {} | |
+ | |
+ if FLAGS.DEBUG: | |
+ def get_grad(loss: Union[torch.Tensor, float]) -> Union[torch.Tensor, float]: | |
+ if isinstance(loss, (int, float)): | |
+ return 0 | |
+ loss_grads = torch.autograd.grad( | |
+ loss, | |
+ list(cast(torch.nn.ModuleList, agent.critics).parameters()), | |
+ retain_graph=True, | |
+ allow_unused=True, | |
+ ) | |
+ return cast( | |
+ torch.Tensor, | |
+ sum(pg.pow(2).sum() for pg in loss_grads if pg is not None), | |
+ ).sqrt() | |
+ | |
+ for k, loss_r in loss_results.items(): | |
+ critic_grad_norm.update({ | |
+ k: loss_r.map_losses(get_grad), | |
+ }) | |
+ | |
torch.stack( | |
- [cast(torch.Tensor, loss_r.loss) for loss_r in loss_results.values()] | |
+ [loss_r.total_loss for loss_r in loss_results.values()] | |
).sum().backward() | |
if self.critics_total_grad_clip_norm is not None: | |
- torch.nn.utils.clip_grad_norm_(cast(torch.nn.ModuleList, agent.critics).parameters(), | |
- max_norm=self.critics_total_grad_clip_norm) | |
+ critic_grad_norm['total'] = torch.nn.utils.clip_grad_norm_( | |
+ cast(torch.nn.ModuleList, agent.critics).parameters(), | |
+ max_norm=self.critics_total_grad_clip_norm) | |
+ else: | |
+ critic_grad_norm['total'] = cast( | |
+ torch.Tensor, | |
+ sum(p.grad.pow(2).sum() for p in cast(torch.nn.ModuleList, agent.critics).parameters() if p.grad is not None), | |
+ ).sqrt() | |
if optimize: | |
for critic in agent.critics: | |
- critic.update_target_encoder_() | |
+ critic.update_target_models_() | |
+ actor_grad_norm: InfoValT = {} | |
if self.actor_loss is not None: | |
assert agent.actor is not None | |
with torch.no_grad(), torch.inference_mode(): | |
- for idx, critic in enumerate(agent.critics): | |
- if self.recompute_critic_for_actor_loss or (critic.has_separate_target_encoder and self.actor_loss_uses_target_encoder): | |
- zx, zy = critic.get_encoder(target=self.actor_loss_uses_target_encoder)( | |
+ for idx, critic in enumerate(agent.critics): # FIXME? | |
+ if self.recompute_critic_for_actor_loss or (critic.has_separate_target_encoder and self.actor_loss.use_target_encoder): | |
+ zx, zy = critic.get_encoder(target=self.actor_loss.use_target_encoder)( | |
torch.stack([data.observations, data.next_observations], dim=0)).unbind(0) | |
critic_batch_infos[idx] = quasimetric_critic.CriticBatchInfo( | |
critic=critic, | |
zx=zx, | |
zy=zy, | |
+ zx_from_target_encoder=self.actor_loss.use_target_encoder, | |
+ zy_from_target_encoder=self.actor_loss.use_target_encoder, | |
) | |
with self.actor_loss.optim_update_context(optimize=optimize): | |
loss_results['actor'] = loss_r = self.actor_loss(agent.actor, critic_batch_infos, data) | |
- cast(torch.Tensor, loss_r.loss).backward() | |
- return LossResult.combine(loss_results) | |
+ if FLAGS.DEBUG: | |
+ def get_grad(loss: Union[torch.Tensor, float]) -> Union[torch.Tensor, float]: | |
+ if isinstance(loss, (int, float)): | |
+ return 0 | |
+ assert agent.actor is not None | |
+ loss_grads = torch.autograd.grad( | |
+ loss, | |
+ list(agent.actor.parameters()), | |
+ retain_graph=True, | |
+ allow_unused=True, | |
+ ) | |
+ return cast( | |
+ torch.Tensor, | |
+ sum(pg.pow(2).sum() for pg in loss_grads if pg is not None), | |
+ ).sqrt() | |
+ | |
+ actor_grad_norm.update(cast(Mapping, loss_r.map_losses(get_grad))) | |
+ | |
+ loss_r.total_loss.backward() | |
+ | |
+ actor_grad_norm['total'] = cast( | |
+ torch.Tensor, | |
+ sum(p.grad.pow(2).sum() for p in agent.actor.parameters() if p.grad is not None), | |
+ ).sqrt() | |
+ | |
+ return LossResult.combine(loss_results, grad_norm=dict(critic=critic_grad_norm, actor=actor_grad_norm)) | |
# for type hints | |
def __call__(self, agent: QRLAgent, data: BatchData, *, optimize: bool = True) -> LossResult: | |
@@ -152,7 +203,6 @@ class QRLLosses(Module): | |
f'critics_share_embedding={self.critics_share_embedding}', | |
f'critics_total_grad_clip_norm={self.critics_total_grad_clip_norm}', | |
f'critic_losses_use_target_encoder={self.critic_losses_use_target_encoder}', | |
- f'actor_loss_uses_target_encoder={self.actor_loss_uses_target_encoder}', | |
]) | |
@@ -167,7 +217,6 @@ class QRLConf: | |
) | |
recompute_critic_for_actor_loss: bool = False | |
critic_losses_use_target_encoder: bool = True | |
- actor_loss_uses_target_encoder: bool = True | |
def make(self, *, env_spec: EnvSpec, total_optim_steps: int) -> Tuple[QRLAgent, QRLLosses]: | |
if self.actor is None: | |
@@ -193,7 +242,6 @@ class QRLConf: | |
critics_total_grad_clip_norm=self.critics_total_grad_clip_norm, | |
recompute_critic_for_actor_loss=self.recompute_critic_for_actor_loss, | |
critic_losses_use_target_encoder=self.critic_losses_use_target_encoder, | |
- actor_loss_uses_target_encoder=self.actor_loss_uses_target_encoder, | |
) | |
__all__ = ['QRLAgent', 'QRLLosses', 'QRLConf', 'InfoT', 'InfoValT'] | |
diff --git a/quasimetric_rl/modules/actor/losses/__init__.py b/quasimetric_rl/modules/actor/losses/__init__.py | |
index c8f1ec7..f9c78a1 100644 | |
--- a/quasimetric_rl/modules/actor/losses/__init__.py | |
+++ b/quasimetric_rl/modules/actor/losses/__init__.py | |
@@ -15,12 +15,15 @@ from ...optim import OptimWrapper, AdamWSpec, LRScheduler | |
class ActorLossBase(LossBase): | |
@abc.abstractmethod | |
- def forward(self, actor: Actor, critic_batch_infos: Collection[CriticBatchInfo], data: BatchData) -> LossResult: | |
+ def forward(self, actor: Actor, critic_batch_infos: Collection[CriticBatchInfo], data: BatchData, | |
+ use_target_encoder: bool, use_target_quasimetric_model: bool) -> LossResult: | |
pass | |
# for type hints | |
- def __call__(self, actor: Actor, critic_batch_infos: Collection[CriticBatchInfo], data: BatchData) -> LossResult: | |
- return super().__call__(actor, critic_batch_infos, data) | |
+ def __call__(self, actor: Actor, critic_batch_infos: Collection[CriticBatchInfo], data: BatchData, | |
+ use_target_encoder: bool, use_target_quasimetric_model: bool) -> LossResult: | |
+ return super().__call__(actor, critic_batch_infos, data, use_target_encoder=use_target_encoder, | |
+ use_target_quasimetric_model=use_target_quasimetric_model) | |
from .min_dist import MinDistLoss | |
@@ -40,6 +43,9 @@ class ActorLosses(ActorLossBase): | |
actor_optim: AdamWSpec.Conf = AdamWSpec.Conf(lr=3e-5) | |
entropy_weight_optim: AdamWSpec.Conf = AdamWSpec.Conf(lr=3e-4) # TODO: move to loss | |
+ use_target_encoder: bool = True | |
+ use_target_quasimetric_model: bool = True | |
+ | |
def make(self, actor: Actor, total_optim_steps: int, env_spec: EnvSpec) -> 'ActorLosses': | |
return ActorLosses( | |
actor, | |
@@ -49,6 +55,8 @@ class ActorLosses(ActorLossBase): | |
advantage_weighted_regression=self.advantage_weighted_regression.make(), | |
actor_optim_spec=self.actor_optim.make(), | |
entropy_weight_optim_spec=self.entropy_weight_optim.make(), | |
+ use_target_encoder=self.use_target_encoder, | |
+ use_target_quasimetric_model=self.use_target_quasimetric_model, | |
) | |
min_dist: Optional[MinDistLoss] | |
@@ -60,10 +68,14 @@ class ActorLosses(ActorLossBase): | |
entropy_weight_optim: OptimWrapper | |
entropy_weight_sched: LRScheduler | |
+ use_target_encoder: bool | |
+ use_target_quasimetric_model: bool | |
+ | |
def __init__(self, actor: Actor, *, total_optim_steps: int, | |
min_dist: Optional[MinDistLoss], behavior_cloning: Optional[BCLoss], | |
advantage_weighted_regression: Optional[AWRLoss], | |
- actor_optim_spec: AdamWSpec, entropy_weight_optim_spec: AdamWSpec): | |
+ actor_optim_spec: AdamWSpec, entropy_weight_optim_spec: AdamWSpec, | |
+ use_target_encoder: bool, use_target_quasimetric_model: bool): | |
super().__init__() | |
self.add_module('min_dist', min_dist) | |
self.add_module('behavior_cloning', behavior_cloning) | |
@@ -76,6 +88,9 @@ class ActorLosses(ActorLossBase): | |
if min_dist is not None: | |
assert len(list(min_dist.parameters())) <= 1 | |
+ self.use_target_encoder = use_target_encoder | |
+ self.use_target_quasimetric_model = use_target_quasimetric_model | |
+ | |
def optimizers(self) -> Iterable[OptimWrapper]: | |
return [self.actor_optim, self.entropy_weight_optim] | |
@@ -86,18 +101,24 @@ class ActorLosses(ActorLossBase): | |
loss_results: Dict[str, LossResult] = {} | |
if self.min_dist is not None: | |
loss_results.update( | |
- min_dist=self.min_dist(actor, critic_batch_infos, data), | |
+ min_dist=self.min_dist(actor, critic_batch_infos, data, self.use_target_encoder, self.use_target_quasimetric_model), | |
) | |
if self.behavior_cloning is not None: | |
loss_results.update( | |
- bc=self.behavior_cloning(actor, critic_batch_infos, data), | |
+ bc=self.behavior_cloning(actor, critic_batch_infos, data, self.use_target_encoder, self.use_target_quasimetric_model), | |
) | |
if self.advantage_weighted_regression is not None: | |
loss_results.update( | |
- awr=self.advantage_weighted_regression(actor, critic_batch_infos, data), | |
+ awr=self.advantage_weighted_regression(actor, critic_batch_infos, data, self.use_target_encoder, self.use_target_quasimetric_model), | |
) | |
return LossResult.combine(loss_results) | |
# for type hints | |
def __call__(self, actor: Actor, critic_batch_infos: Sequence[CriticBatchInfo], data: BatchData) -> LossResult: | |
return torch.nn.Module.__call__(self, actor, critic_batch_infos, data) | |
+ | |
+ def extra_repr(self) -> str: | |
+ return '\n'.join([ | |
+ f"actor_optim={self.actor_optim}, entropy_weight_optim={self.entropy_weight_optim}", | |
+ f"use_target_encoder={self.use_target_encoder}, use_target_quasimetric_model={self.use_target_quasimetric_model}", | |
+ ]) | |
diff --git a/quasimetric_rl/modules/actor/losses/awr.py b/quasimetric_rl/modules/actor/losses/awr.py | |
index f6cf4de..1313dc3 100644 | |
--- a/quasimetric_rl/modules/actor/losses/awr.py | |
+++ b/quasimetric_rl/modules/actor/losses/awr.py | |
@@ -81,8 +81,8 @@ class AWRLoss(ActorLossBase): | |
self.clamp = clamp | |
self.add_goal_as_future_state = add_goal_as_future_state | |
- def gather_obs_goal_pairs(self, critic_batch_infos: Sequence[CriticBatchInfo], | |
- data: BatchData) -> Tuple[torch.Tensor, torch.Tensor, Collection[ActorObsGoalCriticInfo]]: | |
+ def gather_obs_goal_pairs(self, critic_batch_infos: Sequence[CriticBatchInfo], data: BatchData, | |
+ use_target_encoder: bool) -> Tuple[torch.Tensor, torch.Tensor, Collection[ActorObsGoalCriticInfo]]: | |
r""" | |
Returns ( | |
obs, | |
@@ -115,7 +115,7 @@ class AWRLoss(ActorLossBase): | |
# add future_observations | |
zg = torch.stack([ | |
zg, | |
- critic_batch_info.critic.target_encoder(data.future_observations), | |
+ critic_batch_info.critic.get_encoder(target=use_target_encoder)(data.future_observations), | |
], 0) | |
# zo = zo.expand_as(zg) | |
@@ -127,9 +127,11 @@ class AWRLoss(ActorLossBase): | |
return obs, goal, actor_obs_goal_critic_infos | |
- def forward(self, actor: Actor, critic_batch_infos: Sequence[CriticBatchInfo], data: BatchData) -> LossResult: | |
+ def forward(self, actor: Actor, critic_batch_infos: Sequence[CriticBatchInfo], data: BatchData, | |
+ use_target_encoder: bool, use_target_quasimetric_model: bool) -> LossResult: | |
+ | |
with torch.no_grad(): | |
- obs, goal, actor_obs_goal_critic_infos = self.gather_obs_goal_pairs(critic_batch_infos, data) | |
+ obs, goal, actor_obs_goal_critic_infos = self.gather_obs_goal_pairs(critic_batch_infos, data, use_target_encoder) | |
info: Dict[str, Union[float, torch.Tensor]] = {} | |
@@ -139,11 +141,13 @@ class AWRLoss(ActorLossBase): | |
for idx, actor_obs_goal_critic_info in enumerate(actor_obs_goal_critic_infos): | |
critic = actor_obs_goal_critic_info.critic | |
+ latent_dynamics = critic.latent_dynamics | |
+ quasimetric_model = critic.get_quasimetric_model(target=use_target_quasimetric_model) | |
zo = actor_obs_goal_critic_info.zo.detach() # [B,D] | |
zg = actor_obs_goal_critic_info.zg.detach() # [2?,B,D] | |
with torch.no_grad(), critic.mode(False): | |
- zp = critic.latent_dynamics(data.observations, zo, data.actions) # [B,D] | |
- if not critic.borrowing_embedding: | |
+ zp = latent_dynamics(data.observations, zo, data.actions) # [B,D] | |
+ if idx == 0 or not critic.borrowing_embedding: | |
zo, zp, zg = bcast_bshape( | |
(zo, 1), | |
(zp, 1), | |
@@ -151,10 +155,10 @@ class AWRLoss(ActorLossBase): | |
) | |
z = torch.stack([zo, zp], dim=0) # [2,2?,B,D] | |
# z = z[(slice(None),) + (None,) * (zg.ndim - zo.ndim) + (Ellipsis,)] # if zg is batched, add enough dim to be broadcast-able | |
- dist_noact, dist = critic.quasimetric_model(z, zg).unbind(0) # [2,2?,B] -> 2x [2?,B] | |
+ dist_noact, dist = quasimetric_model(z, zg).unbind(0) # [2,2?,B] -> 2x [2?,B] | |
dist_noact = dist_noact.detach() | |
else: | |
- dist = critic.quasimetric_model(zp, zg) | |
+ dist = quasimetric_model(zp, zg) | |
dist_noact = dists_noact[0] | |
info[f'dist_delta_{idx:02d}'] = (dist_noact - dist).mean() | |
info[f'dist_{idx:02d}'] = dist.mean() | |
diff --git a/quasimetric_rl/modules/actor/losses/behavior_cloning.py b/quasimetric_rl/modules/actor/losses/behavior_cloning.py | |
index 632221e..eba36af 100644 | |
--- a/quasimetric_rl/modules/actor/losses/behavior_cloning.py | |
+++ b/quasimetric_rl/modules/actor/losses/behavior_cloning.py | |
@@ -34,7 +34,8 @@ class BCLoss(ActorLossBase): | |
super().__init__() | |
self.weight = weight | |
- def forward(self, actor: Actor, critic_batch_infos: Collection[CriticBatchInfo], data: BatchData) -> LossResult: | |
+ def forward(self, actor: Actor, critic_batch_infos: Collection[CriticBatchInfo], data: BatchData, | |
+ use_target_encoder: bool, use_target_quasimetric_model: bool) -> LossResult: | |
actor_distn = actor(data.observations, data.future_observations) | |
log_prob: torch.Tensor = actor_distn.log_prob(data.actions).mean() | |
loss = -log_prob * self.weight | |
diff --git a/quasimetric_rl/modules/actor/losses/min_dist.py b/quasimetric_rl/modules/actor/losses/min_dist.py | |
index bc36ee8..1d57609 100644 | |
--- a/quasimetric_rl/modules/actor/losses/min_dist.py | |
+++ b/quasimetric_rl/modules/actor/losses/min_dist.py | |
@@ -87,8 +87,8 @@ class MinDistLoss(ActorLossBase): | |
self.register_parameter('raw_entropy_weight', None) | |
self.target_entropy = None | |
- def gather_obs_goal_pairs(self, critic_batch_infos: Sequence[CriticBatchInfo], | |
- data: BatchData) -> Tuple[torch.Tensor, torch.Tensor, Collection[ActorObsGoalCriticInfo]]: | |
+ def gather_obs_goal_pairs(self, critic_batch_infos: Sequence[CriticBatchInfo], data: BatchData, | |
+ use_target_encoder: bool) -> Tuple[torch.Tensor, torch.Tensor, Collection[ActorObsGoalCriticInfo]]: | |
r""" | |
Returns ( | |
obs, | |
@@ -121,7 +121,7 @@ class MinDistLoss(ActorLossBase): | |
# add future_observations | |
zg = torch.stack([ | |
zg, | |
- critic_batch_info.critic.target_encoder(data.future_observations), | |
+ critic_batch_info.critic.get_encoder(target=use_target_encoder)(data.future_observations), | |
], 0) | |
# zo = zo.expand_as(zg) | |
@@ -133,9 +133,10 @@ class MinDistLoss(ActorLossBase): | |
return obs, goal, actor_obs_goal_critic_infos | |
- def forward(self, actor: Actor, critic_batch_infos: Sequence[CriticBatchInfo], data: BatchData) -> LossResult: | |
+ def forward(self, actor: Actor, critic_batch_infos: Sequence[CriticBatchInfo], data: BatchData, | |
+ use_target_encoder: bool, use_target_quasimetric_model: bool) -> LossResult: | |
with torch.no_grad(): | |
- obs, goal, actor_obs_goal_critic_infos = self.gather_obs_goal_pairs(critic_batch_infos, data) | |
+ obs, goal, actor_obs_goal_critic_infos = self.gather_obs_goal_pairs(critic_batch_infos, data, use_target_encoder) | |
actor_distn = actor(obs, goal) | |
action = actor_distn.rsample() | |
@@ -147,17 +148,19 @@ class MinDistLoss(ActorLossBase): | |
for idx, actor_obs_goal_critic_info in enumerate(actor_obs_goal_critic_infos): | |
critic = actor_obs_goal_critic_info.critic | |
+ latent_dynamics = critic.latent_dynamics | |
+ quasimetric_model = critic.get_quasimetric_model(target=use_target_quasimetric_model) | |
zo = actor_obs_goal_critic_info.zo.detach() # [B,D] | |
zg = actor_obs_goal_critic_info.zg.detach() # [2?,B,D] | |
with critic.requiring_grad(False), critic.mode(False): | |
- zp = critic.latent_dynamics(data.observations, zo, action) # [2?,B,D] | |
- if not critic.borrowing_embedding: | |
+ zp = latent_dynamics(data.observations, zo, action) # [2?,B,D] | |
+ if idx == 0 or not critic.borrowing_embedding: | |
# action: [2?,B,A] | |
z = torch.stack(torch.broadcast_tensors(zo, zp), dim=0) # [2,2?,B,D] | |
- dist_noact, dist = critic.quasimetric_model(z, zg).unbind(0) # [2,2?,B] -> 2x [2?,B] | |
+ dist_noact, dist = quasimetric_model(z, zg).unbind(0) # [2,2?,B] -> 2x [2?,B] | |
dist_noact = dist_noact.detach() | |
else: | |
- dist = critic.quasimetric_model(zp, zg) # [2?,B] | |
+ dist = quasimetric_model(zp, zg) # [2?,B] | |
dist_noact = dists_noact[0] | |
info[f'dist_delta_{idx:02d}'] = (dist_noact - dist).mean() | |
info[f'dist_{idx:02d}'] = dist.mean() | |
diff --git a/quasimetric_rl/modules/optim.py b/quasimetric_rl/modules/optim.py | |
index 5131ce0..5019e1b 100644 | |
--- a/quasimetric_rl/modules/optim.py | |
+++ b/quasimetric_rl/modules/optim.py | |
@@ -1,6 +1,7 @@ | |
from typing import * | |
import attrs | |
+import logging | |
import contextlib | |
import torch | |
@@ -91,9 +92,11 @@ class AdamWSpec: | |
if len(params) == 0: | |
params = [dict(params=[])] # dummy param group so pytorch doesn't complain | |
for ii in range(len(params)): | |
- if isinstance(params, Mapping) and 'lr_mul' in params: # handle lr_multiplier | |
- pg = params[ii] | |
+ pg = params[ii] | |
+ if isinstance(pg, Mapping) and 'lr_mul' in pg: # handle lr_multiplier | |
+ pg['params'] = params_list = cast(List[torch.Tensor], list(pg['params'])) | |
assert 'lr' not in pg | |
+ logging.info(f'params (#tensor={len(params_list)}, #={sum(p.numel() for p in params_list)}): lr_mul={pg["lr_mul"]}') | |
pg['lr'] = self.lr * pg['lr_mul'] # type: ignore | |
del pg['lr_mul'] | |
return OptimWrapper( | |
diff --git a/quasimetric_rl/modules/quasimetric_critic/losses/__init__.py b/quasimetric_rl/modules/quasimetric_critic/losses/__init__.py | |
index 48ed700..5487442 100644 | |
--- a/quasimetric_rl/modules/quasimetric_critic/losses/__init__.py | |
+++ b/quasimetric_rl/modules/quasimetric_critic/losses/__init__.py | |
@@ -19,6 +19,8 @@ class CriticBatchInfo: | |
critic: QuasimetricCritic | |
zx: LatentTensor | |
zy: LatentTensor | |
+ zx_from_target_encoder: bool = False | |
+ zy_from_target_encoder: bool = False | |
class CriticLossBase(LossBase): | |
@@ -48,10 +50,10 @@ class QuasimetricCriticLosses(CriticLossBase): | |
latent_dynamics: LatentDynamicsLoss.Conf = LatentDynamicsLoss.Conf() | |
critic_optim: AdamWSpec.Conf = AdamWSpec.Conf(lr=1e-4) | |
- latent_dynamics_lr_mul: float = 1 | |
- quasimetric_model_lr_mul: float = 1 | |
- encoder_lr_mul: float = 1 # TD-MPC2 uses 0.3 | |
- quasimetric_head_lr_mul: float = 1 # IQE2 can benefit from smaller lr, ~1e-5 | |
+ latent_dynamics_lr_mul: float = attrs.field(default=1., validator=attrs.validators.ge(0)) | |
+ quasimetric_model_lr_mul: float = attrs.field(default=1., validator=attrs.validators.ge(0)) | |
+ encoder_lr_mul: float = attrs.field(default=1., validator=attrs.validators.ge(0)) # TD-MPC2 uses 0.3 | |
+ quasimetric_head_lr_mul: float = attrs.field(default=1., validator=attrs.validators.ge(0)) # IQE2 can benefit from smaller lr, ~1e-5 | |
local_lagrange_mult_optim: AdamWSpec.Conf = AdamWSpec.Conf(lr=1e-2) | |
dynamics_lagrange_mult_optim: AdamWSpec.Conf = AdamWSpec.Conf(lr=0) | |
diff --git a/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py b/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py | |
index e469f13..3879518 100644 | |
--- a/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py | |
+++ b/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py | |
@@ -106,28 +106,37 @@ class GlobalPushLossBase(CriticLossBase): | |
zgoal, | |
get_dist(critic_batch_info.zx, zgoal), | |
self.weight, | |
+ {}, | |
) | |
if self.weight_future_goal > 0: | |
- zgoal = critic_batch_info.critic.encoder(data.future_observations) | |
+ zgoal = critic_batch_info.critic.get_encoder(critic_batch_info.zy_from_target_encoder)(data.future_observations) | |
dist = get_dist(critic_batch_info.zx, zgoal) | |
if self.clamp_max_future_goal: | |
+ observed_upper_bound = self.step_cost * data.future_tdelta | |
+ info = dict( | |
+ ratio_future_observed_dist=(dist / observed_upper_bound).mean(), | |
+ exceed_future_observed_dist_rate=(dist > observed_upper_bound).mean(dtype=torch.float32), | |
+ ) | |
dist = dist.clamp_max(self.step_cost * data.future_tdelta) | |
+ else: | |
+ info = {} | |
yield ( | |
'future_goal', | |
zgoal, | |
dist, | |
- self.weight_future_goal, | |
+ self.weight_future_goal,info, | |
) | |
@abc.abstractmethod | |
- def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor, dist: torch.Tensor, weight: float) -> LossResult: | |
+ def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor, | |
+ dist: torch.Tensor, weight: float, info: Mapping[str, torch.Tensor]) -> LossResult: | |
raise NotImplementedError | |
def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult: | |
return LossResult.combine( | |
{ | |
- name: self.compute_loss(data, critic_batch_info, zgoal, dist, weight) | |
- for name, zgoal, dist, weight in self.generate_dist_weight(data, critic_batch_info) | |
+ name: self.compute_loss(data, critic_batch_info, zgoal, dist, weight, info) | |
+ for name, zgoal, dist, weight, info in self.generate_dist_weight(data, critic_batch_info) | |
}, | |
) | |
@@ -174,11 +183,14 @@ class GlobalPushLoss(GlobalPushLossBase): | |
self.softplus_beta = softplus_beta | |
self.softplus_offset = softplus_offset | |
- def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor, dist: torch.Tensor, weight: float) -> LossResult: | |
+ def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor, | |
+ dist: torch.Tensor, weight: float, info: Mapping[str, torch.Tensor]) -> LossResult: | |
+ dict_info: Dict[str, torch.Tensor] = dict(info) | |
# Sec 3.2. Transform so that we penalize large distances less. | |
tsfm_dist: torch.Tensor = F.softplus(self.softplus_offset - dist, beta=self.softplus_beta) # type: ignore | |
tsfm_dist = tsfm_dist.mean() | |
- return LossResult(loss=tsfm_dist * weight, info=dict(dist=dist.mean(), tsfm_dist=tsfm_dist)) | |
+ dict_info.update(dist=dist.mean(), tsfm_dist=tsfm_dist) | |
+ return LossResult(loss=tsfm_dist * weight, info=dict_info) | |
def extra_repr(self) -> str: | |
return '\n'.join([ | |
@@ -216,21 +228,22 @@ class GlobalPushLinearLoss(GlobalPushLossBase): | |
detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, clamp_max_future_goal=clamp_max_future_goal) | |
self.clamp_max = clamp_max | |
- def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor, dist: torch.Tensor, weight: float) -> LossResult: | |
- info: Dict[str, torch.Tensor] | |
+ def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor, | |
+ dist: torch.Tensor, weight: float, info: Mapping[str, torch.Tensor]) -> LossResult: | |
+ dict_info: Dict[str, torch.Tensor] = dict(info) | |
if self.clamp_max is None: | |
dist = dist.mean() | |
- info = dict(dist=dist) | |
+ dict_info.update(dist=dist) | |
neg_loss = dist | |
else: | |
- info = dict(dist=dist.mean()) | |
tsfm_dist = dist.clamp_max(self.clamp_max) | |
- info.update( | |
+ dict_info.update( | |
+ dist=dist.mean(), | |
tsfm_dist=tsfm_dist.mean(), | |
exceed_rate=(dist >= self.clamp_max).mean(dtype=torch.float32), | |
) | |
neg_loss = tsfm_dist | |
- return LossResult(loss=neg_loss * weight, info=info) | |
+ return LossResult(loss=neg_loss * weight, info=dict_info) | |
def extra_repr(self) -> str: | |
return '\n'.join([ | |
@@ -277,9 +290,12 @@ class GlobalPushNextMSELoss(GlobalPushLossBase): | |
self.allow_gt = allow_gt | |
self.gamma = gamma | |
- def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor, dist: torch.Tensor, weight: float) -> LossResult: | |
+ def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor, | |
+ dist: torch.Tensor, weight: float, info: Mapping[str, torch.Tensor]) -> LossResult: | |
+ dict_info: Dict[str, torch.Tensor] = dict(info) | |
+ | |
with torch.enable_grad(self.detach_target_dist): | |
- # by tri-eq, the actual cost can't be larger that step_cost + d(s', g) | |
+ # by tri-eq, the actual cost can't be larger than step_cost + d(s', g) | |
next_dist = critic_batch_info.critic.quasimetric_model( | |
critic_batch_info.zy, zgoal, proj_grad_enabled=(True, not self.detach_proj_goal) | |
) | |
@@ -289,13 +305,20 @@ class GlobalPushNextMSELoss(GlobalPushLossBase): | |
if self.allow_gt: | |
dist = dist.clamp_max(target_dist) | |
+ dict_info.update( | |
+ exceed_rate=(dist >= target_dist).mean(dtype=torch.float32), | |
+ ) | |
if self.gamma is None: | |
loss = F.mse_loss(dist, target_dist) | |
else: | |
loss = F.mse_loss(self.gamma ** dist, self.gamma ** target_dist) | |
- return LossResult(loss=loss * self.weight, info=dict(loss=loss, dist=dist.mean(), target_dist=target_dist.mean())) | |
+ dict_info.update( | |
+ loss=loss, dist=dist.mean(), target_dist=target_dist.mean() | |
+ ) | |
+ | |
+ return LossResult(loss=loss * self.weight, info=dict_info) | |
def extra_repr(self) -> str: | |
return '\n'.join([ | |
@@ -333,10 +356,13 @@ class GlobalPushLogLoss(GlobalPushLossBase): | |
detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, clamp_max_future_goal=clamp_max_future_goal) | |
self.offset = offset | |
- def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor, dist: torch.Tensor, weight: float) -> LossResult: | |
+ def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor, | |
+ dist: torch.Tensor, weight: float, info: Mapping[str, torch.Tensor]) -> LossResult: | |
+ dict_info: Dict[str, torch.Tensor] = dict(info) | |
tsfm_dist: torch.Tensor = -dist.add(self.offset).log() | |
tsfm_dist = tsfm_dist.mean() | |
- return LossResult(loss=tsfm_dist * weight, info=dict(dist=dist.mean(), tsfm_dist=tsfm_dist)) | |
+ dict_info.update(dist=dist.mean(), tsfm_dist=tsfm_dist) | |
+ return LossResult(loss=tsfm_dist * weight, info=dict_info) | |
def extra_repr(self) -> str: | |
return '\n'.join([ | |
@@ -378,13 +404,17 @@ class GlobalPushRBFLoss(GlobalPushLossBase): | |
detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, clamp_max_future_goal=clamp_max_future_goal) | |
self.inv_scale = inv_scale | |
- def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor, dist: torch.Tensor, weight: float) -> LossResult: | |
+ def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor, | |
+ dist: torch.Tensor, weight: float, info: Mapping[str, torch.Tensor]) -> LossResult: | |
+ dict_info: Dict[str, torch.Tensor] = dict(info) | |
inv_scale = dist.detach().square().mean().div(2).sqrt().clamp(1e-3, self.inv_scale) # make E[d^2]/r^2 approx 2 | |
tsfm_dist: torch.Tensor = (dist / inv_scale).square().neg().exp() | |
rbf_potential = tsfm_dist.mean().log() | |
- return LossResult(loss=rbf_potential * self.weight, | |
- info=dict(dist=dist.mean(), inv_scale=inv_scale, | |
- tsfm_dist=tsfm_dist, rbf_potential=rbf_potential)) # type: ignore | |
+ dict_info.update( | |
+ dist=dist.mean(), inv_scale=inv_scale, | |
+ tsfm_dist=tsfm_dist, rbf_potential=rbf_potential, | |
+ ) | |
+ return LossResult(loss=rbf_potential * self.weight, info=dict_info) | |
def extra_repr(self) -> str: | |
return '\n'.join([ | |
diff --git a/quasimetric_rl/modules/quasimetric_critic/losses/latent_dynamics.py b/quasimetric_rl/modules/quasimetric_critic/losses/latent_dynamics.py | |
index f5e8248..19e4886 100644 | |
--- a/quasimetric_rl/modules/quasimetric_critic/losses/latent_dynamics.py | |
+++ b/quasimetric_rl/modules/quasimetric_critic/losses/latent_dynamics.py | |
@@ -2,6 +2,7 @@ from typing import * | |
import attrs | |
+import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
@@ -25,7 +26,7 @@ class LatentDynamicsLoss(CriticLossBase): | |
# weight: float = attrs.field(default=0.1, validator=attrs.validators.gt(0)) | |
epsilon: float = attrs.field(default=0.25, validator=attrs.validators.gt(0)) | |
- bidirectional: bool = True | |
+ # bidirectional: bool = True | |
gamma: Optional[float] = attrs.field( | |
default=None, validator=attrs.validators.optional(attrs.validators.and_( | |
attrs.validators.gt(0), | |
@@ -37,22 +38,26 @@ class LatentDynamicsLoss(CriticLossBase): | |
detach_qmet: bool = False | |
non_quasimetric_dim_mse_weight: float = attrs.field(default=0., validator=attrs.validators.ge(0)) | |
+ kind: str = attrs.field( | |
+ default='mean', validator=attrs.validators.in_({'mean', 'unidir', 'max', 'bound', 'bound_l2comp', 'l2comp'})) # type: ignore | |
+ | |
def make(self) -> 'LatentDynamicsLoss': | |
return LatentDynamicsLoss( | |
# weight=self.weight, | |
epsilon=self.epsilon, | |
- bidirectional=self.bidirectional, | |
+ # bidirectional=self.bidirectional, | |
gamma=self.gamma, | |
detach_sp=self.detach_sp, | |
detach_proj_sp=self.detach_proj_sp, | |
detach_qmet=self.detach_qmet, | |
non_quasimetric_dim_mse_weight=self.non_quasimetric_dim_mse_weight, | |
init_lagrange_multiplier=self.init_lagrange_multiplier, | |
+ kind=self.kind, | |
) | |
# weight: float | |
epsilon: float | |
- bidirectional: bool | |
+ # bidirectional: bool | |
gamma: Optional[float] | |
detach_sp: bool | |
detach_proj_sp: bool | |
@@ -60,20 +65,24 @@ class LatentDynamicsLoss(CriticLossBase): | |
non_quasimetric_dim_mse_weight: float | |
c: float | |
init_lagrange_multiplier: float | |
+ kind: str | |
- def __init__(self, *, epsilon: float, bidirectional: bool, | |
+ def __init__(self, *, epsilon: float, | |
+ # bidirectional: bool, | |
gamma: Optional[float], init_lagrange_multiplier: float, | |
# weight: float, | |
detach_qmet: bool, detach_proj_sp: bool, detach_sp: bool, | |
- non_quasimetric_dim_mse_weight: float): | |
+ non_quasimetric_dim_mse_weight: float, | |
+ kind: str): | |
super().__init__() | |
# self.weight = weight | |
self.epsilon = epsilon | |
- self.bidirectional = bidirectional | |
+ # self.bidirectional = bidirectional | |
self.gamma = gamma | |
self.detach_qmet = detach_qmet | |
self.detach_sp = detach_sp | |
self.detach_proj_sp = detach_proj_sp | |
+ self.kind = kind | |
self.non_quasimetric_dim_mse_weight = non_quasimetric_dim_mse_weight | |
self.init_lagrange_multiplier = init_lagrange_multiplier | |
self.raw_lagrange_multiplier = nn.Parameter( | |
@@ -86,16 +95,58 @@ class LatentDynamicsLoss(CriticLossBase): | |
pred_zy = critic.latent_dynamics(data.observations, zx, data.actions) | |
if self.detach_sp: | |
zy = zy.detach() | |
- with critic.quasimetric_model.requiring_grad(not self.detach_qmet): | |
- dists = critic.quasimetric_model(zy, pred_zy, bidirectional=self.bidirectional, # at least optimize d(s', hat{s'}) | |
- proj_grad_enabled=(self.detach_proj_sp, True)) | |
- | |
lagrange_mult = F.softplus(self.raw_lagrange_multiplier) # make positive | |
# lagrange multiplier is minimax training, so grad_mul -1 | |
lagrange_mult = grad_mul(lagrange_mult, -1) | |
- info = {} | |
+ info: Dict[str, torch.Tensor] = {} | |
+ | |
+ if self.kind == 'unidir': | |
+ with critic.quasimetric_model.requiring_grad(not self.detach_qmet): | |
+ dists = critic.quasimetric_model(zy, pred_zy, bidirectional=False, # at least optimize d(s', hat{s'}) | |
+ proj_grad_enabled=(self.detach_proj_sp, True)) | |
+ info.update(dist_n2p=dists.mean()) | |
+ elif self.kind in ['mean', 'max', 'bound', 'l2comp']: | |
+ with critic.quasimetric_model.requiring_grad(not self.detach_qmet): | |
+ dists_comps = critic.quasimetric_model( | |
+ zy, pred_zy, bidirectional=True, proj_grad_enabled=(self.detach_proj_sp, True), | |
+ reduced=False) | |
+ dists = critic.quasimetric_model.quasimetric_head.reduction(dists_comps) | |
+ | |
+ dist_p2n, dist_n2p = dists.unbind(-1) | |
+ info.update(dist_p2n=dist_p2n.mean(), dist_n2p=dist_n2p.mean()) | |
+ if self.kind == 'max': | |
+ dists = dists.max(-1).values | |
+ elif self.kind == 'bound': | |
+ with critic.quasimetric_model.requiring_grad(not self.detach_qmet): | |
+ dists = critic.quasimetric_model( | |
+ zy, pred_zy, bidirectional=False, # NB: false is enough for bound | |
+ proj_grad_enabled=(self.detach_proj_sp, True), | |
+ symmetric_upperbound=True, | |
+ ) | |
+ info.update(dist_upbnd=dists.mean()) | |
+ elif self.kind == 'bound_l2comp': | |
+ with critic.quasimetric_model.requiring_grad(not self.detach_qmet): | |
+ dists_comps = critic.quasimetric_model( | |
+ zy, pred_zy, bidirectional=False, # NB: false is enough for bound | |
+ proj_grad_enabled=(self.detach_proj_sp, True), | |
+ reduced=False, symmetric_upperbound=True, | |
+ ) | |
+ dists = cast( | |
+ torch.Tensor, | |
+ dists_comps.norm(dim=-1) / np.sqrt(dists_comps.shape[-1]), | |
+ ) | |
+ info.update(dist_upbnd_l2comp=dists.mean()) | |
+ elif self.kind == 'l2comp': | |
+ dists = cast( | |
+ torch.Tensor, | |
+ dists_comps.norm(dim=-1) / np.sqrt(dists_comps.shape[-1]), | |
+ ) | |
+ info.update(dist_l2comp=dists.mean()) | |
+ else: | |
+ raise NotImplementedError(self.kind) | |
+ | |
if self.gamma is None: | |
sq_dists = dists.square().mean() | |
violation = (sq_dists - self.epsilon ** 2) | |
@@ -117,12 +168,6 @@ class LatentDynamicsLoss(CriticLossBase): | |
info.update(non_quasimetric_dim_mse=non_quasimetric_dim_mse) | |
loss += non_quasimetric_dim_mse * self.non_quasimetric_dim_mse_weight | |
- if self.bidirectional: | |
- dist_p2n, dist_n2p = dists.flatten(0, -2).mean(0).unbind(-1) | |
- info.update(dist_p2n=dist_p2n, dist_n2p=dist_n2p) | |
- else: | |
- info.update(dist_n2p=dists.mean()) | |
- | |
return LossResult( | |
loss=loss, | |
info=info, # type: ignore | |
@@ -130,7 +175,7 @@ class LatentDynamicsLoss(CriticLossBase): | |
def extra_repr(self) -> str: | |
return '\n'.join([ | |
- f"epsilon={self.epsilon!r}, bidirectional={self.bidirectional}", | |
- f"gamma={self.gamma!r}, detach_sp={self.detach_sp}, detach_proj_sp={self.detach_proj_sp}, detach_qmet={self.detach_qmet}", | |
+ f"kind={self.kind}, gamma={self.gamma!r}", | |
+ f"epsilon={self.epsilon!r}, detach_sp={self.detach_sp}, detach_proj_sp={self.detach_proj_sp}, detach_qmet={self.detach_qmet}", | |
]) | |
# return f"weight={self.weight:g}, detach_sp={self.detach_sp}" | |
diff --git a/quasimetric_rl/modules/quasimetric_critic/losses/local_constraint.py b/quasimetric_rl/modules/quasimetric_critic/losses/local_constraint.py | |
index 1806248..eb87576 100644 | |
--- a/quasimetric_rl/modules/quasimetric_critic/losses/local_constraint.py | |
+++ b/quasimetric_rl/modules/quasimetric_critic/losses/local_constraint.py | |
@@ -2,6 +2,7 @@ from typing import * | |
import attrs | |
+import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
@@ -30,8 +31,14 @@ class LocalConstraintLoss(CriticLossBase): | |
init_lagrange_multiplier: float = attrs.field(default=0.01, validator=attrs.validators.gt(0)) | |
- kind: str = attrs.field( | |
- default='mse', validator=attrs.validators.in_(['mse', 'sq'])) # type: ignore | |
+ p: float = attrs.field(default=2., validator=attrs.validators.ge(1)) | |
+ compare_before_pow: bool = True | |
+ # kind: str = attrs.field( | |
+ # default='mse', validator=attrs.validators.in_(['mse', 'sq'])) # type: ignore | |
+ # batch_reduction: str = attrs.field( | |
+ # default='mean', validator=attrs.validators.in_(['mean', 'l2'])) # type: ignore | |
+ invpow_after_batch_agg: bool = False | |
+ log: bool = False | |
detach_proj_sp: bool = False | |
detach_sp: bool = False | |
@@ -49,12 +56,22 @@ class LocalConstraintLoss(CriticLossBase): | |
step_cost=self.step_cost, | |
step_cost_high=self.step_cost_high, | |
init_lagrange_multiplier=self.init_lagrange_multiplier, | |
- kind=self.kind, | |
+ # kind=self.kind, | |
+ p=self.p, | |
+ compare_before_pow=self.compare_before_pow, | |
+ log=self.log, | |
+ # batch_reduction=self.batch_reduction, | |
+ invpow_after_batch_agg=self.invpow_after_batch_agg, | |
detach_sp=self.detach_sp, | |
detach_proj_sp=self.detach_proj_sp, | |
) | |
- kind: str | |
+ # kind: str | |
+ p: float | |
+ compare_before_pow: bool | |
+ log: bool | |
+ # batch_reduction: str | |
+ invpow_after_batch_agg: bool | |
epsilon: float | |
allow_le: bool | |
step_cost: float | |
@@ -65,11 +82,21 @@ class LocalConstraintLoss(CriticLossBase): | |
raw_lagrange_multiplier: nn.Parameter # for the QRL constrained optimization | |
- def __init__(self, *, kind: str, epsilon: float, allow_le: bool, | |
+ def __init__(self, *, | |
+ # kind: str, | |
+ p: float, compare_before_pow: bool, invpow_after_batch_agg: bool, | |
+ log: bool, | |
+ # batch_reduction: str, | |
+ epsilon: float, allow_le: bool, | |
step_cost: float, step_cost_high: float, init_lagrange_multiplier: float, | |
detach_sp: bool, detach_proj_sp: bool): | |
super().__init__() | |
- self.kind = kind | |
+ # self.kind = kind | |
+ self.p = p | |
+ self.compare_before_pow = compare_before_pow | |
+ self.log = log | |
+ # self.batch_reduction = batch_reduction | |
+ self.invpow_after_batch_agg = invpow_after_batch_agg | |
self.epsilon = epsilon | |
self.allow_le = allow_le | |
self.step_cost = step_cost | |
@@ -100,6 +127,7 @@ class LocalConstraintLoss(CriticLossBase): | |
info['dist_090'], | |
info['dist_100'], | |
) = dist.quantile(dist.new_tensor([0, 0.1, 0.25, 0.5, 0.75, 0.9, 1])).unbind() | |
+ info.update(dist=dist.mean()) | |
lagrange_mult = F.softplus(self.raw_lagrange_multiplier) # make positive | |
# lagrange multiplier is minimax training, so grad_mul -1 | |
@@ -110,31 +138,63 @@ class LocalConstraintLoss(CriticLossBase): | |
else: | |
target = dist.detach().clamp(self.step_cost, self.step_cost_high) | |
- if self.kind == 'mse': | |
- if self.allow_le: | |
- sq_deviation = (dist - target).relu().square().mean() | |
+ if self.log: | |
+ dist = dist.log() | |
+ if isinstance(target, torch.Tensor): | |
+ target = target.log() | |
else: | |
- sq_deviation = (dist - target).square().mean() | |
- elif self.kind == 'sq': | |
- if self.allow_le: | |
- sq_deviation = (dist.square() - target ** 2).relu().mean() | |
- else: | |
- sq_deviation = (dist.square() - target ** 2).abs().mean() | |
+ target = np.log(target) | |
+ info.update(dist_log=dist.mean()) | |
+ | |
+ if self.compare_before_pow: | |
+ deviation = dist - target | |
+ deviation = (torch.relu if self.allow_le else torch.abs)(deviation) | |
+ deviation = deviation ** self.p | |
else: | |
- raise ValueError(f"Unknown kind {self.kind}") | |
- violation = (sq_deviation - self.epsilon ** 2) | |
- loss = violation * lagrange_mult | |
+ deviation = (dist ** self.p - target ** self.p) | |
+ deviation = (torch.relu if self.allow_le else torch.abs)(deviation) | |
- info.update( | |
- dist=dist.mean(), sq_deviation=sq_deviation, | |
- violation=violation, lagrange_mult=lagrange_mult, | |
- ) | |
+ if self.invpow_after_batch_agg: | |
+ deviation = deviation.mean().pow(1 / self.p) | |
+ violation = deviation - self.epsilon | |
+ else: | |
+ violation = deviation.mean() - (self.epsilon ** self.p) | |
+ info.update(deviation=deviation.mean(), violation=violation, lagrange_mult=lagrange_mult) | |
+ | |
+ # if self.kind == 'mse': | |
+ # if self.allow_le: | |
+ # sq_deviation = (dist - target).relu().square() | |
+ # else: | |
+ # sq_deviation = (dist - target).square() | |
+ # elif self.kind == 'sq': | |
+ # if self.allow_le: | |
+ # sq_deviation = (dist.square() - target ** 2).relu() | |
+ # else: | |
+ # sq_deviation = (dist.square() - target ** 2).abs() | |
+ # else: | |
+ # raise ValueError(f"Unknown kind {self.kind}") | |
+ # info.update(sq_deviation=sq_deviation.mean()) | |
+ | |
+ # if self.batch_reduction == 'mean': | |
+ # batch_sq_deviation = sq_deviation.mean() | |
+ # elif self.batch_reduction == 'l2': | |
+ # batch_sq_deviation = sq_deviation.square().mean().sqrt() | |
+ # else: | |
+ # raise ValueError(f"Unknown batch_reduction {self.batch_reduction}") | |
+ # info.update(batch_sq_deviation=batch_sq_deviation) | |
+ | |
+ # violation = (batch_sq_deviation - self.epsilon ** 2) | |
+ | |
+ loss = violation * lagrange_mult | |
+ # info.update(violation=violation, lagrange_mult=lagrange_mult) | |
return LossResult(loss=loss, info=info) | |
def extra_repr(self) -> str: | |
return '\n'.join([ | |
- f"{self.kind}, epsilon={self.epsilon:g}, allow_le={self.allow_le}", | |
+ # f"kind={self.kind}, log={self.log}, batch_reduction={self.batch_reduction}", | |
+ f"p={self.p}, compare_before_pow={self.compare_before_pow}, invpow_after_batch_agg={self.invpow_after_batch_agg}", | |
+ f"epsilon={self.epsilon:g}, allow_le={self.allow_le}", | |
f"step_cost={self.step_cost:g}, step_cost_high={self.step_cost_high:g}", | |
f"detach_sp={self.detach_sp}, detach_proj_sp={self.detach_proj_sp}", | |
]) | |
diff --git a/quasimetric_rl/modules/quasimetric_critic/models/__init__.py b/quasimetric_rl/modules/quasimetric_critic/models/__init__.py | |
index 8789800..266f592 100644 | |
--- a/quasimetric_rl/modules/quasimetric_critic/models/__init__.py | |
+++ b/quasimetric_rl/modules/quasimetric_critic/models/__init__.py | |
@@ -25,6 +25,11 @@ class QuasimetricCritic(Module): | |
attrs.validators.le(1), | |
))) | |
encoder: Encoder.Conf = Encoder.Conf() | |
+ target_quasimetric_model_ema: Optional[float] = attrs.field( | |
+ default=None, validator=attrs.validators.optional(attrs.validators.and_( | |
+ attrs.validators.ge(0), | |
+ attrs.validators.le(1), | |
+ ))) | |
quasimetric_model: QuasimetricModel.Conf = QuasimetricModel.Conf() | |
latent_dynamics: Dynamics.Conf = Dynamics.Conf() | |
@@ -42,34 +47,46 @@ class QuasimetricCritic(Module): | |
) | |
return QuasimetricCritic(encoder, quasimetric_model, latent_dynamics, | |
share_embedding_from=share_embedding_from, | |
- target_encoder_ema=self.target_encoder_ema) | |
+ target_encoder_ema=self.target_encoder_ema, | |
+ target_quasimetric_model_ema=self.target_quasimetric_model_ema) | |
borrowing_embedding: bool | |
encoder: Encoder | |
_target_encoder: Optional[Encoder] | |
target_encoder_ema: Optional[float] | |
quasimetric_model: QuasimetricModel | |
+ _target_quasimetric_model: Optional[QuasimetricModel] | |
+ target_quasimetric_model_ema: Optional[float] | |
latent_dynamics: Dynamics # FIXME: add BC to model loading & change name | |
def __init__(self, encoder: Encoder, quasimetric_model: QuasimetricModel, latent_dynamics: Dynamics, | |
- share_embedding_from: Optional['QuasimetricCritic'], target_encoder_ema: Optional[float]): | |
+ share_embedding_from: Optional['QuasimetricCritic'], target_encoder_ema: Optional[float], | |
+ target_quasimetric_model_ema: Optional[float] = None): | |
super().__init__() | |
self.borrowing_embedding = share_embedding_from is not None | |
if share_embedding_from is not None: | |
encoder = share_embedding_from.encoder | |
- target_encoder = share_embedding_from.target_encoder | |
+ _target_encoder = share_embedding_from._target_encoder | |
assert target_encoder_ema == share_embedding_from.target_encoder_ema | |
quasimetric_model = share_embedding_from.quasimetric_model | |
+ assert target_quasimetric_model_ema == share_embedding_from.target_quasimetric_model_ema | |
+ _target_quasimetric_model = share_embedding_from._target_quasimetric_model | |
else: | |
if target_encoder_ema is None: | |
- target_encoder = None | |
+ _target_encoder = None | |
+ else: | |
+ _target_encoder = copy.deepcopy(encoder).requires_grad_(False).eval() | |
+ if target_quasimetric_model_ema is None: | |
+ _target_quasimetric_model = None | |
else: | |
- target_encoder = copy.deepcopy(encoder).requires_grad_(False).eval() | |
+ _target_quasimetric_model = copy.deepcopy(quasimetric_model).requires_grad_(False).eval() | |
self.encoder = encoder | |
- self.add_module('_target_encoder', target_encoder) | |
self.target_encoder_ema = target_encoder_ema | |
+ self.add_module('_target_encoder', _target_encoder) | |
self.quasimetric_model = quasimetric_model | |
+ self.target_quasimetric_model_ema = target_quasimetric_model_ema | |
+ self.add_module('_target_quasimetric_model', _target_quasimetric_model) | |
self.latent_dynamics = latent_dynamics | |
def forward(self, x: torch.Tensor, y: torch.Tensor, *, action: Optional[torch.Tensor] = None) -> torch.Tensor: | |
@@ -94,16 +111,31 @@ class QuasimetricCritic(Module): | |
def get_encoder(self, target: bool = False) -> Encoder: | |
return self.target_encoder if target else self.encoder | |
+ @property | |
+ def target_quasimetric_model(self) -> QuasimetricModel: | |
+ if self._target_quasimetric_model is not None: | |
+ return self._target_quasimetric_model | |
+ else: | |
+ return self.quasimetric_model | |
+ | |
+ def get_quasimetric_model(self, target: bool = False) -> QuasimetricModel: | |
+ return self.target_quasimetric_model if target else self.quasimetric_model | |
+ | |
@torch.no_grad() | |
- def update_target_encoder_(self): | |
- if not self.borrowing_embedding and self.target_encoder_ema is not None: | |
- assert self._target_encoder is not None | |
- for p, p_target in zip(self.encoder.parameters(), self._target_encoder.parameters()): | |
- p_target.data.lerp_(p.data, 1 - self.target_encoder_ema) | |
+ def update_target_models_(self): | |
+ if not self.borrowing_embedding: | |
+ if self.target_encoder_ema is not None: | |
+ assert self._target_encoder is not None | |
+ for p, p_target in zip(self.encoder.parameters(), self._target_encoder.parameters()): | |
+ p_target.data.lerp_(p.data, 1 - self.target_encoder_ema) | |
+ if self.target_quasimetric_model_ema is not None: | |
+ assert self._target_quasimetric_model is not None | |
+ for p, p_target in zip(self.quasimetric_model.parameters(), self._target_quasimetric_model.parameters()): | |
+ p_target.data.lerp_(p.data, 1 - self.target_quasimetric_model_ema) | |
# for type hints | |
def __call__(self, x: torch.Tensor, y: torch.Tensor, *, action: Optional[torch.Tensor] = None) -> torch.Tensor: | |
return super().__call__(x, y, action=action) | |
def extra_repr(self) -> str: | |
- return f"borrowing_embedding={self.borrowing_embedding}" | |
+ return f"borrowing_embedding={self.borrowing_embedding}, target_encoder_ema={self.target_encoder_ema}, target_quasimetric_model_ema={self.target_quasimetric_model_ema}" | |
diff --git a/quasimetric_rl/modules/quasimetric_critic/models/encoder.py b/quasimetric_rl/modules/quasimetric_critic/models/encoder.py | |
index ddd6240..2b53bd5 100644 | |
--- a/quasimetric_rl/modules/quasimetric_critic/models/encoder.py | |
+++ b/quasimetric_rl/modules/quasimetric_critic/models/encoder.py | |
@@ -46,6 +46,7 @@ class Encoder(nn.Module): | |
arch: Tuple[int, ...] = (512, 512) | |
layer_norm: bool = True | |
+ bias_last_fc: bool = True | |
latent: LatentSpaceConf = LatentSpaceConf() | |
def make(self, *, env_spec: EnvSpec) -> 'Encoder': | |
@@ -53,6 +54,7 @@ class Encoder(nn.Module): | |
env_spec=env_spec, | |
arch=self.arch, | |
layer_norm=self.layer_norm, | |
+ bias_last_fc=self.bias_last_fc, | |
latent=self.latent, | |
) | |
@@ -62,14 +64,14 @@ class Encoder(nn.Module): | |
projection: nn.Module | |
latent: LatentSpaceConf | |
- def __init__(self, *, env_spec: EnvSpec, arch: Tuple[int, ...], layer_norm: bool, | |
+ def __init__(self, *, env_spec: EnvSpec, arch: Tuple[int, ...], layer_norm: bool, bias_last_fc: bool, | |
latent: LatentSpaceConf, **kwargs): | |
super().__init__(**kwargs) | |
self.input_shape = env_spec.observation_shape | |
self.input_encoding = env_spec.make_observation_input() | |
encoder_input_size = self.input_encoding.output_size | |
self.encoder = nn.Sequential( | |
- MLP(encoder_input_size, latent.latent_size, hidden_sizes=arch, layer_norm=layer_norm), | |
+ MLP(encoder_input_size, latent.latent_size, hidden_sizes=arch, layer_norm=layer_norm, bias_last_fc=bias_last_fc), | |
latent.make_projection(), | |
) | |
self.latent = latent | |
diff --git a/quasimetric_rl/modules/quasimetric_critic/models/latent_dynamics.py b/quasimetric_rl/modules/quasimetric_critic/models/latent_dynamics.py | |
index 0847076..3683454 100644 | |
--- a/quasimetric_rl/modules/quasimetric_critic/models/latent_dynamics.py | |
+++ b/quasimetric_rl/modules/quasimetric_critic/models/latent_dynamics.py | |
@@ -20,10 +20,13 @@ class Dynamics(MLP): | |
# config / argparse uses this to specify behavior | |
arch: Tuple[int, ...] = (512, 512) | |
+ action_arch: Tuple[int, ...] = () | |
layer_norm: bool = True | |
latent_input: bool = True | |
raw_observation_input_arch: Optional[Tuple[int, ...]] = None | |
+ bias: bool = True | |
+ bias_last_fc: bool = True | |
use_latent_space_proj: bool = True | |
residual: bool = False # True -> False following TD-MPC2 | |
fc_fuse: bool = False # whether last fc does fusion | |
@@ -34,14 +37,18 @@ class Dynamics(MLP): | |
latent_input=self.latent_input, | |
raw_observation_input_arch=self.raw_observation_input_arch, | |
env_spec=env_spec, | |
+ action_arch=self.action_arch, | |
hidden_sizes=self.arch, | |
layer_norm=self.layer_norm, | |
use_latent_space_proj=self.use_latent_space_proj, | |
+ bias=self.bias, | |
+ bias_last_fc=self.bias_last_fc, | |
residual=self.residual, | |
fc_fuse=self.fc_fuse, | |
) | |
- action_input: InputEncoding | |
+ action_input: Union[InputEncoding, nn.Sequential] | |
+ action_arch: Tuple[int, ...] | |
latent_input: bool | |
raw_observation_input: Optional[nn.Sequential] | |
residual: bool | |
@@ -50,32 +57,53 @@ class Dynamics(MLP): | |
def __init__(self, *, latent: LatentSpaceConf, | |
latent_input: bool, | |
raw_observation_input_arch: Optional[Tuple[int, ...]], | |
- env_spec: EnvSpec, hidden_sizes: Tuple[int, ...], layer_norm: bool, | |
- use_latent_space_proj: bool, residual: bool, | |
+ env_spec: EnvSpec, action_arch: Tuple[int, ...], | |
+ hidden_sizes: Tuple[int, ...], layer_norm: bool, | |
+ use_latent_space_proj: bool, | |
+ bias: bool, bias_last_fc: bool, residual: bool, | |
fc_fuse: bool): | |
+ self.action_arch = action_arch | |
action_input = env_spec.make_action_input() | |
- mlp_input_size = action_input.output_size | |
+ action_outsize = action_input.output_size | |
+ if action_arch: | |
+ action_input = nn.Sequential( | |
+ action_input, | |
+ MLP( | |
+ action_input.output_size, | |
+ action_arch[-1], | |
+ hidden_sizes=action_arch[:-1], | |
+ activation_last_fc=True, | |
+ layer_norm_last_fc=layer_norm, | |
+ layer_norm=layer_norm, | |
+ bias=bias, | |
+ ), | |
+ ) | |
+ action_outsize = action_arch[-1] | |
+ mlp_input_size = action_outsize | |
mlp_output_size = latent.latent_size | |
self.latent_input = latent_input | |
if latent_input: | |
mlp_input_size += latent.latent_size | |
if raw_observation_input_arch is not None: | |
- assert raw_observation_input_arch is not None | |
- assert len(raw_observation_input_arch) >= 1 | |
input_encoding = env_spec.make_observation_input() | |
- raw_observation_input = MLP( | |
- input_encoding.output_size, | |
- raw_observation_input_arch[-1], | |
- hidden_sizes=raw_observation_input_arch[:-1], | |
- activation_last_fc=True, | |
- layer_norm_last_fc=layer_norm, | |
- layer_norm=layer_norm, | |
- ) | |
- mlp_input_size += raw_observation_input.output_size | |
- raw_observation_input = nn.Sequential( | |
- input_encoding, | |
- raw_observation_input, | |
- ) | |
+ if len(raw_observation_input_arch) >= 1: | |
+ raw_observation_input = MLP( | |
+ input_encoding.output_size, | |
+ raw_observation_input_arch[-1], | |
+ hidden_sizes=raw_observation_input_arch[:-1], | |
+ activation_last_fc=True, | |
+ layer_norm_last_fc=layer_norm, | |
+ layer_norm=layer_norm, | |
+ bias=bias, | |
+ ) | |
+ mlp_input_size += raw_observation_input.output_size | |
+ raw_observation_input = nn.Sequential( | |
+ input_encoding, | |
+ raw_observation_input, | |
+ ) | |
+ else: | |
+ raw_observation_input = input_encoding | |
+ mlp_input_size += input_encoding.output_size | |
else: | |
raw_observation_input = None | |
super().__init__( | |
@@ -84,6 +112,8 @@ class Dynamics(MLP): | |
hidden_sizes=hidden_sizes, | |
zero_init_last_fc=residual, | |
layer_norm=layer_norm, | |
+ bias=bias, | |
+ bias_last_fc=bias_last_fc, | |
) | |
self.action_input = action_input | |
self.add_module('raw_observation_input', raw_observation_input) | |
@@ -99,6 +129,7 @@ class Dynamics(MLP): | |
latent.latent_size, | |
hidden_sizes=[], # no hidden layers | |
zero_init_last_fc=residual, | |
+ bias_last_fc=bias_last_fc, | |
).module[0]) | |
with torch.no_grad(): | |
fuse_fc.weight[:, -latent.latent_size:].add_( | |
@@ -113,7 +144,7 @@ class Dynamics(MLP): | |
inputs: List[torch.Tensor] = [] | |
if self.raw_observation_input is not None: | |
inputs.append(self.raw_observation_input(observation)) | |
- if self.latent_input is not None: | |
+ if self.latent_input: | |
inputs.append(zx) | |
inputs.append(self.action_input(action)) | |
diff --git a/quasimetric_rl/modules/quasimetric_critic/models/quasimetric_model.py b/quasimetric_rl/modules/quasimetric_critic/models/quasimetric_model.py | |
index 92fbc58..6464a45 100644 | |
--- a/quasimetric_rl/modules/quasimetric_critic/models/quasimetric_model.py | |
+++ b/quasimetric_rl/modules/quasimetric_critic/models/quasimetric_model.py | |
@@ -20,7 +20,7 @@ class L2(torchqmet.QuasimetricBase): | |
super().__init__(input_size, num_components=1, guaranteed_quasimetric=True, | |
transforms=[], reduction='sum', discount=None) | |
- def compute_components(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
+ def compute_components(self, x: torch.Tensor, y: torch.Tensor, symmetric_upperbound: bool = False) -> torch.Tensor: | |
r''' | |
Inputs: | |
x (torch.Tensor): Shape [..., input_size] | |
@@ -41,7 +41,7 @@ def create_quasimetric_head_from_spec(spec: str) -> torchqmet.QuasimetricBase: | |
assert dim % components == 0, "IQE: dim must be divisible by components" | |
return torchqmet.IQE(dim, dim // components) | |
- def iqe2(*, dim: int, components: int, norm_delta: bool = False, fake_grad: bool = False, reduction: str = 'maxl12_sm') -> torchqmet.IQE: | |
+ def iqe2(*, dim: int, components: int, norm_delta: bool = False, fake_grad: bool = False, reduction: str = 'maxl12_sm', version=None) -> torchqmet.IQE2: | |
assert dim % components == 0, "IQE: dim must be divisible by components" | |
return torchqmet.IQE2( | |
dim, dim // components, | |
@@ -51,6 +51,7 @@ def create_quasimetric_head_from_spec(spec: str) -> torchqmet.QuasimetricBase: | |
div_init_mul=0.25, | |
mul_kind='normdeltadiv' if norm_delta else 'normdiv', | |
fake_grad=fake_grad, | |
+ version=version, | |
) | |
def l2(*, dim: int) -> L2: | |
@@ -125,10 +126,12 @@ class QuasimetricModel(Module): | |
layer_norm=projector_layer_norm, | |
dropout=projector_dropout, | |
weight_norm_last_fc=projector_weight_norm, | |
- unit_norm_last_fc=projector_unit_norm) | |
+ unit_norm_last_fc=projector_unit_norm, | |
+ ) | |
def forward(self, zx: LatentTensor, zy: LatentTensor, *, bidirectional: bool = False, | |
- proj_grad_enabled: Tuple[bool, bool] = (True, True)) -> torch.Tensor: | |
+ proj_grad_enabled: Tuple[bool, bool] = (True, True), reduced: bool = True, | |
+ symmetric_upperbound: bool = False) -> torch.Tensor: | |
zx = zx[..., :self.input_slice_size] | |
zy = zy[..., :self.input_slice_size] | |
with self.projector.requiring_grad(proj_grad_enabled[0]): | |
@@ -140,7 +143,7 @@ class QuasimetricModel(Module): | |
px, py = torch.broadcast_tensors(px, py) | |
px, py = torch.stack([px, py], dim=-2), torch.stack([py, px], dim=-2) # [B x 2 x D] | |
- return self.quasimetric_head(px, py) | |
+ return self.quasimetric_head(px, py, reduced=reduced, symmetric_upperbound=symmetric_upperbound) | |
def parameters(self, recurse: bool = True, *, include_head: bool = True) -> Iterator[Parameter]: | |
if include_head: | |
@@ -153,8 +156,12 @@ class QuasimetricModel(Module): | |
# for type hint | |
def __call__(self, zx: LatentTensor, zy: LatentTensor, *, bidirectional: bool = False, | |
- proj_grad_enabled: Tuple[bool, bool] = (True, True)) -> torch.Tensor: | |
- return super().__call__(zx, zy, bidirectional=bidirectional, proj_grad_enabled=proj_grad_enabled) | |
+ proj_grad_enabled: Tuple[bool, bool] = (True, True), reduced: bool = True, | |
+ symmetric_upperbound: bool = False) -> torch.Tensor: | |
+ return super().__call__(zx, zy, bidirectional=bidirectional, | |
+ proj_grad_enabled=proj_grad_enabled, | |
+ reduced=reduced, | |
+ symmetric_upperbound=symmetric_upperbound) | |
def extra_repr(self) -> str: | |
return f"input_size={self.input_size}, input_slice_size={self.input_slice_size}" | |
diff --git a/quasimetric_rl/modules/utils.py b/quasimetric_rl/modules/utils.py | |
index 5d7a845..8fbe219 100644 | |
--- a/quasimetric_rl/modules/utils.py | |
+++ b/quasimetric_rl/modules/utils.py | |
@@ -33,11 +33,18 @@ InfoValT = Union[InfoT, float, torch.Tensor] | |
@attrs.define(kw_only=True) | |
class LossResult: | |
- loss: Union[torch.Tensor, float] | |
+ loss: InfoValT | |
info: InfoT | |
def __attrs_post_init__(self): | |
- assert isinstance(self.loss, (int, float)) or self.loss.numel() == 1 | |
+ def test_loss(loss: InfoValT): | |
+ if isinstance(loss, Mapping): | |
+ for v in loss.values(): | |
+ test_loss(v) | |
+ else: | |
+ assert isinstance(loss, (int, float)) or loss.numel() == 1 | |
+ | |
+ test_loss(self.loss) | |
# detach info tensors | |
def detach(d: InfoValT) -> InfoValT: | |
@@ -50,17 +57,40 @@ class LossResult: | |
object.__setattr__(self, "info", detach(self.info)) | |
+ def map_losses(self, fn: Callable[[torch.Tensor | float], torch.Tensor | float]) -> InfoValT: | |
+ def map_loss(loss: InfoValT): | |
+ if isinstance(loss, Mapping): | |
+ return {k: map_loss(v) for k, v in loss.items()} | |
+ else: | |
+ return fn(loss) | |
+ | |
+ return map_loss(self.loss) | |
+ | |
@classmethod | |
def empty(cls) -> 'LossResult': | |
return LossResult(loss=0, info={}) | |
+ @property | |
+ def total_loss(self) -> torch.Tensor: | |
+ l = 0 | |
+ def add_loss(loss: InfoValT): | |
+ nonlocal l | |
+ if isinstance(loss, Mapping): | |
+ for v in loss.values(): | |
+ add_loss(v) | |
+ else: | |
+ l += loss | |
+ add_loss(self.loss) | |
+ assert isinstance(l, torch.Tensor) | |
+ return l | |
+ | |
@classmethod | |
def combine(cls, results: Mapping[str, 'LossResult'], **kwargs) -> 'LossResult': | |
info = {k: r.info for k, r in results.items()} | |
assert info.keys().isdisjoint(kwargs.keys()) | |
info.update(**kwargs) | |
return LossResult( | |
- loss=sum(r.loss for r in results.values()), | |
+ loss={k: r.loss for k, r in results.items()}, | |
info=info, | |
) | |
@@ -246,6 +276,7 @@ class MLP(Module): | |
layer_norm_last_fc: bool = False, | |
weight_norm_last_fc: bool = False, | |
unit_norm_last_fc: bool = False, | |
+ bias: bool = True, | |
bias_last_fc: bool = True, | |
# final_layer_norm: bool = False, # useful when not output | |
# final_layer: Callable[[int], nn.Module] = lambda _: nn.Identity, | |
@@ -261,7 +292,7 @@ class MLP(Module): | |
for ii, sz in enumerate(hidden_sizes): | |
modules.append( | |
TDMPC2Linear(layer_in_size, sz, dropout=(ii == 0) * dropout, | |
- layer_norm=layer_norm, activation_fn=activation_fn), | |
+ layer_norm=layer_norm, activation_fn=activation_fn, bias=bias), | |
) | |
layer_in_size = sz | |
modules.append(TDMPC2Linear(layer_in_size, output_size, | |
Submodule third_party/torch-quasimetric 0fce12e..89f107c: | |
diff --git a/third_party/torch-quasimetric/torchqmet/__init__.py b/third_party/torch-quasimetric/torchqmet/__init__.py | |
index 3afb467..b776de5 100644 | |
--- a/third_party/torch-quasimetric/torchqmet/__init__.py | |
+++ b/third_party/torch-quasimetric/torchqmet/__init__.py | |
@@ -54,19 +54,22 @@ class QuasimetricBase(nn.Module, metaclass=abc.ABCMeta): | |
''' | |
pass | |
- def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
+ def forward(self, x: torch.Tensor, y: torch.Tensor, *, reduced: bool = True, **kwargs) -> torch.Tensor: | |
assert x.shape[-1] == y.shape[-1] == self.input_size | |
- d = self.compute_components(x, y) | |
+ d = self.compute_components(x, y, **kwargs) | |
d: torch.Tensor = self.transforms(d) | |
scale = self.scale | |
if not self.training: | |
scale = scale.detach() | |
- return self.reduction(d) * scale | |
+ if reduced: | |
+ return self.reduction(d) * scale | |
+ else: | |
+ return d * scale | |
- def __call__(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
+ def __call__(self, x: torch.Tensor, y: torch.Tensor, reduced: bool = True, **kwargs) -> torch.Tensor: | |
# Manually define for typing | |
# https://github.com/pytorch/pytorch/issues/45414 | |
- return super().__call__(x, y) | |
+ return super().__call__(x, y, reduced=reduced, **kwargs) | |
def extra_repr(self) -> str: | |
return f"guaranteed_quasimetric={self.guaranteed_quasimetric}\ninput_size={self.input_size}, num_components={self.num_components}" + ( | |
@@ -79,5 +82,5 @@ from .iqe import IQE, IQE2 | |
from .mrn import MRN, MRNFixed | |
from .neural_norms import DeepNorm, WideNorm | |
-__all__ = ['PQE', 'PQELH', 'PQEGG', 'IQE', 'MRN', 'MRNFixed', 'DeepNorm', 'WideNorm'] | |
+__all__ = ['PQE', 'PQELH', 'PQEGG', 'IQE', 'IQE2', 'MRN', 'MRNFixed', 'DeepNorm', 'WideNorm'] | |
__version__ = "0.1.0" | |
diff --git a/third_party/torch-quasimetric/torchqmet/iqe.py b/third_party/torch-quasimetric/torchqmet/iqe.py | |
index a8e8c92..29470a8 100644 | |
--- a/third_party/torch-quasimetric/torchqmet/iqe.py | |
+++ b/third_party/torch-quasimetric/torchqmet/iqe.py | |
@@ -41,6 +41,7 @@ def iqe_tensor_delta(x: torch.Tensor, y: torch.Tensor, delta: torch.Tensor, div_ | |
).detach() | |
neg_f = torch.expm1(neg_f_input) | |
+ # neg_incf = torch.diff(neg_f, dim=-1, prepend=neg_f.new_zeros(()).expand_as(neg_f[..., :1])) | |
neg_incf = torch.cat([neg_f.narrow(-1, 0, 1), torch.diff(neg_f, dim=-1)], dim=-1) | |
# reduction | |
@@ -62,7 +63,6 @@ def iqe_tensor_delta(x: torch.Tensor, y: torch.Tensor, delta: torch.Tensor, div_ | |
return comp | |
- | |
def iqe(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
D = x.shape[-1] # D: dim_per_component | |
@@ -88,6 +88,7 @@ def iqe(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
# neg_inp = torch.where(neg_inp_copies == 0, 0., -delta) | |
# f output: 0 -> 0, x -> 1. | |
neg_f = (neg_inp_copies < 0) * (-1.) | |
+ # neg_incf = torch.diff(neg_f, dim=-1, prepend=neg_f.new_zeros(()).expand_as(neg_f[..., :1])) | |
neg_incf = torch.cat([neg_f.narrow(-1, 0, 1), torch.diff(neg_f, dim=-1)], dim=-1) | |
# reduction | |
@@ -111,12 +112,23 @@ def is_notebook(): | |
if torch.__version__ >= '2.0.1' and not is_notebook(): # well, broken process pool in notebooks | |
- iqe = torch.compile(iqe, mode="max-autotune") | |
- iqe_tensor_delta = torch.compile(iqe_tensor_delta, mode="max-autotune") | |
+ iqe = torch.compile(iqe) | |
+ _iqe_tensor_delta = torch.compile(iqe_tensor_delta) | |
+ _iqe_tensor_delta_jit = torch.jit.script(iqe_tensor_delta) | |
+ _default_iqe_tensor_delta_version = 'compile' | |
# iqe = torch.compile(iqe, dynamic=True) | |
else: | |
iqe = torch.jit.script(iqe) # type: ignore | |
- iqe_tensor_delta = torch.jit.script(iqe_tensor_delta) # type: ignore | |
+ _iqe_tensor_delta = None | |
+ _iqe_tensor_delta_jit = torch.jit.script(iqe_tensor_delta) # type: ignore | |
+ _default_iqe_tensor_delta_version = 'jit' | |
+ | |
+ | |
+ | |
+def get_iqe_tensor_delta(version): | |
+ fn = dict(jit=_iqe_tensor_delta_jit, compile=_iqe_tensor_delta)[version] | |
+ assert fn is not None, f"version={version} not supported" | |
+ return fn | |
class IQE(QuasimetricBase): | |
@@ -243,7 +255,8 @@ class IQE2(IQE): | |
component_dropout_thresh: Tuple[float, float] = (0.5, 2), | |
dropout_p_thresh: Tuple[float, float] = (0.005, 0.995), | |
dropout_batch_frac: float = 0.2, | |
- ema_weight: float = 0.95): | |
+ ema_weight: float = 0.95, | |
+ version: Optional[str] = None): | |
super().__init__(input_size, dim_per_component, transforms=transforms, reduction=reduction, | |
discount=discount, warn_if_not_quasimetric=warn_if_not_quasimetric) | |
self.component_dropout_thresh = tuple(component_dropout_thresh) # type: ignore | |
@@ -302,9 +315,10 @@ class IQE2(IQE): | |
self.mul_kind = mul_kind | |
self.last_components = None # type: ignore | |
self.last_drop_p = None # type: ignore | |
+ self.version = version or _default_iqe_tensor_delta_version | |
- def compute_components(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
+ def compute_components(self, x: torch.Tensor, y: torch.Tensor, *, symmetric_upperbound: bool = False) -> torch.Tensor: | |
# if self.raw_delta is None: | |
# components = super().compute_components(x, y) | |
# else: | |
@@ -314,7 +328,10 @@ class IQE2(IQE): | |
delta.data.clamp_(max=1e3 / (self.latent_2d_shape[-1] / 8)) | |
div_pre_f.data.clamp_(min=1e-3) | |
- components = iqe_tensor_delta( | |
+ if symmetric_upperbound: | |
+ x, y = torch.minimum(x, y), torch.maximum(x, y) | |
+ | |
+ components = get_iqe_tensor_delta(self.version)( # type: ignore | |
x=x.unflatten(-1, self.latent_2d_shape), | |
y=y.unflatten(-1, self.latent_2d_shape), | |
delta=delta, | |
@@ -384,4 +401,5 @@ learned_delta={self.raw_delta is not None}, | |
learned_div={self.raw_div.requires_grad}, | |
div_init_mul={self.div_init_mul:g}, | |
mul_kind={self.mul_kind}, | |
-fake_grad={self.fake_grad},""" | |
+fake_grad={self.fake_grad}, | |
+version={self.version}""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment