Skip to content

Instantly share code, notes, and snippets.

@ssnl
Created March 16, 2024 14:03
Show Gist options
  • Save ssnl/ebd0e033ba7c87641709b82164370f36 to your computer and use it in GitHub Desktop.
Save ssnl/ebd0e033ba7c87641709b82164370f36 to your computer and use it in GitHub Desktop.
diff --git a/quasimetric_rl/data/base.py b/quasimetric_rl/data/base.py
index 09711c4..8412f7d 100644
--- a/quasimetric_rl/data/base.py
+++ b/quasimetric_rl/data/base.py
@@ -208,6 +208,7 @@ class Dataset(torch.utils.data.Dataset):
kind: str = MISSING # d4rl, gcrl, etc.
name: str = MISSING # maze2d-umaze-v1, etc.
+ horizon: int = attrs.field(default=1, validator=attrs.validators.gt(0)) # type: ignore
# Defines how to fetch the future observation. smaller -> more recent
future_observation_discount: float = attrs.field(default=0.99, validator=attrs.validators.and_(
@@ -217,6 +218,7 @@ class Dataset(torch.utils.data.Dataset):
def make(self, *, dummy: bool = False) -> 'Dataset':
return Dataset(self.kind, self.name,
+ horizon=self.horizon,
future_observation_discount=self.future_observation_discount,
dummy=dummy)
@@ -278,6 +280,7 @@ class Dataset(torch.utils.data.Dataset):
return NORMALIZE_SCORE_REGISTRY[self.kind, self.name](timestep_reward)
def __init__(self, kind: str, name: str, *,
+ horizon: int,
future_observation_discount: float,
dummy: bool = False, # when you don't want to load data, e.g., in analysis
device: torch.device = torch.device('cpu'), # FIXME: get some heuristic
@@ -285,6 +288,7 @@ class Dataset(torch.utils.data.Dataset):
self.kind = kind
self.name = name
self.future_observation_discount = future_observation_discount
+ self.horizon = horizon
self.env_spec = EnvSpec.from_env(self.create_env())
@@ -297,19 +301,19 @@ class Dataset(torch.utils.data.Dataset):
from .utils import get_empty_episode
episodes = (get_empty_episode(self.env_spec, episode_length=1),)
- obs_indices_to_obs_index_in_episode = []
+ # obs_indices_to_obs_index_in_episode = []
indices_to_episode_indices = []
indices_to_episode_timesteps = []
for eidx, episode in enumerate(episodes):
l = episode.num_transitions
- obs_indices_to_obs_index_in_episode.append(torch.arange(l + 1, dtype=torch.int64))
- indices_to_episode_indices.append(torch.full([l], eidx, dtype=torch.int64))
- indices_to_episode_timesteps.append(torch.arange(l, dtype=torch.int64))
+ # obs_indices_to_obs_index_in_episode.append(torch.arange(l + 1, dtype=torch.int64))
+ indices_to_episode_indices.append(torch.full([l + 1 - horizon], eidx, dtype=torch.int64))
+ indices_to_episode_timesteps.append(torch.arange(l + 1 - horizon, dtype=torch.int64))
assert len(episodes) > 0, "must have at least one episode"
self.raw_data = MultiEpisodeData.cat(episodes).to(device)
- self.obs_indices_to_obs_index_in_episode = torch.cat(obs_indices_to_obs_index_in_episode, dim=0).to(device)
+ # self.obs_indices_to_obs_index_in_episode = torch.cat(obs_indices_to_obs_index_in_episode, dim=0).to(device)
self.indices_to_episode_indices = torch.cat(indices_to_episode_indices, dim=0).to(device)
self.indices_to_episode_timesteps = torch.cat(indices_to_episode_timesteps, dim=0).to(device)
self.max_episode_length = int(self.raw_data.episode_lengths.max().item())
@@ -324,19 +328,22 @@ class Dataset(torch.utils.data.Dataset):
def __getitem__(self, indices: torch.Tensor) -> BatchData:
indices = torch.as_tensor(indices, device=self.device)
eindices = self.indices_to_episode_indices[indices]
+ indices = indices[..., None] + torch.arange(self.horizon + 1, device=self.device)
+ eindices = eindices[..., None].expand(indices.shape)
obs_indices = indices + eindices # index for `observation`: skip the s_last from previous episodes
obs = self.get_observations(obs_indices)
- nobs = self.get_observations(obs_indices + 1)
-
- terminals = self.raw_data.terminals[indices]
+ # nobs = self.get_observations(obs_indices + 1)
+ indices = indices[..., :-1] # remove the last one which is the next observation
+ obs_indices = obs_indices[..., :-1]
+ eindices = eindices[..., :-1]
tindices = self.indices_to_episode_timesteps[indices]
epilengths = self.raw_data.episode_lengths[eindices] # max idx is this
deltas = torch.arange(self.max_episode_length, device=self.device)
pdeltas = torch.where(
# test tidx + 1 + delta <= max_idx = epi_length
- (tindices[:, None] + deltas) < epilengths[:, None],
- self.future_observation_discount ** deltas,
+ (tindices[..., None] + deltas) < epilengths[..., None],
+ (self.future_observation_discount ** deltas).expand(tindices.shape + (self.max_episode_length,)),
0,
)
deltas = torch.distributions.Categorical(
@@ -346,24 +353,25 @@ class Dataset(torch.utils.data.Dataset):
future_observations = self.get_observations(obs_indices + deltas)
return BatchData(
- observations=obs,
- actions=self.raw_data.actions[indices],
- next_observations=nobs,
- future_observations=future_observations,
- future_tdelta=deltas,
- rewards=self.raw_data.rewards[indices],
- terminals=terminals,
- timeouts=self.raw_data.timeouts[indices],
+ observations=obs.select(indices.ndim - 1, 0),
+ actions=self.raw_data.actions[indices].squeeze(indices.ndim - 1),
+ next_observations=obs.select(indices.ndim - 1, -1),
+ future_observations=future_observations.squeeze(indices.ndim - 1),
+ future_tdelta=deltas.squeeze(indices.ndim - 1),
+ rewards=self.raw_data.rewards[indices].squeeze(indices.ndim - 1),
+ terminals=self.raw_data.terminals[indices].squeeze(indices.ndim - 1),
+ timeouts=self.raw_data.timeouts[indices].squeeze(indices.ndim - 1),
)
def __len__(self):
- return self.raw_data.num_transitions
+ return self.indices_to_episode_indices.shape[0]
def __repr__(self):
return rf"""
{self.__class__.__name__}(
kind={self.kind!r},
name={self.name!r},
+ horizon={self.horizon!r},
future_observation_discount={self.future_observation_discount!r},
env_spec={self.env_spec!r},
)""".lstrip('\n')
diff --git a/quasimetric_rl/modules/__init__.py b/quasimetric_rl/modules/__init__.py
index 1ed2315..81b5443 100644
--- a/quasimetric_rl/modules/__init__.py
+++ b/quasimetric_rl/modules/__init__.py
@@ -9,6 +9,7 @@ from . import actor, quasimetric_critic
from ..data import EnvSpec, BatchData
from .utils import LossResult, Module, InfoT, InfoValT
+from ..flags import FLAGS
class QRLAgent(Module):
@@ -33,8 +34,7 @@ class QRLLosses(Module):
critics_total_grad_clip_norm: Optional[float],
recompute_critic_for_actor_loss: bool,
critics_share_embedding: bool,
- critic_losses_use_target_encoder: bool,
- actor_loss_uses_target_encoder: bool):
+ critic_losses_use_target_encoder: bool):
super().__init__()
self.add_module('actor_loss', actor_loss)
self.critic_losses = torch.nn.ModuleList(critic_losses) # type: ignore
@@ -42,7 +42,6 @@ class QRLLosses(Module):
self.recompute_critic_for_actor_loss = recompute_critic_for_actor_loss
self.critics_share_embedding = critics_share_embedding
self.critic_losses_use_target_encoder = critic_losses_use_target_encoder
- self.actor_loss_uses_target_encoder = actor_loss_uses_target_encoder
def forward(self, agent: QRLAgent, data: BatchData, *, optimize: bool = True) -> LossResult:
# compute CriticBatchInfo
@@ -54,11 +53,7 @@ class QRLLosses(Module):
stack.enter_context(critic_loss.optim_update_context(optimize=optimize))
if self.critics_share_embedding and idx > 0:
- critic_batch_info = quasimetric_critic.CriticBatchInfo(
- critic=critic,
- zx=critic_batch_infos[0].zx,
- zy=critic_batch_infos[0].zy,
- )
+ critic_batch_info = attrs.evolve(critic_batch_infos[0], critic=critic)
else:
zx = critic.encoder(data.observations)
zy = critic.get_encoder(target=self.critic_losses_use_target_encoder)(data.next_observations)
@@ -68,40 +63,96 @@ class QRLLosses(Module):
critic=critic,
zx=zx,
zy=zy,
+ zy_from_target_encoder=self.critic_losses_use_target_encoder,
)
loss_results[f"critic_{idx:02d}"] = critic_loss(data, critic_batch_info) # we update together to handle shared embedding
critic_batch_infos.append(critic_batch_info)
+ critic_grad_norm: InfoValT = {}
+
+ if FLAGS.DEBUG:
+ def get_grad(loss: Union[torch.Tensor, float]) -> Union[torch.Tensor, float]:
+ if isinstance(loss, (int, float)):
+ return 0
+ loss_grads = torch.autograd.grad(
+ loss,
+ list(cast(torch.nn.ModuleList, agent.critics).parameters()),
+ retain_graph=True,
+ allow_unused=True,
+ )
+ return cast(
+ torch.Tensor,
+ sum(pg.pow(2).sum() for pg in loss_grads if pg is not None),
+ ).sqrt()
+
+ for k, loss_r in loss_results.items():
+ critic_grad_norm.update({
+ k: loss_r.map_losses(get_grad),
+ })
+
torch.stack(
- [cast(torch.Tensor, loss_r.loss) for loss_r in loss_results.values()]
+ [loss_r.total_loss for loss_r in loss_results.values()]
).sum().backward()
if self.critics_total_grad_clip_norm is not None:
- torch.nn.utils.clip_grad_norm_(cast(torch.nn.ModuleList, agent.critics).parameters(),
- max_norm=self.critics_total_grad_clip_norm)
+ critic_grad_norm['total'] = torch.nn.utils.clip_grad_norm_(
+ cast(torch.nn.ModuleList, agent.critics).parameters(),
+ max_norm=self.critics_total_grad_clip_norm)
+ else:
+ critic_grad_norm['total'] = cast(
+ torch.Tensor,
+ sum(p.grad.pow(2).sum() for p in cast(torch.nn.ModuleList, agent.critics).parameters() if p.grad is not None),
+ ).sqrt()
if optimize:
for critic in agent.critics:
- critic.update_target_encoder_()
+ critic.update_target_models_()
+ actor_grad_norm: InfoValT = {}
if self.actor_loss is not None:
assert agent.actor is not None
with torch.no_grad(), torch.inference_mode():
- for idx, critic in enumerate(agent.critics):
- if self.recompute_critic_for_actor_loss or (critic.has_separate_target_encoder and self.actor_loss_uses_target_encoder):
- zx, zy = critic.get_encoder(target=self.actor_loss_uses_target_encoder)(
+ for idx, critic in enumerate(agent.critics): # FIXME?
+ if self.recompute_critic_for_actor_loss or (critic.has_separate_target_encoder and self.actor_loss.use_target_encoder):
+ zx, zy = critic.get_encoder(target=self.actor_loss.use_target_encoder)(
torch.stack([data.observations, data.next_observations], dim=0)).unbind(0)
critic_batch_infos[idx] = quasimetric_critic.CriticBatchInfo(
critic=critic,
zx=zx,
zy=zy,
+ zx_from_target_encoder=self.actor_loss.use_target_encoder,
+ zy_from_target_encoder=self.actor_loss.use_target_encoder,
)
with self.actor_loss.optim_update_context(optimize=optimize):
loss_results['actor'] = loss_r = self.actor_loss(agent.actor, critic_batch_infos, data)
- cast(torch.Tensor, loss_r.loss).backward()
- return LossResult.combine(loss_results)
+ if FLAGS.DEBUG:
+ def get_grad(loss: Union[torch.Tensor, float]) -> Union[torch.Tensor, float]:
+ if isinstance(loss, (int, float)):
+ return 0
+ assert agent.actor is not None
+ loss_grads = torch.autograd.grad(
+ loss,
+ list(agent.actor.parameters()),
+ retain_graph=True,
+ allow_unused=True,
+ )
+ return cast(
+ torch.Tensor,
+ sum(pg.pow(2).sum() for pg in loss_grads if pg is not None),
+ ).sqrt()
+
+ actor_grad_norm.update(cast(Mapping, loss_r.map_losses(get_grad)))
+
+ loss_r.total_loss.backward()
+
+ actor_grad_norm['total'] = cast(
+ torch.Tensor,
+ sum(p.grad.pow(2).sum() for p in agent.actor.parameters() if p.grad is not None),
+ ).sqrt()
+
+ return LossResult.combine(loss_results, grad_norm=dict(critic=critic_grad_norm, actor=actor_grad_norm))
# for type hints
def __call__(self, agent: QRLAgent, data: BatchData, *, optimize: bool = True) -> LossResult:
@@ -152,7 +203,6 @@ class QRLLosses(Module):
f'critics_share_embedding={self.critics_share_embedding}',
f'critics_total_grad_clip_norm={self.critics_total_grad_clip_norm}',
f'critic_losses_use_target_encoder={self.critic_losses_use_target_encoder}',
- f'actor_loss_uses_target_encoder={self.actor_loss_uses_target_encoder}',
])
@@ -167,7 +217,6 @@ class QRLConf:
)
recompute_critic_for_actor_loss: bool = False
critic_losses_use_target_encoder: bool = True
- actor_loss_uses_target_encoder: bool = True
def make(self, *, env_spec: EnvSpec, total_optim_steps: int) -> Tuple[QRLAgent, QRLLosses]:
if self.actor is None:
@@ -193,7 +242,6 @@ class QRLConf:
critics_total_grad_clip_norm=self.critics_total_grad_clip_norm,
recompute_critic_for_actor_loss=self.recompute_critic_for_actor_loss,
critic_losses_use_target_encoder=self.critic_losses_use_target_encoder,
- actor_loss_uses_target_encoder=self.actor_loss_uses_target_encoder,
)
__all__ = ['QRLAgent', 'QRLLosses', 'QRLConf', 'InfoT', 'InfoValT']
diff --git a/quasimetric_rl/modules/actor/losses/__init__.py b/quasimetric_rl/modules/actor/losses/__init__.py
index c8f1ec7..f9c78a1 100644
--- a/quasimetric_rl/modules/actor/losses/__init__.py
+++ b/quasimetric_rl/modules/actor/losses/__init__.py
@@ -15,12 +15,15 @@ from ...optim import OptimWrapper, AdamWSpec, LRScheduler
class ActorLossBase(LossBase):
@abc.abstractmethod
- def forward(self, actor: Actor, critic_batch_infos: Collection[CriticBatchInfo], data: BatchData) -> LossResult:
+ def forward(self, actor: Actor, critic_batch_infos: Collection[CriticBatchInfo], data: BatchData,
+ use_target_encoder: bool, use_target_quasimetric_model: bool) -> LossResult:
pass
# for type hints
- def __call__(self, actor: Actor, critic_batch_infos: Collection[CriticBatchInfo], data: BatchData) -> LossResult:
- return super().__call__(actor, critic_batch_infos, data)
+ def __call__(self, actor: Actor, critic_batch_infos: Collection[CriticBatchInfo], data: BatchData,
+ use_target_encoder: bool, use_target_quasimetric_model: bool) -> LossResult:
+ return super().__call__(actor, critic_batch_infos, data, use_target_encoder=use_target_encoder,
+ use_target_quasimetric_model=use_target_quasimetric_model)
from .min_dist import MinDistLoss
@@ -40,6 +43,9 @@ class ActorLosses(ActorLossBase):
actor_optim: AdamWSpec.Conf = AdamWSpec.Conf(lr=3e-5)
entropy_weight_optim: AdamWSpec.Conf = AdamWSpec.Conf(lr=3e-4) # TODO: move to loss
+ use_target_encoder: bool = True
+ use_target_quasimetric_model: bool = True
+
def make(self, actor: Actor, total_optim_steps: int, env_spec: EnvSpec) -> 'ActorLosses':
return ActorLosses(
actor,
@@ -49,6 +55,8 @@ class ActorLosses(ActorLossBase):
advantage_weighted_regression=self.advantage_weighted_regression.make(),
actor_optim_spec=self.actor_optim.make(),
entropy_weight_optim_spec=self.entropy_weight_optim.make(),
+ use_target_encoder=self.use_target_encoder,
+ use_target_quasimetric_model=self.use_target_quasimetric_model,
)
min_dist: Optional[MinDistLoss]
@@ -60,10 +68,14 @@ class ActorLosses(ActorLossBase):
entropy_weight_optim: OptimWrapper
entropy_weight_sched: LRScheduler
+ use_target_encoder: bool
+ use_target_quasimetric_model: bool
+
def __init__(self, actor: Actor, *, total_optim_steps: int,
min_dist: Optional[MinDistLoss], behavior_cloning: Optional[BCLoss],
advantage_weighted_regression: Optional[AWRLoss],
- actor_optim_spec: AdamWSpec, entropy_weight_optim_spec: AdamWSpec):
+ actor_optim_spec: AdamWSpec, entropy_weight_optim_spec: AdamWSpec,
+ use_target_encoder: bool, use_target_quasimetric_model: bool):
super().__init__()
self.add_module('min_dist', min_dist)
self.add_module('behavior_cloning', behavior_cloning)
@@ -76,6 +88,9 @@ class ActorLosses(ActorLossBase):
if min_dist is not None:
assert len(list(min_dist.parameters())) <= 1
+ self.use_target_encoder = use_target_encoder
+ self.use_target_quasimetric_model = use_target_quasimetric_model
+
def optimizers(self) -> Iterable[OptimWrapper]:
return [self.actor_optim, self.entropy_weight_optim]
@@ -86,18 +101,24 @@ class ActorLosses(ActorLossBase):
loss_results: Dict[str, LossResult] = {}
if self.min_dist is not None:
loss_results.update(
- min_dist=self.min_dist(actor, critic_batch_infos, data),
+ min_dist=self.min_dist(actor, critic_batch_infos, data, self.use_target_encoder, self.use_target_quasimetric_model),
)
if self.behavior_cloning is not None:
loss_results.update(
- bc=self.behavior_cloning(actor, critic_batch_infos, data),
+ bc=self.behavior_cloning(actor, critic_batch_infos, data, self.use_target_encoder, self.use_target_quasimetric_model),
)
if self.advantage_weighted_regression is not None:
loss_results.update(
- awr=self.advantage_weighted_regression(actor, critic_batch_infos, data),
+ awr=self.advantage_weighted_regression(actor, critic_batch_infos, data, self.use_target_encoder, self.use_target_quasimetric_model),
)
return LossResult.combine(loss_results)
# for type hints
def __call__(self, actor: Actor, critic_batch_infos: Sequence[CriticBatchInfo], data: BatchData) -> LossResult:
return torch.nn.Module.__call__(self, actor, critic_batch_infos, data)
+
+ def extra_repr(self) -> str:
+ return '\n'.join([
+ f"actor_optim={self.actor_optim}, entropy_weight_optim={self.entropy_weight_optim}",
+ f"use_target_encoder={self.use_target_encoder}, use_target_quasimetric_model={self.use_target_quasimetric_model}",
+ ])
diff --git a/quasimetric_rl/modules/actor/losses/awr.py b/quasimetric_rl/modules/actor/losses/awr.py
index f6cf4de..1313dc3 100644
--- a/quasimetric_rl/modules/actor/losses/awr.py
+++ b/quasimetric_rl/modules/actor/losses/awr.py
@@ -81,8 +81,8 @@ class AWRLoss(ActorLossBase):
self.clamp = clamp
self.add_goal_as_future_state = add_goal_as_future_state
- def gather_obs_goal_pairs(self, critic_batch_infos: Sequence[CriticBatchInfo],
- data: BatchData) -> Tuple[torch.Tensor, torch.Tensor, Collection[ActorObsGoalCriticInfo]]:
+ def gather_obs_goal_pairs(self, critic_batch_infos: Sequence[CriticBatchInfo], data: BatchData,
+ use_target_encoder: bool) -> Tuple[torch.Tensor, torch.Tensor, Collection[ActorObsGoalCriticInfo]]:
r"""
Returns (
obs,
@@ -115,7 +115,7 @@ class AWRLoss(ActorLossBase):
# add future_observations
zg = torch.stack([
zg,
- critic_batch_info.critic.target_encoder(data.future_observations),
+ critic_batch_info.critic.get_encoder(target=use_target_encoder)(data.future_observations),
], 0)
# zo = zo.expand_as(zg)
@@ -127,9 +127,11 @@ class AWRLoss(ActorLossBase):
return obs, goal, actor_obs_goal_critic_infos
- def forward(self, actor: Actor, critic_batch_infos: Sequence[CriticBatchInfo], data: BatchData) -> LossResult:
+ def forward(self, actor: Actor, critic_batch_infos: Sequence[CriticBatchInfo], data: BatchData,
+ use_target_encoder: bool, use_target_quasimetric_model: bool) -> LossResult:
+
with torch.no_grad():
- obs, goal, actor_obs_goal_critic_infos = self.gather_obs_goal_pairs(critic_batch_infos, data)
+ obs, goal, actor_obs_goal_critic_infos = self.gather_obs_goal_pairs(critic_batch_infos, data, use_target_encoder)
info: Dict[str, Union[float, torch.Tensor]] = {}
@@ -139,11 +141,13 @@ class AWRLoss(ActorLossBase):
for idx, actor_obs_goal_critic_info in enumerate(actor_obs_goal_critic_infos):
critic = actor_obs_goal_critic_info.critic
+ latent_dynamics = critic.latent_dynamics
+ quasimetric_model = critic.get_quasimetric_model(target=use_target_quasimetric_model)
zo = actor_obs_goal_critic_info.zo.detach() # [B,D]
zg = actor_obs_goal_critic_info.zg.detach() # [2?,B,D]
with torch.no_grad(), critic.mode(False):
- zp = critic.latent_dynamics(data.observations, zo, data.actions) # [B,D]
- if not critic.borrowing_embedding:
+ zp = latent_dynamics(data.observations, zo, data.actions) # [B,D]
+ if idx == 0 or not critic.borrowing_embedding:
zo, zp, zg = bcast_bshape(
(zo, 1),
(zp, 1),
@@ -151,10 +155,10 @@ class AWRLoss(ActorLossBase):
)
z = torch.stack([zo, zp], dim=0) # [2,2?,B,D]
# z = z[(slice(None),) + (None,) * (zg.ndim - zo.ndim) + (Ellipsis,)] # if zg is batched, add enough dim to be broadcast-able
- dist_noact, dist = critic.quasimetric_model(z, zg).unbind(0) # [2,2?,B] -> 2x [2?,B]
+ dist_noact, dist = quasimetric_model(z, zg).unbind(0) # [2,2?,B] -> 2x [2?,B]
dist_noact = dist_noact.detach()
else:
- dist = critic.quasimetric_model(zp, zg)
+ dist = quasimetric_model(zp, zg)
dist_noact = dists_noact[0]
info[f'dist_delta_{idx:02d}'] = (dist_noact - dist).mean()
info[f'dist_{idx:02d}'] = dist.mean()
diff --git a/quasimetric_rl/modules/actor/losses/behavior_cloning.py b/quasimetric_rl/modules/actor/losses/behavior_cloning.py
index 632221e..eba36af 100644
--- a/quasimetric_rl/modules/actor/losses/behavior_cloning.py
+++ b/quasimetric_rl/modules/actor/losses/behavior_cloning.py
@@ -34,7 +34,8 @@ class BCLoss(ActorLossBase):
super().__init__()
self.weight = weight
- def forward(self, actor: Actor, critic_batch_infos: Collection[CriticBatchInfo], data: BatchData) -> LossResult:
+ def forward(self, actor: Actor, critic_batch_infos: Collection[CriticBatchInfo], data: BatchData,
+ use_target_encoder: bool, use_target_quasimetric_model: bool) -> LossResult:
actor_distn = actor(data.observations, data.future_observations)
log_prob: torch.Tensor = actor_distn.log_prob(data.actions).mean()
loss = -log_prob * self.weight
diff --git a/quasimetric_rl/modules/actor/losses/min_dist.py b/quasimetric_rl/modules/actor/losses/min_dist.py
index bc36ee8..1d57609 100644
--- a/quasimetric_rl/modules/actor/losses/min_dist.py
+++ b/quasimetric_rl/modules/actor/losses/min_dist.py
@@ -87,8 +87,8 @@ class MinDistLoss(ActorLossBase):
self.register_parameter('raw_entropy_weight', None)
self.target_entropy = None
- def gather_obs_goal_pairs(self, critic_batch_infos: Sequence[CriticBatchInfo],
- data: BatchData) -> Tuple[torch.Tensor, torch.Tensor, Collection[ActorObsGoalCriticInfo]]:
+ def gather_obs_goal_pairs(self, critic_batch_infos: Sequence[CriticBatchInfo], data: BatchData,
+ use_target_encoder: bool) -> Tuple[torch.Tensor, torch.Tensor, Collection[ActorObsGoalCriticInfo]]:
r"""
Returns (
obs,
@@ -121,7 +121,7 @@ class MinDistLoss(ActorLossBase):
# add future_observations
zg = torch.stack([
zg,
- critic_batch_info.critic.target_encoder(data.future_observations),
+ critic_batch_info.critic.get_encoder(target=use_target_encoder)(data.future_observations),
], 0)
# zo = zo.expand_as(zg)
@@ -133,9 +133,10 @@ class MinDistLoss(ActorLossBase):
return obs, goal, actor_obs_goal_critic_infos
- def forward(self, actor: Actor, critic_batch_infos: Sequence[CriticBatchInfo], data: BatchData) -> LossResult:
+ def forward(self, actor: Actor, critic_batch_infos: Sequence[CriticBatchInfo], data: BatchData,
+ use_target_encoder: bool, use_target_quasimetric_model: bool) -> LossResult:
with torch.no_grad():
- obs, goal, actor_obs_goal_critic_infos = self.gather_obs_goal_pairs(critic_batch_infos, data)
+ obs, goal, actor_obs_goal_critic_infos = self.gather_obs_goal_pairs(critic_batch_infos, data, use_target_encoder)
actor_distn = actor(obs, goal)
action = actor_distn.rsample()
@@ -147,17 +148,19 @@ class MinDistLoss(ActorLossBase):
for idx, actor_obs_goal_critic_info in enumerate(actor_obs_goal_critic_infos):
critic = actor_obs_goal_critic_info.critic
+ latent_dynamics = critic.latent_dynamics
+ quasimetric_model = critic.get_quasimetric_model(target=use_target_quasimetric_model)
zo = actor_obs_goal_critic_info.zo.detach() # [B,D]
zg = actor_obs_goal_critic_info.zg.detach() # [2?,B,D]
with critic.requiring_grad(False), critic.mode(False):
- zp = critic.latent_dynamics(data.observations, zo, action) # [2?,B,D]
- if not critic.borrowing_embedding:
+ zp = latent_dynamics(data.observations, zo, action) # [2?,B,D]
+ if idx == 0 or not critic.borrowing_embedding:
# action: [2?,B,A]
z = torch.stack(torch.broadcast_tensors(zo, zp), dim=0) # [2,2?,B,D]
- dist_noact, dist = critic.quasimetric_model(z, zg).unbind(0) # [2,2?,B] -> 2x [2?,B]
+ dist_noact, dist = quasimetric_model(z, zg).unbind(0) # [2,2?,B] -> 2x [2?,B]
dist_noact = dist_noact.detach()
else:
- dist = critic.quasimetric_model(zp, zg) # [2?,B]
+ dist = quasimetric_model(zp, zg) # [2?,B]
dist_noact = dists_noact[0]
info[f'dist_delta_{idx:02d}'] = (dist_noact - dist).mean()
info[f'dist_{idx:02d}'] = dist.mean()
diff --git a/quasimetric_rl/modules/optim.py b/quasimetric_rl/modules/optim.py
index 5131ce0..5019e1b 100644
--- a/quasimetric_rl/modules/optim.py
+++ b/quasimetric_rl/modules/optim.py
@@ -1,6 +1,7 @@
from typing import *
import attrs
+import logging
import contextlib
import torch
@@ -91,9 +92,11 @@ class AdamWSpec:
if len(params) == 0:
params = [dict(params=[])] # dummy param group so pytorch doesn't complain
for ii in range(len(params)):
- if isinstance(params, Mapping) and 'lr_mul' in params: # handle lr_multiplier
- pg = params[ii]
+ pg = params[ii]
+ if isinstance(pg, Mapping) and 'lr_mul' in pg: # handle lr_multiplier
+ pg['params'] = params_list = cast(List[torch.Tensor], list(pg['params']))
assert 'lr' not in pg
+ logging.info(f'params (#tensor={len(params_list)}, #={sum(p.numel() for p in params_list)}): lr_mul={pg["lr_mul"]}')
pg['lr'] = self.lr * pg['lr_mul'] # type: ignore
del pg['lr_mul']
return OptimWrapper(
diff --git a/quasimetric_rl/modules/quasimetric_critic/losses/__init__.py b/quasimetric_rl/modules/quasimetric_critic/losses/__init__.py
index 48ed700..5487442 100644
--- a/quasimetric_rl/modules/quasimetric_critic/losses/__init__.py
+++ b/quasimetric_rl/modules/quasimetric_critic/losses/__init__.py
@@ -19,6 +19,8 @@ class CriticBatchInfo:
critic: QuasimetricCritic
zx: LatentTensor
zy: LatentTensor
+ zx_from_target_encoder: bool = False
+ zy_from_target_encoder: bool = False
class CriticLossBase(LossBase):
@@ -48,10 +50,10 @@ class QuasimetricCriticLosses(CriticLossBase):
latent_dynamics: LatentDynamicsLoss.Conf = LatentDynamicsLoss.Conf()
critic_optim: AdamWSpec.Conf = AdamWSpec.Conf(lr=1e-4)
- latent_dynamics_lr_mul: float = 1
- quasimetric_model_lr_mul: float = 1
- encoder_lr_mul: float = 1 # TD-MPC2 uses 0.3
- quasimetric_head_lr_mul: float = 1 # IQE2 can benefit from smaller lr, ~1e-5
+ latent_dynamics_lr_mul: float = attrs.field(default=1., validator=attrs.validators.ge(0))
+ quasimetric_model_lr_mul: float = attrs.field(default=1., validator=attrs.validators.ge(0))
+ encoder_lr_mul: float = attrs.field(default=1., validator=attrs.validators.ge(0)) # TD-MPC2 uses 0.3
+ quasimetric_head_lr_mul: float = attrs.field(default=1., validator=attrs.validators.ge(0)) # IQE2 can benefit from smaller lr, ~1e-5
local_lagrange_mult_optim: AdamWSpec.Conf = AdamWSpec.Conf(lr=1e-2)
dynamics_lagrange_mult_optim: AdamWSpec.Conf = AdamWSpec.Conf(lr=0)
diff --git a/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py b/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py
index e469f13..3879518 100644
--- a/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py
+++ b/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py
@@ -106,28 +106,37 @@ class GlobalPushLossBase(CriticLossBase):
zgoal,
get_dist(critic_batch_info.zx, zgoal),
self.weight,
+ {},
)
if self.weight_future_goal > 0:
- zgoal = critic_batch_info.critic.encoder(data.future_observations)
+ zgoal = critic_batch_info.critic.get_encoder(critic_batch_info.zy_from_target_encoder)(data.future_observations)
dist = get_dist(critic_batch_info.zx, zgoal)
if self.clamp_max_future_goal:
+ observed_upper_bound = self.step_cost * data.future_tdelta
+ info = dict(
+ ratio_future_observed_dist=(dist / observed_upper_bound).mean(),
+ exceed_future_observed_dist_rate=(dist > observed_upper_bound).mean(dtype=torch.float32),
+ )
dist = dist.clamp_max(self.step_cost * data.future_tdelta)
+ else:
+ info = {}
yield (
'future_goal',
zgoal,
dist,
- self.weight_future_goal,
+ self.weight_future_goal,info,
)
@abc.abstractmethod
- def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor, dist: torch.Tensor, weight: float) -> LossResult:
+ def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor,
+ dist: torch.Tensor, weight: float, info: Mapping[str, torch.Tensor]) -> LossResult:
raise NotImplementedError
def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult:
return LossResult.combine(
{
- name: self.compute_loss(data, critic_batch_info, zgoal, dist, weight)
- for name, zgoal, dist, weight in self.generate_dist_weight(data, critic_batch_info)
+ name: self.compute_loss(data, critic_batch_info, zgoal, dist, weight, info)
+ for name, zgoal, dist, weight, info in self.generate_dist_weight(data, critic_batch_info)
},
)
@@ -174,11 +183,14 @@ class GlobalPushLoss(GlobalPushLossBase):
self.softplus_beta = softplus_beta
self.softplus_offset = softplus_offset
- def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor, dist: torch.Tensor, weight: float) -> LossResult:
+ def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor,
+ dist: torch.Tensor, weight: float, info: Mapping[str, torch.Tensor]) -> LossResult:
+ dict_info: Dict[str, torch.Tensor] = dict(info)
# Sec 3.2. Transform so that we penalize large distances less.
tsfm_dist: torch.Tensor = F.softplus(self.softplus_offset - dist, beta=self.softplus_beta) # type: ignore
tsfm_dist = tsfm_dist.mean()
- return LossResult(loss=tsfm_dist * weight, info=dict(dist=dist.mean(), tsfm_dist=tsfm_dist))
+ dict_info.update(dist=dist.mean(), tsfm_dist=tsfm_dist)
+ return LossResult(loss=tsfm_dist * weight, info=dict_info)
def extra_repr(self) -> str:
return '\n'.join([
@@ -216,21 +228,22 @@ class GlobalPushLinearLoss(GlobalPushLossBase):
detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, clamp_max_future_goal=clamp_max_future_goal)
self.clamp_max = clamp_max
- def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor, dist: torch.Tensor, weight: float) -> LossResult:
- info: Dict[str, torch.Tensor]
+ def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor,
+ dist: torch.Tensor, weight: float, info: Mapping[str, torch.Tensor]) -> LossResult:
+ dict_info: Dict[str, torch.Tensor] = dict(info)
if self.clamp_max is None:
dist = dist.mean()
- info = dict(dist=dist)
+ dict_info.update(dist=dist)
neg_loss = dist
else:
- info = dict(dist=dist.mean())
tsfm_dist = dist.clamp_max(self.clamp_max)
- info.update(
+ dict_info.update(
+ dist=dist.mean(),
tsfm_dist=tsfm_dist.mean(),
exceed_rate=(dist >= self.clamp_max).mean(dtype=torch.float32),
)
neg_loss = tsfm_dist
- return LossResult(loss=neg_loss * weight, info=info)
+ return LossResult(loss=neg_loss * weight, info=dict_info)
def extra_repr(self) -> str:
return '\n'.join([
@@ -277,9 +290,12 @@ class GlobalPushNextMSELoss(GlobalPushLossBase):
self.allow_gt = allow_gt
self.gamma = gamma
- def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor, dist: torch.Tensor, weight: float) -> LossResult:
+ def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor,
+ dist: torch.Tensor, weight: float, info: Mapping[str, torch.Tensor]) -> LossResult:
+ dict_info: Dict[str, torch.Tensor] = dict(info)
+
with torch.enable_grad(self.detach_target_dist):
- # by tri-eq, the actual cost can't be larger that step_cost + d(s', g)
+ # by tri-eq, the actual cost can't be larger than step_cost + d(s', g)
next_dist = critic_batch_info.critic.quasimetric_model(
critic_batch_info.zy, zgoal, proj_grad_enabled=(True, not self.detach_proj_goal)
)
@@ -289,13 +305,20 @@ class GlobalPushNextMSELoss(GlobalPushLossBase):
if self.allow_gt:
dist = dist.clamp_max(target_dist)
+ dict_info.update(
+ exceed_rate=(dist >= target_dist).mean(dtype=torch.float32),
+ )
if self.gamma is None:
loss = F.mse_loss(dist, target_dist)
else:
loss = F.mse_loss(self.gamma ** dist, self.gamma ** target_dist)
- return LossResult(loss=loss * self.weight, info=dict(loss=loss, dist=dist.mean(), target_dist=target_dist.mean()))
+ dict_info.update(
+ loss=loss, dist=dist.mean(), target_dist=target_dist.mean()
+ )
+
+ return LossResult(loss=loss * self.weight, info=dict_info)
def extra_repr(self) -> str:
return '\n'.join([
@@ -333,10 +356,13 @@ class GlobalPushLogLoss(GlobalPushLossBase):
detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, clamp_max_future_goal=clamp_max_future_goal)
self.offset = offset
- def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor, dist: torch.Tensor, weight: float) -> LossResult:
+ def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor,
+ dist: torch.Tensor, weight: float, info: Mapping[str, torch.Tensor]) -> LossResult:
+ dict_info: Dict[str, torch.Tensor] = dict(info)
tsfm_dist: torch.Tensor = -dist.add(self.offset).log()
tsfm_dist = tsfm_dist.mean()
- return LossResult(loss=tsfm_dist * weight, info=dict(dist=dist.mean(), tsfm_dist=tsfm_dist))
+ dict_info.update(dist=dist.mean(), tsfm_dist=tsfm_dist)
+ return LossResult(loss=tsfm_dist * weight, info=dict_info)
def extra_repr(self) -> str:
return '\n'.join([
@@ -378,13 +404,17 @@ class GlobalPushRBFLoss(GlobalPushLossBase):
detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, clamp_max_future_goal=clamp_max_future_goal)
self.inv_scale = inv_scale
- def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor, dist: torch.Tensor, weight: float) -> LossResult:
+ def compute_loss(self, data: BatchData, critic_batch_info: CriticBatchInfo, zgoal: torch.Tensor,
+ dist: torch.Tensor, weight: float, info: Mapping[str, torch.Tensor]) -> LossResult:
+ dict_info: Dict[str, torch.Tensor] = dict(info)
inv_scale = dist.detach().square().mean().div(2).sqrt().clamp(1e-3, self.inv_scale) # make E[d^2]/r^2 approx 2
tsfm_dist: torch.Tensor = (dist / inv_scale).square().neg().exp()
rbf_potential = tsfm_dist.mean().log()
- return LossResult(loss=rbf_potential * self.weight,
- info=dict(dist=dist.mean(), inv_scale=inv_scale,
- tsfm_dist=tsfm_dist, rbf_potential=rbf_potential)) # type: ignore
+ dict_info.update(
+ dist=dist.mean(), inv_scale=inv_scale,
+ tsfm_dist=tsfm_dist, rbf_potential=rbf_potential,
+ )
+ return LossResult(loss=rbf_potential * self.weight, info=dict_info)
def extra_repr(self) -> str:
return '\n'.join([
diff --git a/quasimetric_rl/modules/quasimetric_critic/losses/latent_dynamics.py b/quasimetric_rl/modules/quasimetric_critic/losses/latent_dynamics.py
index f5e8248..19e4886 100644
--- a/quasimetric_rl/modules/quasimetric_critic/losses/latent_dynamics.py
+++ b/quasimetric_rl/modules/quasimetric_critic/losses/latent_dynamics.py
@@ -2,6 +2,7 @@ from typing import *
import attrs
+import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -25,7 +26,7 @@ class LatentDynamicsLoss(CriticLossBase):
# weight: float = attrs.field(default=0.1, validator=attrs.validators.gt(0))
epsilon: float = attrs.field(default=0.25, validator=attrs.validators.gt(0))
- bidirectional: bool = True
+ # bidirectional: bool = True
gamma: Optional[float] = attrs.field(
default=None, validator=attrs.validators.optional(attrs.validators.and_(
attrs.validators.gt(0),
@@ -37,22 +38,26 @@ class LatentDynamicsLoss(CriticLossBase):
detach_qmet: bool = False
non_quasimetric_dim_mse_weight: float = attrs.field(default=0., validator=attrs.validators.ge(0))
+ kind: str = attrs.field(
+ default='mean', validator=attrs.validators.in_({'mean', 'unidir', 'max', 'bound', 'bound_l2comp', 'l2comp'})) # type: ignore
+
def make(self) -> 'LatentDynamicsLoss':
return LatentDynamicsLoss(
# weight=self.weight,
epsilon=self.epsilon,
- bidirectional=self.bidirectional,
+ # bidirectional=self.bidirectional,
gamma=self.gamma,
detach_sp=self.detach_sp,
detach_proj_sp=self.detach_proj_sp,
detach_qmet=self.detach_qmet,
non_quasimetric_dim_mse_weight=self.non_quasimetric_dim_mse_weight,
init_lagrange_multiplier=self.init_lagrange_multiplier,
+ kind=self.kind,
)
# weight: float
epsilon: float
- bidirectional: bool
+ # bidirectional: bool
gamma: Optional[float]
detach_sp: bool
detach_proj_sp: bool
@@ -60,20 +65,24 @@ class LatentDynamicsLoss(CriticLossBase):
non_quasimetric_dim_mse_weight: float
c: float
init_lagrange_multiplier: float
+ kind: str
- def __init__(self, *, epsilon: float, bidirectional: bool,
+ def __init__(self, *, epsilon: float,
+ # bidirectional: bool,
gamma: Optional[float], init_lagrange_multiplier: float,
# weight: float,
detach_qmet: bool, detach_proj_sp: bool, detach_sp: bool,
- non_quasimetric_dim_mse_weight: float):
+ non_quasimetric_dim_mse_weight: float,
+ kind: str):
super().__init__()
# self.weight = weight
self.epsilon = epsilon
- self.bidirectional = bidirectional
+ # self.bidirectional = bidirectional
self.gamma = gamma
self.detach_qmet = detach_qmet
self.detach_sp = detach_sp
self.detach_proj_sp = detach_proj_sp
+ self.kind = kind
self.non_quasimetric_dim_mse_weight = non_quasimetric_dim_mse_weight
self.init_lagrange_multiplier = init_lagrange_multiplier
self.raw_lagrange_multiplier = nn.Parameter(
@@ -86,16 +95,58 @@ class LatentDynamicsLoss(CriticLossBase):
pred_zy = critic.latent_dynamics(data.observations, zx, data.actions)
if self.detach_sp:
zy = zy.detach()
- with critic.quasimetric_model.requiring_grad(not self.detach_qmet):
- dists = critic.quasimetric_model(zy, pred_zy, bidirectional=self.bidirectional, # at least optimize d(s', hat{s'})
- proj_grad_enabled=(self.detach_proj_sp, True))
-
lagrange_mult = F.softplus(self.raw_lagrange_multiplier) # make positive
# lagrange multiplier is minimax training, so grad_mul -1
lagrange_mult = grad_mul(lagrange_mult, -1)
- info = {}
+ info: Dict[str, torch.Tensor] = {}
+
+ if self.kind == 'unidir':
+ with critic.quasimetric_model.requiring_grad(not self.detach_qmet):
+ dists = critic.quasimetric_model(zy, pred_zy, bidirectional=False, # at least optimize d(s', hat{s'})
+ proj_grad_enabled=(self.detach_proj_sp, True))
+ info.update(dist_n2p=dists.mean())
+ elif self.kind in ['mean', 'max', 'bound', 'l2comp']:
+ with critic.quasimetric_model.requiring_grad(not self.detach_qmet):
+ dists_comps = critic.quasimetric_model(
+ zy, pred_zy, bidirectional=True, proj_grad_enabled=(self.detach_proj_sp, True),
+ reduced=False)
+ dists = critic.quasimetric_model.quasimetric_head.reduction(dists_comps)
+
+ dist_p2n, dist_n2p = dists.unbind(-1)
+ info.update(dist_p2n=dist_p2n.mean(), dist_n2p=dist_n2p.mean())
+ if self.kind == 'max':
+ dists = dists.max(-1).values
+ elif self.kind == 'bound':
+ with critic.quasimetric_model.requiring_grad(not self.detach_qmet):
+ dists = critic.quasimetric_model(
+ zy, pred_zy, bidirectional=False, # NB: false is enough for bound
+ proj_grad_enabled=(self.detach_proj_sp, True),
+ symmetric_upperbound=True,
+ )
+ info.update(dist_upbnd=dists.mean())
+ elif self.kind == 'bound_l2comp':
+ with critic.quasimetric_model.requiring_grad(not self.detach_qmet):
+ dists_comps = critic.quasimetric_model(
+ zy, pred_zy, bidirectional=False, # NB: false is enough for bound
+ proj_grad_enabled=(self.detach_proj_sp, True),
+ reduced=False, symmetric_upperbound=True,
+ )
+ dists = cast(
+ torch.Tensor,
+ dists_comps.norm(dim=-1) / np.sqrt(dists_comps.shape[-1]),
+ )
+ info.update(dist_upbnd_l2comp=dists.mean())
+ elif self.kind == 'l2comp':
+ dists = cast(
+ torch.Tensor,
+ dists_comps.norm(dim=-1) / np.sqrt(dists_comps.shape[-1]),
+ )
+ info.update(dist_l2comp=dists.mean())
+ else:
+ raise NotImplementedError(self.kind)
+
if self.gamma is None:
sq_dists = dists.square().mean()
violation = (sq_dists - self.epsilon ** 2)
@@ -117,12 +168,6 @@ class LatentDynamicsLoss(CriticLossBase):
info.update(non_quasimetric_dim_mse=non_quasimetric_dim_mse)
loss += non_quasimetric_dim_mse * self.non_quasimetric_dim_mse_weight
- if self.bidirectional:
- dist_p2n, dist_n2p = dists.flatten(0, -2).mean(0).unbind(-1)
- info.update(dist_p2n=dist_p2n, dist_n2p=dist_n2p)
- else:
- info.update(dist_n2p=dists.mean())
-
return LossResult(
loss=loss,
info=info, # type: ignore
@@ -130,7 +175,7 @@ class LatentDynamicsLoss(CriticLossBase):
def extra_repr(self) -> str:
return '\n'.join([
- f"epsilon={self.epsilon!r}, bidirectional={self.bidirectional}",
- f"gamma={self.gamma!r}, detach_sp={self.detach_sp}, detach_proj_sp={self.detach_proj_sp}, detach_qmet={self.detach_qmet}",
+ f"kind={self.kind}, gamma={self.gamma!r}",
+ f"epsilon={self.epsilon!r}, detach_sp={self.detach_sp}, detach_proj_sp={self.detach_proj_sp}, detach_qmet={self.detach_qmet}",
])
# return f"weight={self.weight:g}, detach_sp={self.detach_sp}"
diff --git a/quasimetric_rl/modules/quasimetric_critic/losses/local_constraint.py b/quasimetric_rl/modules/quasimetric_critic/losses/local_constraint.py
index 1806248..eb87576 100644
--- a/quasimetric_rl/modules/quasimetric_critic/losses/local_constraint.py
+++ b/quasimetric_rl/modules/quasimetric_critic/losses/local_constraint.py
@@ -2,6 +2,7 @@ from typing import *
import attrs
+import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -30,8 +31,14 @@ class LocalConstraintLoss(CriticLossBase):
init_lagrange_multiplier: float = attrs.field(default=0.01, validator=attrs.validators.gt(0))
- kind: str = attrs.field(
- default='mse', validator=attrs.validators.in_(['mse', 'sq'])) # type: ignore
+ p: float = attrs.field(default=2., validator=attrs.validators.ge(1))
+ compare_before_pow: bool = True
+ # kind: str = attrs.field(
+ # default='mse', validator=attrs.validators.in_(['mse', 'sq'])) # type: ignore
+ # batch_reduction: str = attrs.field(
+ # default='mean', validator=attrs.validators.in_(['mean', 'l2'])) # type: ignore
+ invpow_after_batch_agg: bool = False
+ log: bool = False
detach_proj_sp: bool = False
detach_sp: bool = False
@@ -49,12 +56,22 @@ class LocalConstraintLoss(CriticLossBase):
step_cost=self.step_cost,
step_cost_high=self.step_cost_high,
init_lagrange_multiplier=self.init_lagrange_multiplier,
- kind=self.kind,
+ # kind=self.kind,
+ p=self.p,
+ compare_before_pow=self.compare_before_pow,
+ log=self.log,
+ # batch_reduction=self.batch_reduction,
+ invpow_after_batch_agg=self.invpow_after_batch_agg,
detach_sp=self.detach_sp,
detach_proj_sp=self.detach_proj_sp,
)
- kind: str
+ # kind: str
+ p: float
+ compare_before_pow: bool
+ log: bool
+ # batch_reduction: str
+ invpow_after_batch_agg: bool
epsilon: float
allow_le: bool
step_cost: float
@@ -65,11 +82,21 @@ class LocalConstraintLoss(CriticLossBase):
raw_lagrange_multiplier: nn.Parameter # for the QRL constrained optimization
- def __init__(self, *, kind: str, epsilon: float, allow_le: bool,
+ def __init__(self, *,
+ # kind: str,
+ p: float, compare_before_pow: bool, invpow_after_batch_agg: bool,
+ log: bool,
+ # batch_reduction: str,
+ epsilon: float, allow_le: bool,
step_cost: float, step_cost_high: float, init_lagrange_multiplier: float,
detach_sp: bool, detach_proj_sp: bool):
super().__init__()
- self.kind = kind
+ # self.kind = kind
+ self.p = p
+ self.compare_before_pow = compare_before_pow
+ self.log = log
+ # self.batch_reduction = batch_reduction
+ self.invpow_after_batch_agg = invpow_after_batch_agg
self.epsilon = epsilon
self.allow_le = allow_le
self.step_cost = step_cost
@@ -100,6 +127,7 @@ class LocalConstraintLoss(CriticLossBase):
info['dist_090'],
info['dist_100'],
) = dist.quantile(dist.new_tensor([0, 0.1, 0.25, 0.5, 0.75, 0.9, 1])).unbind()
+ info.update(dist=dist.mean())
lagrange_mult = F.softplus(self.raw_lagrange_multiplier) # make positive
# lagrange multiplier is minimax training, so grad_mul -1
@@ -110,31 +138,63 @@ class LocalConstraintLoss(CriticLossBase):
else:
target = dist.detach().clamp(self.step_cost, self.step_cost_high)
- if self.kind == 'mse':
- if self.allow_le:
- sq_deviation = (dist - target).relu().square().mean()
+ if self.log:
+ dist = dist.log()
+ if isinstance(target, torch.Tensor):
+ target = target.log()
else:
- sq_deviation = (dist - target).square().mean()
- elif self.kind == 'sq':
- if self.allow_le:
- sq_deviation = (dist.square() - target ** 2).relu().mean()
- else:
- sq_deviation = (dist.square() - target ** 2).abs().mean()
+ target = np.log(target)
+ info.update(dist_log=dist.mean())
+
+ if self.compare_before_pow:
+ deviation = dist - target
+ deviation = (torch.relu if self.allow_le else torch.abs)(deviation)
+ deviation = deviation ** self.p
else:
- raise ValueError(f"Unknown kind {self.kind}")
- violation = (sq_deviation - self.epsilon ** 2)
- loss = violation * lagrange_mult
+ deviation = (dist ** self.p - target ** self.p)
+ deviation = (torch.relu if self.allow_le else torch.abs)(deviation)
- info.update(
- dist=dist.mean(), sq_deviation=sq_deviation,
- violation=violation, lagrange_mult=lagrange_mult,
- )
+ if self.invpow_after_batch_agg:
+ deviation = deviation.mean().pow(1 / self.p)
+ violation = deviation - self.epsilon
+ else:
+ violation = deviation.mean() - (self.epsilon ** self.p)
+ info.update(deviation=deviation.mean(), violation=violation, lagrange_mult=lagrange_mult)
+
+ # if self.kind == 'mse':
+ # if self.allow_le:
+ # sq_deviation = (dist - target).relu().square()
+ # else:
+ # sq_deviation = (dist - target).square()
+ # elif self.kind == 'sq':
+ # if self.allow_le:
+ # sq_deviation = (dist.square() - target ** 2).relu()
+ # else:
+ # sq_deviation = (dist.square() - target ** 2).abs()
+ # else:
+ # raise ValueError(f"Unknown kind {self.kind}")
+ # info.update(sq_deviation=sq_deviation.mean())
+
+ # if self.batch_reduction == 'mean':
+ # batch_sq_deviation = sq_deviation.mean()
+ # elif self.batch_reduction == 'l2':
+ # batch_sq_deviation = sq_deviation.square().mean().sqrt()
+ # else:
+ # raise ValueError(f"Unknown batch_reduction {self.batch_reduction}")
+ # info.update(batch_sq_deviation=batch_sq_deviation)
+
+ # violation = (batch_sq_deviation - self.epsilon ** 2)
+
+ loss = violation * lagrange_mult
+ # info.update(violation=violation, lagrange_mult=lagrange_mult)
return LossResult(loss=loss, info=info)
def extra_repr(self) -> str:
return '\n'.join([
- f"{self.kind}, epsilon={self.epsilon:g}, allow_le={self.allow_le}",
+ # f"kind={self.kind}, log={self.log}, batch_reduction={self.batch_reduction}",
+ f"p={self.p}, compare_before_pow={self.compare_before_pow}, invpow_after_batch_agg={self.invpow_after_batch_agg}",
+ f"epsilon={self.epsilon:g}, allow_le={self.allow_le}",
f"step_cost={self.step_cost:g}, step_cost_high={self.step_cost_high:g}",
f"detach_sp={self.detach_sp}, detach_proj_sp={self.detach_proj_sp}",
])
diff --git a/quasimetric_rl/modules/quasimetric_critic/models/__init__.py b/quasimetric_rl/modules/quasimetric_critic/models/__init__.py
index 8789800..266f592 100644
--- a/quasimetric_rl/modules/quasimetric_critic/models/__init__.py
+++ b/quasimetric_rl/modules/quasimetric_critic/models/__init__.py
@@ -25,6 +25,11 @@ class QuasimetricCritic(Module):
attrs.validators.le(1),
)))
encoder: Encoder.Conf = Encoder.Conf()
+ target_quasimetric_model_ema: Optional[float] = attrs.field(
+ default=None, validator=attrs.validators.optional(attrs.validators.and_(
+ attrs.validators.ge(0),
+ attrs.validators.le(1),
+ )))
quasimetric_model: QuasimetricModel.Conf = QuasimetricModel.Conf()
latent_dynamics: Dynamics.Conf = Dynamics.Conf()
@@ -42,34 +47,46 @@ class QuasimetricCritic(Module):
)
return QuasimetricCritic(encoder, quasimetric_model, latent_dynamics,
share_embedding_from=share_embedding_from,
- target_encoder_ema=self.target_encoder_ema)
+ target_encoder_ema=self.target_encoder_ema,
+ target_quasimetric_model_ema=self.target_quasimetric_model_ema)
borrowing_embedding: bool
encoder: Encoder
_target_encoder: Optional[Encoder]
target_encoder_ema: Optional[float]
quasimetric_model: QuasimetricModel
+ _target_quasimetric_model: Optional[QuasimetricModel]
+ target_quasimetric_model_ema: Optional[float]
latent_dynamics: Dynamics # FIXME: add BC to model loading & change name
def __init__(self, encoder: Encoder, quasimetric_model: QuasimetricModel, latent_dynamics: Dynamics,
- share_embedding_from: Optional['QuasimetricCritic'], target_encoder_ema: Optional[float]):
+ share_embedding_from: Optional['QuasimetricCritic'], target_encoder_ema: Optional[float],
+ target_quasimetric_model_ema: Optional[float] = None):
super().__init__()
self.borrowing_embedding = share_embedding_from is not None
if share_embedding_from is not None:
encoder = share_embedding_from.encoder
- target_encoder = share_embedding_from.target_encoder
+ _target_encoder = share_embedding_from._target_encoder
assert target_encoder_ema == share_embedding_from.target_encoder_ema
quasimetric_model = share_embedding_from.quasimetric_model
+ assert target_quasimetric_model_ema == share_embedding_from.target_quasimetric_model_ema
+ _target_quasimetric_model = share_embedding_from._target_quasimetric_model
else:
if target_encoder_ema is None:
- target_encoder = None
+ _target_encoder = None
+ else:
+ _target_encoder = copy.deepcopy(encoder).requires_grad_(False).eval()
+ if target_quasimetric_model_ema is None:
+ _target_quasimetric_model = None
else:
- target_encoder = copy.deepcopy(encoder).requires_grad_(False).eval()
+ _target_quasimetric_model = copy.deepcopy(quasimetric_model).requires_grad_(False).eval()
self.encoder = encoder
- self.add_module('_target_encoder', target_encoder)
self.target_encoder_ema = target_encoder_ema
+ self.add_module('_target_encoder', _target_encoder)
self.quasimetric_model = quasimetric_model
+ self.target_quasimetric_model_ema = target_quasimetric_model_ema
+ self.add_module('_target_quasimetric_model', _target_quasimetric_model)
self.latent_dynamics = latent_dynamics
def forward(self, x: torch.Tensor, y: torch.Tensor, *, action: Optional[torch.Tensor] = None) -> torch.Tensor:
@@ -94,16 +111,31 @@ class QuasimetricCritic(Module):
def get_encoder(self, target: bool = False) -> Encoder:
return self.target_encoder if target else self.encoder
+ @property
+ def target_quasimetric_model(self) -> QuasimetricModel:
+ if self._target_quasimetric_model is not None:
+ return self._target_quasimetric_model
+ else:
+ return self.quasimetric_model
+
+ def get_quasimetric_model(self, target: bool = False) -> QuasimetricModel:
+ return self.target_quasimetric_model if target else self.quasimetric_model
+
@torch.no_grad()
- def update_target_encoder_(self):
- if not self.borrowing_embedding and self.target_encoder_ema is not None:
- assert self._target_encoder is not None
- for p, p_target in zip(self.encoder.parameters(), self._target_encoder.parameters()):
- p_target.data.lerp_(p.data, 1 - self.target_encoder_ema)
+ def update_target_models_(self):
+ if not self.borrowing_embedding:
+ if self.target_encoder_ema is not None:
+ assert self._target_encoder is not None
+ for p, p_target in zip(self.encoder.parameters(), self._target_encoder.parameters()):
+ p_target.data.lerp_(p.data, 1 - self.target_encoder_ema)
+ if self.target_quasimetric_model_ema is not None:
+ assert self._target_quasimetric_model is not None
+ for p, p_target in zip(self.quasimetric_model.parameters(), self._target_quasimetric_model.parameters()):
+ p_target.data.lerp_(p.data, 1 - self.target_quasimetric_model_ema)
# for type hints
def __call__(self, x: torch.Tensor, y: torch.Tensor, *, action: Optional[torch.Tensor] = None) -> torch.Tensor:
return super().__call__(x, y, action=action)
def extra_repr(self) -> str:
- return f"borrowing_embedding={self.borrowing_embedding}"
+ return f"borrowing_embedding={self.borrowing_embedding}, target_encoder_ema={self.target_encoder_ema}, target_quasimetric_model_ema={self.target_quasimetric_model_ema}"
diff --git a/quasimetric_rl/modules/quasimetric_critic/models/encoder.py b/quasimetric_rl/modules/quasimetric_critic/models/encoder.py
index ddd6240..2b53bd5 100644
--- a/quasimetric_rl/modules/quasimetric_critic/models/encoder.py
+++ b/quasimetric_rl/modules/quasimetric_critic/models/encoder.py
@@ -46,6 +46,7 @@ class Encoder(nn.Module):
arch: Tuple[int, ...] = (512, 512)
layer_norm: bool = True
+ bias_last_fc: bool = True
latent: LatentSpaceConf = LatentSpaceConf()
def make(self, *, env_spec: EnvSpec) -> 'Encoder':
@@ -53,6 +54,7 @@ class Encoder(nn.Module):
env_spec=env_spec,
arch=self.arch,
layer_norm=self.layer_norm,
+ bias_last_fc=self.bias_last_fc,
latent=self.latent,
)
@@ -62,14 +64,14 @@ class Encoder(nn.Module):
projection: nn.Module
latent: LatentSpaceConf
- def __init__(self, *, env_spec: EnvSpec, arch: Tuple[int, ...], layer_norm: bool,
+ def __init__(self, *, env_spec: EnvSpec, arch: Tuple[int, ...], layer_norm: bool, bias_last_fc: bool,
latent: LatentSpaceConf, **kwargs):
super().__init__(**kwargs)
self.input_shape = env_spec.observation_shape
self.input_encoding = env_spec.make_observation_input()
encoder_input_size = self.input_encoding.output_size
self.encoder = nn.Sequential(
- MLP(encoder_input_size, latent.latent_size, hidden_sizes=arch, layer_norm=layer_norm),
+ MLP(encoder_input_size, latent.latent_size, hidden_sizes=arch, layer_norm=layer_norm, bias_last_fc=bias_last_fc),
latent.make_projection(),
)
self.latent = latent
diff --git a/quasimetric_rl/modules/quasimetric_critic/models/latent_dynamics.py b/quasimetric_rl/modules/quasimetric_critic/models/latent_dynamics.py
index 0847076..3683454 100644
--- a/quasimetric_rl/modules/quasimetric_critic/models/latent_dynamics.py
+++ b/quasimetric_rl/modules/quasimetric_critic/models/latent_dynamics.py
@@ -20,10 +20,13 @@ class Dynamics(MLP):
# config / argparse uses this to specify behavior
arch: Tuple[int, ...] = (512, 512)
+ action_arch: Tuple[int, ...] = ()
layer_norm: bool = True
latent_input: bool = True
raw_observation_input_arch: Optional[Tuple[int, ...]] = None
+ bias: bool = True
+ bias_last_fc: bool = True
use_latent_space_proj: bool = True
residual: bool = False # True -> False following TD-MPC2
fc_fuse: bool = False # whether last fc does fusion
@@ -34,14 +37,18 @@ class Dynamics(MLP):
latent_input=self.latent_input,
raw_observation_input_arch=self.raw_observation_input_arch,
env_spec=env_spec,
+ action_arch=self.action_arch,
hidden_sizes=self.arch,
layer_norm=self.layer_norm,
use_latent_space_proj=self.use_latent_space_proj,
+ bias=self.bias,
+ bias_last_fc=self.bias_last_fc,
residual=self.residual,
fc_fuse=self.fc_fuse,
)
- action_input: InputEncoding
+ action_input: Union[InputEncoding, nn.Sequential]
+ action_arch: Tuple[int, ...]
latent_input: bool
raw_observation_input: Optional[nn.Sequential]
residual: bool
@@ -50,32 +57,53 @@ class Dynamics(MLP):
def __init__(self, *, latent: LatentSpaceConf,
latent_input: bool,
raw_observation_input_arch: Optional[Tuple[int, ...]],
- env_spec: EnvSpec, hidden_sizes: Tuple[int, ...], layer_norm: bool,
- use_latent_space_proj: bool, residual: bool,
+ env_spec: EnvSpec, action_arch: Tuple[int, ...],
+ hidden_sizes: Tuple[int, ...], layer_norm: bool,
+ use_latent_space_proj: bool,
+ bias: bool, bias_last_fc: bool, residual: bool,
fc_fuse: bool):
+ self.action_arch = action_arch
action_input = env_spec.make_action_input()
- mlp_input_size = action_input.output_size
+ action_outsize = action_input.output_size
+ if action_arch:
+ action_input = nn.Sequential(
+ action_input,
+ MLP(
+ action_input.output_size,
+ action_arch[-1],
+ hidden_sizes=action_arch[:-1],
+ activation_last_fc=True,
+ layer_norm_last_fc=layer_norm,
+ layer_norm=layer_norm,
+ bias=bias,
+ ),
+ )
+ action_outsize = action_arch[-1]
+ mlp_input_size = action_outsize
mlp_output_size = latent.latent_size
self.latent_input = latent_input
if latent_input:
mlp_input_size += latent.latent_size
if raw_observation_input_arch is not None:
- assert raw_observation_input_arch is not None
- assert len(raw_observation_input_arch) >= 1
input_encoding = env_spec.make_observation_input()
- raw_observation_input = MLP(
- input_encoding.output_size,
- raw_observation_input_arch[-1],
- hidden_sizes=raw_observation_input_arch[:-1],
- activation_last_fc=True,
- layer_norm_last_fc=layer_norm,
- layer_norm=layer_norm,
- )
- mlp_input_size += raw_observation_input.output_size
- raw_observation_input = nn.Sequential(
- input_encoding,
- raw_observation_input,
- )
+ if len(raw_observation_input_arch) >= 1:
+ raw_observation_input = MLP(
+ input_encoding.output_size,
+ raw_observation_input_arch[-1],
+ hidden_sizes=raw_observation_input_arch[:-1],
+ activation_last_fc=True,
+ layer_norm_last_fc=layer_norm,
+ layer_norm=layer_norm,
+ bias=bias,
+ )
+ mlp_input_size += raw_observation_input.output_size
+ raw_observation_input = nn.Sequential(
+ input_encoding,
+ raw_observation_input,
+ )
+ else:
+ raw_observation_input = input_encoding
+ mlp_input_size += input_encoding.output_size
else:
raw_observation_input = None
super().__init__(
@@ -84,6 +112,8 @@ class Dynamics(MLP):
hidden_sizes=hidden_sizes,
zero_init_last_fc=residual,
layer_norm=layer_norm,
+ bias=bias,
+ bias_last_fc=bias_last_fc,
)
self.action_input = action_input
self.add_module('raw_observation_input', raw_observation_input)
@@ -99,6 +129,7 @@ class Dynamics(MLP):
latent.latent_size,
hidden_sizes=[], # no hidden layers
zero_init_last_fc=residual,
+ bias_last_fc=bias_last_fc,
).module[0])
with torch.no_grad():
fuse_fc.weight[:, -latent.latent_size:].add_(
@@ -113,7 +144,7 @@ class Dynamics(MLP):
inputs: List[torch.Tensor] = []
if self.raw_observation_input is not None:
inputs.append(self.raw_observation_input(observation))
- if self.latent_input is not None:
+ if self.latent_input:
inputs.append(zx)
inputs.append(self.action_input(action))
diff --git a/quasimetric_rl/modules/quasimetric_critic/models/quasimetric_model.py b/quasimetric_rl/modules/quasimetric_critic/models/quasimetric_model.py
index 92fbc58..6464a45 100644
--- a/quasimetric_rl/modules/quasimetric_critic/models/quasimetric_model.py
+++ b/quasimetric_rl/modules/quasimetric_critic/models/quasimetric_model.py
@@ -20,7 +20,7 @@ class L2(torchqmet.QuasimetricBase):
super().__init__(input_size, num_components=1, guaranteed_quasimetric=True,
transforms=[], reduction='sum', discount=None)
- def compute_components(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+ def compute_components(self, x: torch.Tensor, y: torch.Tensor, symmetric_upperbound: bool = False) -> torch.Tensor:
r'''
Inputs:
x (torch.Tensor): Shape [..., input_size]
@@ -41,7 +41,7 @@ def create_quasimetric_head_from_spec(spec: str) -> torchqmet.QuasimetricBase:
assert dim % components == 0, "IQE: dim must be divisible by components"
return torchqmet.IQE(dim, dim // components)
- def iqe2(*, dim: int, components: int, norm_delta: bool = False, fake_grad: bool = False, reduction: str = 'maxl12_sm') -> torchqmet.IQE:
+ def iqe2(*, dim: int, components: int, norm_delta: bool = False, fake_grad: bool = False, reduction: str = 'maxl12_sm', version=None) -> torchqmet.IQE2:
assert dim % components == 0, "IQE: dim must be divisible by components"
return torchqmet.IQE2(
dim, dim // components,
@@ -51,6 +51,7 @@ def create_quasimetric_head_from_spec(spec: str) -> torchqmet.QuasimetricBase:
div_init_mul=0.25,
mul_kind='normdeltadiv' if norm_delta else 'normdiv',
fake_grad=fake_grad,
+ version=version,
)
def l2(*, dim: int) -> L2:
@@ -125,10 +126,12 @@ class QuasimetricModel(Module):
layer_norm=projector_layer_norm,
dropout=projector_dropout,
weight_norm_last_fc=projector_weight_norm,
- unit_norm_last_fc=projector_unit_norm)
+ unit_norm_last_fc=projector_unit_norm,
+ )
def forward(self, zx: LatentTensor, zy: LatentTensor, *, bidirectional: bool = False,
- proj_grad_enabled: Tuple[bool, bool] = (True, True)) -> torch.Tensor:
+ proj_grad_enabled: Tuple[bool, bool] = (True, True), reduced: bool = True,
+ symmetric_upperbound: bool = False) -> torch.Tensor:
zx = zx[..., :self.input_slice_size]
zy = zy[..., :self.input_slice_size]
with self.projector.requiring_grad(proj_grad_enabled[0]):
@@ -140,7 +143,7 @@ class QuasimetricModel(Module):
px, py = torch.broadcast_tensors(px, py)
px, py = torch.stack([px, py], dim=-2), torch.stack([py, px], dim=-2) # [B x 2 x D]
- return self.quasimetric_head(px, py)
+ return self.quasimetric_head(px, py, reduced=reduced, symmetric_upperbound=symmetric_upperbound)
def parameters(self, recurse: bool = True, *, include_head: bool = True) -> Iterator[Parameter]:
if include_head:
@@ -153,8 +156,12 @@ class QuasimetricModel(Module):
# for type hint
def __call__(self, zx: LatentTensor, zy: LatentTensor, *, bidirectional: bool = False,
- proj_grad_enabled: Tuple[bool, bool] = (True, True)) -> torch.Tensor:
- return super().__call__(zx, zy, bidirectional=bidirectional, proj_grad_enabled=proj_grad_enabled)
+ proj_grad_enabled: Tuple[bool, bool] = (True, True), reduced: bool = True,
+ symmetric_upperbound: bool = False) -> torch.Tensor:
+ return super().__call__(zx, zy, bidirectional=bidirectional,
+ proj_grad_enabled=proj_grad_enabled,
+ reduced=reduced,
+ symmetric_upperbound=symmetric_upperbound)
def extra_repr(self) -> str:
return f"input_size={self.input_size}, input_slice_size={self.input_slice_size}"
diff --git a/quasimetric_rl/modules/utils.py b/quasimetric_rl/modules/utils.py
index 5d7a845..8fbe219 100644
--- a/quasimetric_rl/modules/utils.py
+++ b/quasimetric_rl/modules/utils.py
@@ -33,11 +33,18 @@ InfoValT = Union[InfoT, float, torch.Tensor]
@attrs.define(kw_only=True)
class LossResult:
- loss: Union[torch.Tensor, float]
+ loss: InfoValT
info: InfoT
def __attrs_post_init__(self):
- assert isinstance(self.loss, (int, float)) or self.loss.numel() == 1
+ def test_loss(loss: InfoValT):
+ if isinstance(loss, Mapping):
+ for v in loss.values():
+ test_loss(v)
+ else:
+ assert isinstance(loss, (int, float)) or loss.numel() == 1
+
+ test_loss(self.loss)
# detach info tensors
def detach(d: InfoValT) -> InfoValT:
@@ -50,17 +57,40 @@ class LossResult:
object.__setattr__(self, "info", detach(self.info))
+ def map_losses(self, fn: Callable[[torch.Tensor | float], torch.Tensor | float]) -> InfoValT:
+ def map_loss(loss: InfoValT):
+ if isinstance(loss, Mapping):
+ return {k: map_loss(v) for k, v in loss.items()}
+ else:
+ return fn(loss)
+
+ return map_loss(self.loss)
+
@classmethod
def empty(cls) -> 'LossResult':
return LossResult(loss=0, info={})
+ @property
+ def total_loss(self) -> torch.Tensor:
+ l = 0
+ def add_loss(loss: InfoValT):
+ nonlocal l
+ if isinstance(loss, Mapping):
+ for v in loss.values():
+ add_loss(v)
+ else:
+ l += loss
+ add_loss(self.loss)
+ assert isinstance(l, torch.Tensor)
+ return l
+
@classmethod
def combine(cls, results: Mapping[str, 'LossResult'], **kwargs) -> 'LossResult':
info = {k: r.info for k, r in results.items()}
assert info.keys().isdisjoint(kwargs.keys())
info.update(**kwargs)
return LossResult(
- loss=sum(r.loss for r in results.values()),
+ loss={k: r.loss for k, r in results.items()},
info=info,
)
@@ -246,6 +276,7 @@ class MLP(Module):
layer_norm_last_fc: bool = False,
weight_norm_last_fc: bool = False,
unit_norm_last_fc: bool = False,
+ bias: bool = True,
bias_last_fc: bool = True,
# final_layer_norm: bool = False, # useful when not output
# final_layer: Callable[[int], nn.Module] = lambda _: nn.Identity,
@@ -261,7 +292,7 @@ class MLP(Module):
for ii, sz in enumerate(hidden_sizes):
modules.append(
TDMPC2Linear(layer_in_size, sz, dropout=(ii == 0) * dropout,
- layer_norm=layer_norm, activation_fn=activation_fn),
+ layer_norm=layer_norm, activation_fn=activation_fn, bias=bias),
)
layer_in_size = sz
modules.append(TDMPC2Linear(layer_in_size, output_size,
Submodule third_party/torch-quasimetric 0fce12e..89f107c:
diff --git a/third_party/torch-quasimetric/torchqmet/__init__.py b/third_party/torch-quasimetric/torchqmet/__init__.py
index 3afb467..b776de5 100644
--- a/third_party/torch-quasimetric/torchqmet/__init__.py
+++ b/third_party/torch-quasimetric/torchqmet/__init__.py
@@ -54,19 +54,22 @@ class QuasimetricBase(nn.Module, metaclass=abc.ABCMeta):
'''
pass
- def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+ def forward(self, x: torch.Tensor, y: torch.Tensor, *, reduced: bool = True, **kwargs) -> torch.Tensor:
assert x.shape[-1] == y.shape[-1] == self.input_size
- d = self.compute_components(x, y)
+ d = self.compute_components(x, y, **kwargs)
d: torch.Tensor = self.transforms(d)
scale = self.scale
if not self.training:
scale = scale.detach()
- return self.reduction(d) * scale
+ if reduced:
+ return self.reduction(d) * scale
+ else:
+ return d * scale
- def __call__(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+ def __call__(self, x: torch.Tensor, y: torch.Tensor, reduced: bool = True, **kwargs) -> torch.Tensor:
# Manually define for typing
# https://github.com/pytorch/pytorch/issues/45414
- return super().__call__(x, y)
+ return super().__call__(x, y, reduced=reduced, **kwargs)
def extra_repr(self) -> str:
return f"guaranteed_quasimetric={self.guaranteed_quasimetric}\ninput_size={self.input_size}, num_components={self.num_components}" + (
@@ -79,5 +82,5 @@ from .iqe import IQE, IQE2
from .mrn import MRN, MRNFixed
from .neural_norms import DeepNorm, WideNorm
-__all__ = ['PQE', 'PQELH', 'PQEGG', 'IQE', 'MRN', 'MRNFixed', 'DeepNorm', 'WideNorm']
+__all__ = ['PQE', 'PQELH', 'PQEGG', 'IQE', 'IQE2', 'MRN', 'MRNFixed', 'DeepNorm', 'WideNorm']
__version__ = "0.1.0"
diff --git a/third_party/torch-quasimetric/torchqmet/iqe.py b/third_party/torch-quasimetric/torchqmet/iqe.py
index a8e8c92..29470a8 100644
--- a/third_party/torch-quasimetric/torchqmet/iqe.py
+++ b/third_party/torch-quasimetric/torchqmet/iqe.py
@@ -41,6 +41,7 @@ def iqe_tensor_delta(x: torch.Tensor, y: torch.Tensor, delta: torch.Tensor, div_
).detach()
neg_f = torch.expm1(neg_f_input)
+ # neg_incf = torch.diff(neg_f, dim=-1, prepend=neg_f.new_zeros(()).expand_as(neg_f[..., :1]))
neg_incf = torch.cat([neg_f.narrow(-1, 0, 1), torch.diff(neg_f, dim=-1)], dim=-1)
# reduction
@@ -62,7 +63,6 @@ def iqe_tensor_delta(x: torch.Tensor, y: torch.Tensor, delta: torch.Tensor, div_
return comp
-
def iqe(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
D = x.shape[-1] # D: dim_per_component
@@ -88,6 +88,7 @@ def iqe(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
# neg_inp = torch.where(neg_inp_copies == 0, 0., -delta)
# f output: 0 -> 0, x -> 1.
neg_f = (neg_inp_copies < 0) * (-1.)
+ # neg_incf = torch.diff(neg_f, dim=-1, prepend=neg_f.new_zeros(()).expand_as(neg_f[..., :1]))
neg_incf = torch.cat([neg_f.narrow(-1, 0, 1), torch.diff(neg_f, dim=-1)], dim=-1)
# reduction
@@ -111,12 +112,23 @@ def is_notebook():
if torch.__version__ >= '2.0.1' and not is_notebook(): # well, broken process pool in notebooks
- iqe = torch.compile(iqe, mode="max-autotune")
- iqe_tensor_delta = torch.compile(iqe_tensor_delta, mode="max-autotune")
+ iqe = torch.compile(iqe)
+ _iqe_tensor_delta = torch.compile(iqe_tensor_delta)
+ _iqe_tensor_delta_jit = torch.jit.script(iqe_tensor_delta)
+ _default_iqe_tensor_delta_version = 'compile'
# iqe = torch.compile(iqe, dynamic=True)
else:
iqe = torch.jit.script(iqe) # type: ignore
- iqe_tensor_delta = torch.jit.script(iqe_tensor_delta) # type: ignore
+ _iqe_tensor_delta = None
+ _iqe_tensor_delta_jit = torch.jit.script(iqe_tensor_delta) # type: ignore
+ _default_iqe_tensor_delta_version = 'jit'
+
+
+
+def get_iqe_tensor_delta(version):
+ fn = dict(jit=_iqe_tensor_delta_jit, compile=_iqe_tensor_delta)[version]
+ assert fn is not None, f"version={version} not supported"
+ return fn
class IQE(QuasimetricBase):
@@ -243,7 +255,8 @@ class IQE2(IQE):
component_dropout_thresh: Tuple[float, float] = (0.5, 2),
dropout_p_thresh: Tuple[float, float] = (0.005, 0.995),
dropout_batch_frac: float = 0.2,
- ema_weight: float = 0.95):
+ ema_weight: float = 0.95,
+ version: Optional[str] = None):
super().__init__(input_size, dim_per_component, transforms=transforms, reduction=reduction,
discount=discount, warn_if_not_quasimetric=warn_if_not_quasimetric)
self.component_dropout_thresh = tuple(component_dropout_thresh) # type: ignore
@@ -302,9 +315,10 @@ class IQE2(IQE):
self.mul_kind = mul_kind
self.last_components = None # type: ignore
self.last_drop_p = None # type: ignore
+ self.version = version or _default_iqe_tensor_delta_version
- def compute_components(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+ def compute_components(self, x: torch.Tensor, y: torch.Tensor, *, symmetric_upperbound: bool = False) -> torch.Tensor:
# if self.raw_delta is None:
# components = super().compute_components(x, y)
# else:
@@ -314,7 +328,10 @@ class IQE2(IQE):
delta.data.clamp_(max=1e3 / (self.latent_2d_shape[-1] / 8))
div_pre_f.data.clamp_(min=1e-3)
- components = iqe_tensor_delta(
+ if symmetric_upperbound:
+ x, y = torch.minimum(x, y), torch.maximum(x, y)
+
+ components = get_iqe_tensor_delta(self.version)( # type: ignore
x=x.unflatten(-1, self.latent_2d_shape),
y=y.unflatten(-1, self.latent_2d_shape),
delta=delta,
@@ -384,4 +401,5 @@ learned_delta={self.raw_delta is not None},
learned_div={self.raw_div.requires_grad},
div_init_mul={self.div_init_mul:g},
mul_kind={self.mul_kind},
-fake_grad={self.fake_grad},"""
+fake_grad={self.fake_grad},
+version={self.version}"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment