Skip to content

Instantly share code, notes, and snippets.

@ssnl
Created March 19, 2024 16:31
Show Gist options
  • Save ssnl/e57b14adc51edee30122cc8d07ce3734 to your computer and use it in GitHub Desktop.
Save ssnl/e57b14adc51edee30122cc8d07ce3734 to your computer and use it in GitHub Desktop.
diff --git a/quasimetric_rl/base_conf.py b/quasimetric_rl/base_conf.py
index 32de62f..df6afea 100644
--- a/quasimetric_rl/base_conf.py
+++ b/quasimetric_rl/base_conf.py
@@ -140,8 +140,8 @@ class BaseConf(abc.ABC):
]
if self.agent.quasimetric_critic.losses.dynamics_lagrange_mult_optim.lr > 0:
specs[-1] += '-opt'
- if self.agent.num_critics > 1:
- specs.append(f'{self.agent.num_critics}critic')
+ if self.agent.quasimetric_critic.num_critics > 1:
+ specs.append(f'{self.agent.quasimetric_critic.num_critics}critic')
if self.agent.actor is not None:
aspecs = []
if self.agent.actor.losses.min_dist.add_goal_as_future_state:
diff --git a/quasimetric_rl/modules/__init__.py b/quasimetric_rl/modules/__init__.py
index 81b5443..c2e05ba 100644
--- a/quasimetric_rl/modules/__init__.py
+++ b/quasimetric_rl/modules/__init__.py
@@ -24,52 +24,47 @@ class QRLAgent(Module):
class QRLLosses(Module):
actor_loss: Optional[actor.ActorLosses]
- critic_losses: Collection[quasimetric_critic.QuasimetricCriticLosses]
+ critic_loss: quasimetric_critic.QuasimetricCriticLosses
critics_total_grad_clip_norm: Optional[float]
recompute_critic_for_actor_loss: bool
- critics_share_embedding: bool
- def __init__(self, actor_loss: Optional['actor.ActorLosses'],
- critic_losses: Collection[quasimetric_critic.QuasimetricCriticLosses],
+ def __init__(self,
+ actor_loss: Optional['actor.ActorLosses'],
+ critic_loss: quasimetric_critic.QuasimetricCriticLosses,
critics_total_grad_clip_norm: Optional[float],
- recompute_critic_for_actor_loss: bool,
- critics_share_embedding: bool,
- critic_losses_use_target_encoder: bool):
+ recompute_critic_for_actor_loss: bool):
super().__init__()
self.add_module('actor_loss', actor_loss)
- self.critic_losses = torch.nn.ModuleList(critic_losses) # type: ignore
+ self.critic_loss = critic_loss
self.critics_total_grad_clip_norm = critics_total_grad_clip_norm
self.recompute_critic_for_actor_loss = recompute_critic_for_actor_loss
- self.critics_share_embedding = critics_share_embedding
- self.critic_losses_use_target_encoder = critic_losses_use_target_encoder
- def forward(self, agent: QRLAgent, data: BatchData, *, optimize: bool = True) -> LossResult:
- # compute CriticBatchInfo
+ def compute_critic_batch_infos(self, agent: QRLAgent, data: BatchData, *,
+ x_use_target_encoder: bool, y_use_target_encoder: bool) -> List[quasimetric_critic.CriticBatchInfo]:
critic_batch_infos: List[quasimetric_critic.CriticBatchInfo] = []
- loss_results: Dict[str, LossResult] = {}
-
- with contextlib.ExitStack() as stack:
- for idx, (critic, critic_loss) in enumerate(zip(agent.critics, self.critic_losses)):
- stack.enter_context(critic_loss.optim_update_context(optimize=optimize))
-
- if self.critics_share_embedding and idx > 0:
- critic_batch_info = attrs.evolve(critic_batch_infos[0], critic=critic)
- else:
- zx = critic.encoder(data.observations)
- zy = critic.get_encoder(target=self.critic_losses_use_target_encoder)(data.next_observations)
- if critic.has_separate_target_encoder and self.critic_losses_use_target_encoder:
- assert not zy.requires_grad
- critic_batch_info = quasimetric_critic.CriticBatchInfo(
- critic=critic,
- zx=zx,
- zy=zy,
- zy_from_target_encoder=self.critic_losses_use_target_encoder,
- )
+ for idx, critic in enumerate(agent.critics):
+ zx = critic.get_encoder(target=x_use_target_encoder)(data.observations)
+ with torch.enable_grad(not y_use_target_encoder):
+ zy = critic.get_encoder(target=y_use_target_encoder)(data.next_observations)
+ critic_batch_infos[idx] = quasimetric_critic.CriticBatchInfo(
+ critic=critic,
+ zx=zx,
+ zy=zy,
+ zx_from_target_encoder=x_use_target_encoder,
+ zy_from_target_encoder=y_use_target_encoder,
+ )
+ return critic_batch_infos
- loss_results[f"critic_{idx:02d}"] = critic_loss(data, critic_batch_info) # we update together to handle shared embedding
- critic_batch_infos.append(critic_batch_info)
+ def forward(self, agent: QRLAgent, data: BatchData, *, optimize: bool = True) -> LossResult:
+ loss_results: Dict[str, LossResult] = dict()
+
+ # compute CriticBatchInfo
+ critic_batch_infos: List[quasimetric_critic.CriticBatchInfo] = self.compute_critic_batch_infos(
+ agent, data, x_use_target_encoder=False, y_use_target_encoder=self.critic_loss.goal_use_target_encoder)
+ critic_loss_result = self.critic_loss(data, critic_batch_infos)
- critic_grad_norm: InfoValT = {}
+ with self.critic_loss.optim_update_context(optimize=optimize):
+ critic_grad_norm: Dict[str, InfoValT] = {}
if FLAGS.DEBUG:
def get_grad(loss: Union[torch.Tensor, float]) -> Union[torch.Tensor, float]:
@@ -86,14 +81,9 @@ class QRLLosses(Module):
sum(pg.pow(2).sum() for pg in loss_grads if pg is not None),
).sqrt()
- for k, loss_r in loss_results.items():
- critic_grad_norm.update({
- k: loss_r.map_losses(get_grad),
- })
+ critic_grad_norm = critic_loss_result.map_losses(get_grad) # type: ignore lol
- torch.stack(
- [loss_r.total_loss for loss_r in loss_results.values()]
- ).sum().backward()
+ critic_loss_result.total_loss.backward()
if self.critics_total_grad_clip_norm is not None:
critic_grad_norm['total'] = torch.nn.utils.clip_grad_norm_(
@@ -105,10 +95,14 @@ class QRLLosses(Module):
sum(p.grad.pow(2).sum() for p in cast(torch.nn.ModuleList, agent.critics).parameters() if p.grad is not None),
).sqrt()
+ critic_loss_result = critic_loss_result.evolve_info(grad_norm=critic_grad_norm)
+
if optimize:
for critic in agent.critics:
critic.update_target_models_()
+ loss_results['critic'] = critic_loss_result
+
actor_grad_norm: InfoValT = {}
if self.actor_loss is not None:
assert agent.actor is not None
@@ -125,7 +119,7 @@ class QRLLosses(Module):
zy_from_target_encoder=self.actor_loss.use_target_encoder,
)
with self.actor_loss.optim_update_context(optimize=optimize):
- loss_results['actor'] = loss_r = self.actor_loss(agent.actor, critic_batch_infos, data)
+ actor_loss_result = self.actor_loss(agent.actor, critic_batch_infos, data)
if FLAGS.DEBUG:
def get_grad(loss: Union[torch.Tensor, float]) -> Union[torch.Tensor, float]:
@@ -143,16 +137,19 @@ class QRLLosses(Module):
sum(pg.pow(2).sum() for pg in loss_grads if pg is not None),
).sqrt()
- actor_grad_norm.update(cast(Mapping, loss_r.map_losses(get_grad)))
+ actor_grad_norm.update(cast(Mapping, actor_loss_result.map_losses(get_grad)))
- loss_r.total_loss.backward()
+ actor_loss_result.total_loss.backward()
actor_grad_norm['total'] = cast(
torch.Tensor,
sum(p.grad.pow(2).sum() for p in agent.actor.parameters() if p.grad is not None),
).sqrt()
- return LossResult.combine(loss_results, grad_norm=dict(critic=critic_grad_norm, actor=actor_grad_norm))
+ actor_loss_result = actor_loss_result.evolve_info(grad_norm=actor_grad_norm)
+ loss_results['actor'] = actor_loss_result
+
+ return LossResult.combine(loss_results)
# for type hints
def __call__(self, agent: QRLAgent, data: BatchData, *, optimize: bool = True) -> LossResult:
@@ -167,15 +164,14 @@ class QRLLosses(Module):
entropy_weight_optim=self.actor_loss.entropy_weight_optim.state_dict(),
entropy_weight_sched=self.actor_loss.entropy_weight_sched.state_dict(),
)
- for idx, critic_loss in enumerate(self.critic_losses):
- optim_scheds[f"critic_{idx:02d}"] = dict(
- critic_optim=critic_loss.critic_optim.state_dict(),
- critic_sched=critic_loss.critic_sched.state_dict(),
- local_lagrange_mult_optim=critic_loss.local_lagrange_mult_optim.state_dict(),
- local_lagrange_mult_sched=critic_loss.local_lagrange_mult_sched.state_dict(),
- dynamics_lagrange_mult_optim=critic_loss.dynamics_lagrange_mult_optim.state_dict(),
- dynamics_lagrange_mult_sched=critic_loss.dynamics_lagrange_mult_sched.state_dict(),
- )
+ optim_scheds[f"critic"] = dict(
+ critic_optim=self.critic_loss.critic_optim.state_dict(),
+ critic_sched=self.critic_loss.critic_sched.state_dict(),
+ local_lagrange_mult_optim=self.critic_loss.local_lagrange_mult_optim.state_dict(),
+ local_lagrange_mult_sched=self.critic_loss.local_lagrange_mult_sched.state_dict(),
+ dynamics_lagrange_mult_optim=self.critic_loss.dynamics_lagrange_mult_optim.state_dict(),
+ dynamics_lagrange_mult_sched=self.critic_loss.dynamics_lagrange_mult_sched.state_dict(),
+ )
return dict(
module=super().state_dict(),
optim_scheds=optim_scheds,
@@ -189,20 +185,17 @@ class QRLLosses(Module):
self.actor_loss.actor_sched.load_state_dict(optim_scheds['actor']['actor_sched'])
self.actor_loss.entropy_weight_optim.load_state_dict(optim_scheds['actor']['entropy_weight_optim'])
self.actor_loss.entropy_weight_sched.load_state_dict(optim_scheds['actor']['entropy_weight_sched'])
- for idx, critic_loss in enumerate(self.critic_losses):
- critic_loss.critic_optim.load_state_dict(optim_scheds[f"critic_{idx:02d}"]['critic_optim'])
- critic_loss.critic_sched.load_state_dict(optim_scheds[f"critic_{idx:02d}"]['critic_sched'])
- critic_loss.local_lagrange_mult_optim.load_state_dict(optim_scheds[f"critic_{idx:02d}"]['local_lagrange_mult_optim'])
- critic_loss.local_lagrange_mult_sched.load_state_dict(optim_scheds[f"critic_{idx:02d}"]['local_lagrange_mult_sched'])
- critic_loss.dynamics_lagrange_mult_optim.load_state_dict(optim_scheds[f"critic_{idx:02d}"]['dynamics_lagrange_mult_optim'])
- critic_loss.dynamics_lagrange_mult_sched.load_state_dict(optim_scheds[f"critic_{idx:02d}"]['dynamics_lagrange_mult_sched'])
+ self.critic_loss.critic_optim.load_state_dict(optim_scheds['critic_optim'])
+ self.critic_loss.critic_sched.load_state_dict(optim_scheds['critic_sched'])
+ self.critic_loss.local_lagrange_mult_optim.load_state_dict(optim_scheds['local_lagrange_mult_optim'])
+ self.critic_loss.local_lagrange_mult_sched.load_state_dict(optim_scheds['local_lagrange_mult_sched'])
+ self.critic_loss.dynamics_lagrange_mult_optim.load_state_dict(optim_scheds['dynamics_lagrange_mult_optim'])
+ self.critic_loss.dynamics_lagrange_mult_sched.load_state_dict(optim_scheds['dynamics_lagrange_mult_sched'])
def extra_repr(self) -> str:
return '\n'.join([
f'recompute_critic_for_actor_loss={self.recompute_critic_for_actor_loss}',
- f'critics_share_embedding={self.critics_share_embedding}',
f'critics_total_grad_clip_norm={self.critics_total_grad_clip_norm}',
- f'critic_losses_use_target_encoder={self.critic_losses_use_target_encoder}',
])
@@ -210,13 +203,10 @@ class QRLLosses(Module):
class QRLConf:
actor: Optional['actor.ActorConf'] = actor.ActorConf()
quasimetric_critic: 'quasimetric_critic.QuasimetricCriticConf' = quasimetric_critic.QuasimetricCriticConf()
- num_critics: int = attrs.field(default=2, validator=attrs.validators.gt(0)) # NB that TD-MPC2 uses 5 # type: ignore
- critics_share_embedding: bool = False
critics_total_grad_clip_norm: Optional[float] = attrs.field(
default=None, validator=attrs.validators.optional(attrs.validators.gt(0)),
)
recompute_critic_for_actor_loss: bool = False
- critic_losses_use_target_encoder: bool = True
def make(self, *, env_spec: EnvSpec, total_optim_steps: int) -> Tuple[QRLAgent, QRLLosses]:
if self.actor is None:
@@ -224,24 +214,12 @@ class QRLConf:
else:
actor, actor_losses = self.actor.make(env_spec=env_spec, total_optim_steps=total_optim_steps)
- # first critic
- critic, critic_loss = self.quasimetric_critic.make(env_spec=env_spec, total_optim_steps=total_optim_steps)
-
- critics: List[quasimetric_critic.QuasimetricCritic] = [critic]
- critic_losses: List[quasimetric_critic.QuasimetricCriticLosses] = [critic_loss]
-
- for _ in range(self.num_critics - 1):
- critic, critic_loss = self.quasimetric_critic.make(env_spec=env_spec, total_optim_steps=total_optim_steps,
- share_embedding_from=(critics[0] if self.critics_share_embedding else None))
- critics.append(critic)
- critic_losses.append(critic_loss)
+ critics, critic_loss = self.quasimetric_critic.make(env_spec=env_spec, total_optim_steps=total_optim_steps)
return QRLAgent(actor=actor, critics=critics), QRLLosses(
- actor_loss=actor_losses, critic_losses=critic_losses,
- critics_share_embedding=self.critics_share_embedding,
+ actor_loss=actor_losses, critic_loss=critic_loss,
critics_total_grad_clip_norm=self.critics_total_grad_clip_norm,
recompute_critic_for_actor_loss=self.recompute_critic_for_actor_loss,
- critic_losses_use_target_encoder=self.critic_losses_use_target_encoder,
)
__all__ = ['QRLAgent', 'QRLLosses', 'QRLConf', 'InfoT', 'InfoValT']
diff --git a/quasimetric_rl/modules/quasimetric_critic/__init__.py b/quasimetric_rl/modules/quasimetric_critic/__init__.py
index 6986704..dcbb2aa 100644
--- a/quasimetric_rl/modules/quasimetric_critic/__init__.py
+++ b/quasimetric_rl/modules/quasimetric_critic/__init__.py
@@ -11,12 +11,17 @@ from ...data import EnvSpec
@attrs.define(kw_only=True)
class QuasimetricCriticConf:
model: QuasimetricCritic.Conf = QuasimetricCritic.Conf()
+ num_critics: int = attrs.field(default=2, validator=attrs.validators.gt(0)) # NB that TD-MPC2 uses 5 # type: ignore
+ critics_share_embedding: bool = False
losses: QuasimetricCriticLosses.Conf = QuasimetricCriticLosses.Conf()
- def make(self, *, env_spec: EnvSpec, total_optim_steps: int,
- share_embedding_from: Optional[QuasimetricCritic] = None) -> Tuple[QuasimetricCritic, QuasimetricCriticLosses]:
- critic = self.model.make(env_spec=env_spec, share_embedding_from=share_embedding_from)
- return critic, self.losses.make(critic, total_optim_steps, share_embedding_from=share_embedding_from)
+ def make(self, *, env_spec: EnvSpec, total_optim_steps: int) -> Tuple[Sequence[QuasimetricCritic], QuasimetricCriticLosses]:
+ critics = [self.model.make(env_spec=env_spec)]
+ for _ in range(self.num_critics - 1):
+ critics.append(
+ self.model.make(env_spec=env_spec, share_embedding_from=critics[0] if self.critics_share_embedding else None)
+ )
+ return critics, self.losses.make(critics, total_optim_steps)
__all__ = ['QuasimetricCritic', 'QuasimetricCriticLosses', 'CriticBatchInfo', 'QuasimetricCriticConf']
diff --git a/quasimetric_rl/modules/quasimetric_critic/losses/__init__.py b/quasimetric_rl/modules/quasimetric_critic/losses/__init__.py
index 5487442..cfaaa7f 100644
--- a/quasimetric_rl/modules/quasimetric_critic/losses/__init__.py
+++ b/quasimetric_rl/modules/quasimetric_critic/losses/__init__.py
@@ -2,6 +2,7 @@ from typing import *
import abc
import attrs
+from collections import defaultdict
import torch
@@ -25,15 +26,15 @@ class CriticBatchInfo:
class CriticLossBase(LossBase):
@abc.abstractmethod
- def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult:
+ def forward(self, data: BatchData, critic_batch_infos: Sequence[CriticBatchInfo]) -> LossResult:
pass
# for type hints
- def __call__(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult:
- return super().__call__(data, critic_batch_info)
+ def __call__(self, data: BatchData, critic_batch_infos: Sequence[CriticBatchInfo]) -> LossResult:
+ return super().__call__(data, critic_batch_infos)
-from .global_push import GlobalPushLoss, GlobalPushLinearLoss, GlobalPushLogLoss, GlobalPushRBFLoss, GlobalPushNextMSELoss
+from .global_push import GlobalPushLoss, GlobalPushLinearLoss, GlobalPushLogLoss, GlobalPushRBFLoss, GlobalPushNextLoss
from .local_constraint import LocalConstraintLoss
from .latent_dynamics import LatentDynamicsLoss
@@ -43,12 +44,14 @@ class QuasimetricCriticLosses(CriticLossBase):
class Conf:
global_push: GlobalPushLoss.Conf = GlobalPushLoss.Conf()
global_push_linear: GlobalPushLinearLoss.Conf = GlobalPushLinearLoss.Conf()
- global_push_next_mse: GlobalPushNextMSELoss.Conf = GlobalPushNextMSELoss.Conf()
+ global_push_next: GlobalPushNextLoss.Conf = GlobalPushNextLoss.Conf()
global_push_log: GlobalPushLogLoss.Conf = GlobalPushLogLoss.Conf()
global_push_rbf: GlobalPushRBFLoss.Conf = GlobalPushRBFLoss.Conf()
local_constraint: LocalConstraintLoss.Conf = LocalConstraintLoss.Conf()
latent_dynamics: LatentDynamicsLoss.Conf = LatentDynamicsLoss.Conf()
+ goal_use_target_encoder: bool = False
+
critic_optim: AdamWSpec.Conf = AdamWSpec.Conf(lr=1e-4)
latent_dynamics_lr_mul: float = attrs.field(default=1., validator=attrs.validators.ge(0))
quasimetric_model_lr_mul: float = attrs.field(default=1., validator=attrs.validators.ge(0))
@@ -62,22 +65,22 @@ class QuasimetricCriticLosses(CriticLossBase):
['best_local_fit', 'best_local_fit_clip5', 'best_local_fit_clip10',
'best_local_fit_detach']))) # type: ignore
- def make(self, critic: QuasimetricCritic, total_optim_steps: int,
- share_embedding_from: Optional[QuasimetricCritic] = None) -> 'QuasimetricCriticLosses':
+ def make(self, critics: Sequence[QuasimetricCritic], total_optim_steps: int) -> 'QuasimetricCriticLosses':
return QuasimetricCriticLosses(
- critic,
+ critics,
total_optim_steps=total_optim_steps,
- share_embedding_from=share_embedding_from,
# global losses
global_push=self.global_push.make(),
global_push_linear=self.global_push_linear.make(),
- global_push_next_mse=self.global_push_next_mse.make(),
+ global_push_next=self.global_push_next.make(),
global_push_log=self.global_push_log.make(),
global_push_rbf=self.global_push_rbf.make(),
# local loss
local_constraint=self.local_constraint.make(),
# dyn loss
latent_dynamics=self.latent_dynamics.make(),
+ #
+ goal_use_target_encoder=self.goal_use_target_encoder,
# critic optim
critic_optim_spec=self.critic_optim.make(),
# lr mult
@@ -91,15 +94,16 @@ class QuasimetricCriticLosses(CriticLossBase):
quasimetric_scale=self.quasimetric_scale,
)
- borrowing_embedding: bool
global_push: Optional[GlobalPushLoss]
global_push_linear: Optional[GlobalPushLinearLoss]
- global_push_next_mse: Optional[GlobalPushNextMSELoss]
+ global_push_next: Optional[GlobalPushNextLoss]
global_push_log: Optional[GlobalPushLogLoss]
global_push_rbf: Optional[GlobalPushRBFLoss]
- local_constraint: Optional[LocalConstraintLoss]
+ local_constraint: LocalConstraintLoss
latent_dynamics: LatentDynamicsLoss
+ goal_use_target_encoder: bool
+
critic_optim: OptimWrapper
critic_sched: LRScheduler
local_lagrange_mult_optim: OptimWrapper
@@ -108,12 +112,12 @@ class QuasimetricCriticLosses(CriticLossBase):
dynamics_lagrange_mult_sched: LRScheduler
quasimetric_scale: Optional[str]
- def __init__(self, critic: QuasimetricCritic, *, total_optim_steps: int,
- share_embedding_from: Optional[QuasimetricCritic] = None,
+ def __init__(self, critics: Sequence[QuasimetricCritic], *, total_optim_steps: int,
global_push: Optional[GlobalPushLoss], global_push_linear: Optional[GlobalPushLinearLoss],
- global_push_next_mse: Optional[GlobalPushNextMSELoss], global_push_log: Optional[GlobalPushLogLoss],
+ global_push_next: Optional[GlobalPushNextLoss], global_push_log: Optional[GlobalPushLogLoss],
global_push_rbf: Optional[GlobalPushRBFLoss],
- local_constraint: Optional[LocalConstraintLoss], latent_dynamics: LatentDynamicsLoss,
+ local_constraint: LocalConstraintLoss, latent_dynamics: LatentDynamicsLoss,
+ goal_use_target_encoder: bool,
critic_optim_spec: AdamWSpec,
latent_dynamics_lr_mul: float,
quasimetric_model_lr_mul: float,
@@ -123,31 +127,34 @@ class QuasimetricCriticLosses(CriticLossBase):
dynamics_lagrange_mult_optim_spec: AdamWSpec,
quasimetric_scale: Optional[str]):
super().__init__()
- self.borrowing_embedding = share_embedding_from is not None
- if self.borrowing_embedding:
- global_push = None
- global_push_linear = None
- global_push_log = None
- global_push_rbf = None
- local_constraint = None
+ # if self.borrowing_embedding:
+ # global_push = None
+ # global_push_linear = None
+ # global_push_log = None
+ # global_push_rbf = None
+ # local_constraint = None
self.add_module('global_push', global_push)
self.add_module('global_push_linear', global_push_linear)
- self.add_module('global_push_next_mse', global_push_next_mse)
+ self.add_module('global_push_next', global_push_next)
self.add_module('global_push_log', global_push_log)
self.add_module('global_push_rbf', global_push_rbf)
self.add_module('local_constraint', local_constraint)
self.latent_dynamics = latent_dynamics
- critic_param_groups = [
- dict(params=critic.latent_dynamics.parameters(), lr_mul=latent_dynamics_lr_mul),
- ]
- if not self.borrowing_embedding:
- # add encoder and quasimetric head
+ self.goal_use_target_encoder = goal_use_target_encoder
+
+ critic_param_groups = []
+ for critic in critics:
critic_param_groups += [
- dict(params=critic.quasimetric_model.parameters(include_head=False), lr_mul=quasimetric_model_lr_mul),
- dict(params=critic.quasimetric_model.quasimetric_head.parameters(), lr_mul=quasimetric_model_lr_mul * quasimetric_head_lr_mul),
- dict(params=critic.encoder.parameters(), lr_mul=encoder_lr_mul),
+ dict(params=critic.latent_dynamics.parameters(), lr_mul=latent_dynamics_lr_mul),
]
+ if not critic.borrowing_embedding:
+ # add encoder and quasimetric head
+ critic_param_groups += [
+ dict(params=critic.quasimetric_model.parameters(include_head=False), lr_mul=quasimetric_model_lr_mul),
+ dict(params=critic.quasimetric_model.quasimetric_head.parameters(), lr_mul=quasimetric_model_lr_mul * quasimetric_head_lr_mul),
+ dict(params=critic.encoder.parameters(), lr_mul=encoder_lr_mul),
+ ]
self.critic_optim, self.critic_sched = critic_optim_spec.create_optim_scheduler(critic_param_groups, total_optim_steps)
self.local_lagrange_mult_optim, self.local_lagrange_mult_sched = local_lagrange_mult_optim_spec.create_optim_scheduler(
@@ -166,48 +173,49 @@ class QuasimetricCriticLosses(CriticLossBase):
return [self.critic_sched, self.local_lagrange_mult_sched, self.dynamics_lagrange_mult_sched]
def compute_best_quasimetric_scale(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> Tuple[torch.Tensor, torch.Tensor]:
- assert self.local_constraint is not None and not self.borrowing_embedding
critic_batch_info.critic.quasimetric_model.quasimetric_head.scale.detach_().fill_(1) # reset
dist = critic_batch_info.critic.quasimetric_model(critic_batch_info.zx, critic_batch_info.zy)
return dist, (self.local_constraint.step_cost * (dist.mean() / dist.square().mean().clamp_min(1e-12))) # .detach().clamp_(1e-1, 1e1)
- def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult:
- extra_info: Dict[str, torch.Tensor] = {}
- if self.quasimetric_scale is not None and not self.borrowing_embedding:
- unscaled_dist, scale = self.compute_best_quasimetric_scale(data, critic_batch_info)
- assert scale.grad_fn is not None # allow bp
- if self.quasimetric_scale == 'best_local_fit_detach':
- scale = scale.detach()
- elif self.quasimetric_scale == 'best_local_fit_clip5':
- scale = scale.clamp(1 / 5, 5)
- elif self.quasimetric_scale == 'best_local_fit_clip10':
- scale = scale.clamp(1 / 10, 10)
- extra_info['unscaled_dist'] = unscaled_dist
- extra_info['quasimetric_autoscale'] = scale
- critic_batch_info.critic.quasimetric_model.quasimetric_head.scale = scale
+ def forward(self, data: BatchData, critic_batch_infos: Sequence[CriticBatchInfo]) -> LossResult:
+ extra_info: DefaultDict[str, Dict[str, torch.Tensor]] = defaultdict(dict)
+ if self.quasimetric_scale is not None:
+ for idx, critic_batch_info in enumerate(critic_batch_infos):
+ if critic_batch_info.critic.borrowing_embedding:
+ continue
+ unscaled_dist, scale = self.compute_best_quasimetric_scale(data, critic_batch_info)
+ assert scale.grad_fn is not None # allow bp
+ if self.quasimetric_scale == 'best_local_fit_detach':
+ scale = scale.detach()
+ elif self.quasimetric_scale == 'best_local_fit_clip5':
+ scale = scale.clamp(1 / 5, 5)
+ elif self.quasimetric_scale == 'best_local_fit_clip10':
+ scale = scale.clamp(1 / 10, 10)
+ extra_info[f"critic_{idx:02d}"]['unscaled_dist'] = unscaled_dist
+ extra_info[f"critic_{idx:02d}"]['quasimetric_autoscale'] = scale
+ critic_batch_info.critic.quasimetric_model.quasimetric_head.scale = scale
loss_results: Dict[str, LossResult] = {}
if self.global_push is not None:
- loss_results.update(global_push=self.global_push(data, critic_batch_info))
+ loss_results.update(global_push=self.global_push(data, critic_batch_infos))
if self.global_push_linear is not None:
- loss_results.update(global_push_linear=self.global_push_linear(data, critic_batch_info))
- if self.global_push_next_mse is not None:
- loss_results.update(global_push_next_mse=self.global_push_next_mse(data, critic_batch_info))
+ loss_results.update(global_push_linear=self.global_push_linear(data, critic_batch_infos))
+ if self.global_push_next is not None:
+ loss_results.update(global_push_next=self.global_push_next(data, critic_batch_infos))
if self.global_push_log is not None:
- loss_results.update(global_push_log=self.global_push_log(data, critic_batch_info))
+ loss_results.update(global_push_log=self.global_push_log(data, critic_batch_infos))
if self.global_push_rbf is not None:
- loss_results.update(global_push_rbf=self.global_push_rbf(data, critic_batch_info))
- if self.local_constraint is not None:
- loss_results.update(local_constraint=self.local_constraint(data, critic_batch_info))
+ loss_results.update(global_push_rbf=self.global_push_rbf(data, critic_batch_infos))
loss_results.update(
- latent_dynamics=self.latent_dynamics(data, critic_batch_info),
+ local_constraint=self.local_constraint(data, critic_batch_infos),
+ latent_dynamics=self.latent_dynamics(data, critic_batch_infos),
)
result = LossResult.combine(loss_results, **extra_info)
return result
# for type hints
- def __call__(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult:
- return torch.nn.Module.__call__(self, data, critic_batch_info)
+ def __call__(self, data: BatchData, critic_batch_infos: Sequence[CriticBatchInfo]) -> LossResult:
+ return torch.nn.Module.__call__(self, data, critic_batch_infos)
def extra_repr(self) -> str:
- return f"borrowing_embedding={self.borrowing_embedding}, quasimetric_scale={self.quasimetric_scale!r}"
+ return f"goal_use_target_encoder={self.goal_use_target_encoder}, quasimetric_scale={self.quasimetric_scale!r}"
diff --git a/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py b/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py
index fedce27..ec2ee66 100644
--- a/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py
+++ b/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py
@@ -140,7 +140,7 @@ class GlobalPushLossBase(CriticLossBase):
target: Optional[torch.Tensor], weight: float, info: Mapping[str, torch.Tensor]) -> LossResult:
raise NotImplementedError
- def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult:
+ def forward_one(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult:
return LossResult.combine(
{
name: self.compute_loss(data, critic_batch_info, zgoal, dist, target, weight, info)
@@ -148,6 +148,14 @@ class GlobalPushLossBase(CriticLossBase):
},
)
+ def forward(self, data: BatchData, critic_batch_infos: Sequence[CriticBatchInfo]) -> LossResult:
+ return LossResult.combine({
+ f"critic_{idx:02d}": self.forward_one(data, critic_batch_info)
+ for idx, critic_batch_info in enumerate(critic_batch_infos)
+ if not critic_batch_info.critic.borrowing_embedding
+ })
+
+
def extra_repr(self) -> str:
return '\n'.join([
f"weight={self.weight:g}, detach_proj_goal={self.detach_proj_goal}",
@@ -240,7 +248,8 @@ class GlobalPushLinearLoss(GlobalPushLossBase):
clamp_max: Optional[float]
- def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, detach_qmet: bool, step_cost: float, regress_max_future_goal: bool,
+ def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool,
+ detach_qmet: bool, step_cost: float, regress_max_future_goal: bool,
clamp_max: Optional[float]):
super().__init__(weight=weight, weight_future_goal=weight_future_goal, detach_goal=detach_goal,
detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, regress_max_future_goal=regress_max_future_goal)
@@ -284,7 +293,7 @@ class GlobalPushLinearLoss(GlobalPushLossBase):
-class GlobalPushNextMSELoss(GlobalPushLossBase):
+class GlobalPushNextLoss(GlobalPushLossBase):
@attrs.define(kw_only=True)
class Conf(GlobalPushLossBase.Conf):
enabled: bool = False
@@ -295,11 +304,15 @@ class GlobalPushNextMSELoss(GlobalPushLossBase):
attrs.validators.gt(0),
attrs.validators.lt(1),
)))
+ # expectile: float = attrs.field(default=0.5, validator=attrs.validators.and_(
+ # attrs.validators.ge(0),
+ # attrs.validators.le(1),
+ # ))
- def make(self) -> Optional['GlobalPushNextMSELoss']:
+ def make(self) -> Optional['GlobalPushNextLoss']:
if not self.enabled:
return None
- return GlobalPushNextMSELoss(
+ return GlobalPushNextLoss(
weight=self.weight,
weight_future_goal=self.weight_future_goal,
detach_goal=self.detach_goal,
@@ -312,7 +325,8 @@ class GlobalPushNextMSELoss(GlobalPushLossBase):
gamma=self.gamma,
)
- def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, detach_qmet: bool, step_cost: float, regress_max_future_goal: bool,
+ def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool,
+ detach_qmet: bool, step_cost: float, regress_max_future_goal: bool,
detach_target_dist: bool, allow_gt: bool, gamma: Optional[float]):
super().__init__(weight=weight, weight_future_goal=weight_future_goal, detach_goal=detach_goal,
detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, regress_max_future_goal=regress_max_future_goal)
@@ -327,7 +341,7 @@ class GlobalPushNextMSELoss(GlobalPushLossBase):
if target is None:
with torch.enable_grad(self.detach_target_dist):
# by tri-eq, the actual cost can't be larger than step_cost + d(s', g)
- next_dist = critic_batch_info.critic.quasimetric_model(
+ next_dist = critic_batch_info.critic.target_quasimetric_model(
critic_batch_info.zy, zgoal, proj_grad_enabled=(True, not self.detach_proj_goal)
)
if self.detach_target_dist:
@@ -382,7 +396,8 @@ class GlobalPushLogLoss(GlobalPushLossBase):
offset: float
- def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, detach_qmet: bool, step_cost: float, regress_max_future_goal: bool,
+ def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool,
+ detach_qmet: bool, step_cost: float, regress_max_future_goal: bool,
offset: float):
super().__init__(weight=weight, weight_future_goal=weight_future_goal, detach_goal=detach_goal,
detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, regress_max_future_goal=regress_max_future_goal)
@@ -439,7 +454,8 @@ class GlobalPushRBFLoss(GlobalPushLossBase):
inv_scale: float
- def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, detach_qmet: bool, step_cost: float, regress_max_future_goal: bool,
+ def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool,
+ detach_qmet: bool, step_cost: float, regress_max_future_goal: bool,
inv_scale: float):
super().__init__(weight=weight, weight_future_goal=weight_future_goal, detach_goal=detach_goal,
detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, regress_max_future_goal=regress_max_future_goal)
diff --git a/quasimetric_rl/modules/quasimetric_critic/losses/latent_dynamics.py b/quasimetric_rl/modules/quasimetric_critic/losses/latent_dynamics.py
index 19e4886..15213c6 100644
--- a/quasimetric_rl/modules/quasimetric_critic/losses/latent_dynamics.py
+++ b/quasimetric_rl/modules/quasimetric_critic/losses/latent_dynamics.py
@@ -88,7 +88,7 @@ class LatentDynamicsLoss(CriticLossBase):
self.raw_lagrange_multiplier = nn.Parameter(
torch.tensor(softplus_inv_float(init_lagrange_multiplier), dtype=torch.float32))
- def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult:
+ def forward_one(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult:
critic = critic_batch_info.critic
zx = critic_batch_info.zx
zy = critic_batch_info.zy
@@ -173,6 +173,11 @@ class LatentDynamicsLoss(CriticLossBase):
info=info, # type: ignore
)
+ def forward(self, data: BatchData, critic_batch_infos: Sequence[CriticBatchInfo]) -> LossResult:
+ return LossResult.combine({
+ f"critic_{idx:02d}": self.forward_one(data, critic_batch_info) for idx, critic_batch_info in enumerate(critic_batch_infos)
+ })
+
def extra_repr(self) -> str:
return '\n'.join([
f"kind={self.kind}, gamma={self.gamma!r}",
diff --git a/quasimetric_rl/modules/quasimetric_critic/losses/local_constraint.py b/quasimetric_rl/modules/quasimetric_critic/losses/local_constraint.py
index c3d6b38..30abbbe 100644
--- a/quasimetric_rl/modules/quasimetric_critic/losses/local_constraint.py
+++ b/quasimetric_rl/modules/quasimetric_critic/losses/local_constraint.py
@@ -111,7 +111,7 @@ class LocalConstraintLoss(CriticLossBase):
self.raw_lagrange_multiplier = nn.Parameter(
torch.tensor(softplus_inv_float(init_lagrange_multiplier), dtype=torch.float32)) # type: ignore
- def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult:
+ def forward_one(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult:
info: Dict[str, Union[float, torch.Tensor]] = {}
@@ -198,6 +198,12 @@ class LocalConstraintLoss(CriticLossBase):
return LossResult(loss=loss * lagrange_mult, info=info)
+ def forward(self, data: BatchData, critic_batch_infos: Sequence[CriticBatchInfo]) -> LossResult:
+ return LossResult.combine({
+ f"critic_{idx:02d}": self.forward_one(data, critic_batch_info) for idx, critic_batch_info in enumerate(critic_batch_infos)
+ if not critic_batch_info.critic.borrowing_embedding
+ })
+
def extra_repr(self) -> str:
return '\n'.join([
# f"kind={self.kind}, log={self.log}, batch_reduction={self.batch_reduction}",
diff --git a/quasimetric_rl/modules/utils.py b/quasimetric_rl/modules/utils.py
index 8fbe219..fa0a42c 100644
--- a/quasimetric_rl/modules/utils.py
+++ b/quasimetric_rl/modules/utils.py
@@ -31,6 +31,22 @@ InfoT = Union[NestedMapping[Union[float, torch.Tensor]], Mapping[str, 'InfoValT'
InfoValT = Union[InfoT, float, torch.Tensor]
+
+
+# merge kwargs into info, deep merge
+def deep_merge_dict(a: dict, b: Mapping, path=[]):
+ # https://stackoverflow.com/a/7205107
+ for key in b:
+ if key in a:
+ if isinstance(a[key], dict) and isinstance(b[key], Mapping):
+ deep_merge_dict(a[key], b[key], path + [str(key)])
+ elif a[key] != b[key]:
+ raise Exception('Conflict at ' + '.'.join(path + [str(key)]))
+ else:
+ a[key] = b[key]
+ return a
+
+
@attrs.define(kw_only=True)
class LossResult:
loss: InfoValT
@@ -84,11 +100,20 @@ class LossResult:
assert isinstance(l, torch.Tensor)
return l
+ def evolve_info(self, **kwargs) -> 'LossResult':
+ info = dict(self.info)
+ deep_merge_dict(info, kwargs)
+ return LossResult(
+ loss=self.loss,
+ info=info,
+ )
+
@classmethod
def combine(cls, results: Mapping[str, 'LossResult'], **kwargs) -> 'LossResult':
info = {k: r.info for k, r in results.items()}
- assert info.keys().isdisjoint(kwargs.keys())
- info.update(**kwargs)
+
+ deep_merge_dict(info, kwargs)
+
return LossResult(
loss={k: r.loss for k, r in results.items()},
info=info,
Submodule third_party/torch-quasimetric contains modified content
diff --git a/third_party/torch-quasimetric/torchqmet/iqe.py b/third_party/torch-quasimetric/torchqmet/iqe.py
index deeb541..4faa873 100644
--- a/third_party/torch-quasimetric/torchqmet/iqe.py
+++ b/third_party/torch-quasimetric/torchqmet/iqe.py
@@ -404,8 +404,8 @@ class IQE2(IQE):
components = components.reshape(*bshape, self.num_components)
- self.last_components = components
- return components
+ self.last_components = components # type: ignore
+ return components # type: ignore
def extra_repr(self) -> str:
return super().extra_repr() + rf"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment