Created
March 19, 2024 16:31
-
-
Save ssnl/e57b14adc51edee30122cc8d07ce3734 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/quasimetric_rl/base_conf.py b/quasimetric_rl/base_conf.py | |
index 32de62f..df6afea 100644 | |
--- a/quasimetric_rl/base_conf.py | |
+++ b/quasimetric_rl/base_conf.py | |
@@ -140,8 +140,8 @@ class BaseConf(abc.ABC): | |
] | |
if self.agent.quasimetric_critic.losses.dynamics_lagrange_mult_optim.lr > 0: | |
specs[-1] += '-opt' | |
- if self.agent.num_critics > 1: | |
- specs.append(f'{self.agent.num_critics}critic') | |
+ if self.agent.quasimetric_critic.num_critics > 1: | |
+ specs.append(f'{self.agent.quasimetric_critic.num_critics}critic') | |
if self.agent.actor is not None: | |
aspecs = [] | |
if self.agent.actor.losses.min_dist.add_goal_as_future_state: | |
diff --git a/quasimetric_rl/modules/__init__.py b/quasimetric_rl/modules/__init__.py | |
index 81b5443..c2e05ba 100644 | |
--- a/quasimetric_rl/modules/__init__.py | |
+++ b/quasimetric_rl/modules/__init__.py | |
@@ -24,52 +24,47 @@ class QRLAgent(Module): | |
class QRLLosses(Module): | |
actor_loss: Optional[actor.ActorLosses] | |
- critic_losses: Collection[quasimetric_critic.QuasimetricCriticLosses] | |
+ critic_loss: quasimetric_critic.QuasimetricCriticLosses | |
critics_total_grad_clip_norm: Optional[float] | |
recompute_critic_for_actor_loss: bool | |
- critics_share_embedding: bool | |
- def __init__(self, actor_loss: Optional['actor.ActorLosses'], | |
- critic_losses: Collection[quasimetric_critic.QuasimetricCriticLosses], | |
+ def __init__(self, | |
+ actor_loss: Optional['actor.ActorLosses'], | |
+ critic_loss: quasimetric_critic.QuasimetricCriticLosses, | |
critics_total_grad_clip_norm: Optional[float], | |
- recompute_critic_for_actor_loss: bool, | |
- critics_share_embedding: bool, | |
- critic_losses_use_target_encoder: bool): | |
+ recompute_critic_for_actor_loss: bool): | |
super().__init__() | |
self.add_module('actor_loss', actor_loss) | |
- self.critic_losses = torch.nn.ModuleList(critic_losses) # type: ignore | |
+ self.critic_loss = critic_loss | |
self.critics_total_grad_clip_norm = critics_total_grad_clip_norm | |
self.recompute_critic_for_actor_loss = recompute_critic_for_actor_loss | |
- self.critics_share_embedding = critics_share_embedding | |
- self.critic_losses_use_target_encoder = critic_losses_use_target_encoder | |
- def forward(self, agent: QRLAgent, data: BatchData, *, optimize: bool = True) -> LossResult: | |
- # compute CriticBatchInfo | |
+ def compute_critic_batch_infos(self, agent: QRLAgent, data: BatchData, *, | |
+ x_use_target_encoder: bool, y_use_target_encoder: bool) -> List[quasimetric_critic.CriticBatchInfo]: | |
critic_batch_infos: List[quasimetric_critic.CriticBatchInfo] = [] | |
- loss_results: Dict[str, LossResult] = {} | |
- | |
- with contextlib.ExitStack() as stack: | |
- for idx, (critic, critic_loss) in enumerate(zip(agent.critics, self.critic_losses)): | |
- stack.enter_context(critic_loss.optim_update_context(optimize=optimize)) | |
- | |
- if self.critics_share_embedding and idx > 0: | |
- critic_batch_info = attrs.evolve(critic_batch_infos[0], critic=critic) | |
- else: | |
- zx = critic.encoder(data.observations) | |
- zy = critic.get_encoder(target=self.critic_losses_use_target_encoder)(data.next_observations) | |
- if critic.has_separate_target_encoder and self.critic_losses_use_target_encoder: | |
- assert not zy.requires_grad | |
- critic_batch_info = quasimetric_critic.CriticBatchInfo( | |
- critic=critic, | |
- zx=zx, | |
- zy=zy, | |
- zy_from_target_encoder=self.critic_losses_use_target_encoder, | |
- ) | |
+ for idx, critic in enumerate(agent.critics): | |
+ zx = critic.get_encoder(target=x_use_target_encoder)(data.observations) | |
+ with torch.enable_grad(not y_use_target_encoder): | |
+ zy = critic.get_encoder(target=y_use_target_encoder)(data.next_observations) | |
+ critic_batch_infos[idx] = quasimetric_critic.CriticBatchInfo( | |
+ critic=critic, | |
+ zx=zx, | |
+ zy=zy, | |
+ zx_from_target_encoder=x_use_target_encoder, | |
+ zy_from_target_encoder=y_use_target_encoder, | |
+ ) | |
+ return critic_batch_infos | |
- loss_results[f"critic_{idx:02d}"] = critic_loss(data, critic_batch_info) # we update together to handle shared embedding | |
- critic_batch_infos.append(critic_batch_info) | |
+ def forward(self, agent: QRLAgent, data: BatchData, *, optimize: bool = True) -> LossResult: | |
+ loss_results: Dict[str, LossResult] = dict() | |
+ | |
+ # compute CriticBatchInfo | |
+ critic_batch_infos: List[quasimetric_critic.CriticBatchInfo] = self.compute_critic_batch_infos( | |
+ agent, data, x_use_target_encoder=False, y_use_target_encoder=self.critic_loss.goal_use_target_encoder) | |
+ critic_loss_result = self.critic_loss(data, critic_batch_infos) | |
- critic_grad_norm: InfoValT = {} | |
+ with self.critic_loss.optim_update_context(optimize=optimize): | |
+ critic_grad_norm: Dict[str, InfoValT] = {} | |
if FLAGS.DEBUG: | |
def get_grad(loss: Union[torch.Tensor, float]) -> Union[torch.Tensor, float]: | |
@@ -86,14 +81,9 @@ class QRLLosses(Module): | |
sum(pg.pow(2).sum() for pg in loss_grads if pg is not None), | |
).sqrt() | |
- for k, loss_r in loss_results.items(): | |
- critic_grad_norm.update({ | |
- k: loss_r.map_losses(get_grad), | |
- }) | |
+ critic_grad_norm = critic_loss_result.map_losses(get_grad) # type: ignore lol | |
- torch.stack( | |
- [loss_r.total_loss for loss_r in loss_results.values()] | |
- ).sum().backward() | |
+ critic_loss_result.total_loss.backward() | |
if self.critics_total_grad_clip_norm is not None: | |
critic_grad_norm['total'] = torch.nn.utils.clip_grad_norm_( | |
@@ -105,10 +95,14 @@ class QRLLosses(Module): | |
sum(p.grad.pow(2).sum() for p in cast(torch.nn.ModuleList, agent.critics).parameters() if p.grad is not None), | |
).sqrt() | |
+ critic_loss_result = critic_loss_result.evolve_info(grad_norm=critic_grad_norm) | |
+ | |
if optimize: | |
for critic in agent.critics: | |
critic.update_target_models_() | |
+ loss_results['critic'] = critic_loss_result | |
+ | |
actor_grad_norm: InfoValT = {} | |
if self.actor_loss is not None: | |
assert agent.actor is not None | |
@@ -125,7 +119,7 @@ class QRLLosses(Module): | |
zy_from_target_encoder=self.actor_loss.use_target_encoder, | |
) | |
with self.actor_loss.optim_update_context(optimize=optimize): | |
- loss_results['actor'] = loss_r = self.actor_loss(agent.actor, critic_batch_infos, data) | |
+ actor_loss_result = self.actor_loss(agent.actor, critic_batch_infos, data) | |
if FLAGS.DEBUG: | |
def get_grad(loss: Union[torch.Tensor, float]) -> Union[torch.Tensor, float]: | |
@@ -143,16 +137,19 @@ class QRLLosses(Module): | |
sum(pg.pow(2).sum() for pg in loss_grads if pg is not None), | |
).sqrt() | |
- actor_grad_norm.update(cast(Mapping, loss_r.map_losses(get_grad))) | |
+ actor_grad_norm.update(cast(Mapping, actor_loss_result.map_losses(get_grad))) | |
- loss_r.total_loss.backward() | |
+ actor_loss_result.total_loss.backward() | |
actor_grad_norm['total'] = cast( | |
torch.Tensor, | |
sum(p.grad.pow(2).sum() for p in agent.actor.parameters() if p.grad is not None), | |
).sqrt() | |
- return LossResult.combine(loss_results, grad_norm=dict(critic=critic_grad_norm, actor=actor_grad_norm)) | |
+ actor_loss_result = actor_loss_result.evolve_info(grad_norm=actor_grad_norm) | |
+ loss_results['actor'] = actor_loss_result | |
+ | |
+ return LossResult.combine(loss_results) | |
# for type hints | |
def __call__(self, agent: QRLAgent, data: BatchData, *, optimize: bool = True) -> LossResult: | |
@@ -167,15 +164,14 @@ class QRLLosses(Module): | |
entropy_weight_optim=self.actor_loss.entropy_weight_optim.state_dict(), | |
entropy_weight_sched=self.actor_loss.entropy_weight_sched.state_dict(), | |
) | |
- for idx, critic_loss in enumerate(self.critic_losses): | |
- optim_scheds[f"critic_{idx:02d}"] = dict( | |
- critic_optim=critic_loss.critic_optim.state_dict(), | |
- critic_sched=critic_loss.critic_sched.state_dict(), | |
- local_lagrange_mult_optim=critic_loss.local_lagrange_mult_optim.state_dict(), | |
- local_lagrange_mult_sched=critic_loss.local_lagrange_mult_sched.state_dict(), | |
- dynamics_lagrange_mult_optim=critic_loss.dynamics_lagrange_mult_optim.state_dict(), | |
- dynamics_lagrange_mult_sched=critic_loss.dynamics_lagrange_mult_sched.state_dict(), | |
- ) | |
+ optim_scheds[f"critic"] = dict( | |
+ critic_optim=self.critic_loss.critic_optim.state_dict(), | |
+ critic_sched=self.critic_loss.critic_sched.state_dict(), | |
+ local_lagrange_mult_optim=self.critic_loss.local_lagrange_mult_optim.state_dict(), | |
+ local_lagrange_mult_sched=self.critic_loss.local_lagrange_mult_sched.state_dict(), | |
+ dynamics_lagrange_mult_optim=self.critic_loss.dynamics_lagrange_mult_optim.state_dict(), | |
+ dynamics_lagrange_mult_sched=self.critic_loss.dynamics_lagrange_mult_sched.state_dict(), | |
+ ) | |
return dict( | |
module=super().state_dict(), | |
optim_scheds=optim_scheds, | |
@@ -189,20 +185,17 @@ class QRLLosses(Module): | |
self.actor_loss.actor_sched.load_state_dict(optim_scheds['actor']['actor_sched']) | |
self.actor_loss.entropy_weight_optim.load_state_dict(optim_scheds['actor']['entropy_weight_optim']) | |
self.actor_loss.entropy_weight_sched.load_state_dict(optim_scheds['actor']['entropy_weight_sched']) | |
- for idx, critic_loss in enumerate(self.critic_losses): | |
- critic_loss.critic_optim.load_state_dict(optim_scheds[f"critic_{idx:02d}"]['critic_optim']) | |
- critic_loss.critic_sched.load_state_dict(optim_scheds[f"critic_{idx:02d}"]['critic_sched']) | |
- critic_loss.local_lagrange_mult_optim.load_state_dict(optim_scheds[f"critic_{idx:02d}"]['local_lagrange_mult_optim']) | |
- critic_loss.local_lagrange_mult_sched.load_state_dict(optim_scheds[f"critic_{idx:02d}"]['local_lagrange_mult_sched']) | |
- critic_loss.dynamics_lagrange_mult_optim.load_state_dict(optim_scheds[f"critic_{idx:02d}"]['dynamics_lagrange_mult_optim']) | |
- critic_loss.dynamics_lagrange_mult_sched.load_state_dict(optim_scheds[f"critic_{idx:02d}"]['dynamics_lagrange_mult_sched']) | |
+ self.critic_loss.critic_optim.load_state_dict(optim_scheds['critic_optim']) | |
+ self.critic_loss.critic_sched.load_state_dict(optim_scheds['critic_sched']) | |
+ self.critic_loss.local_lagrange_mult_optim.load_state_dict(optim_scheds['local_lagrange_mult_optim']) | |
+ self.critic_loss.local_lagrange_mult_sched.load_state_dict(optim_scheds['local_lagrange_mult_sched']) | |
+ self.critic_loss.dynamics_lagrange_mult_optim.load_state_dict(optim_scheds['dynamics_lagrange_mult_optim']) | |
+ self.critic_loss.dynamics_lagrange_mult_sched.load_state_dict(optim_scheds['dynamics_lagrange_mult_sched']) | |
def extra_repr(self) -> str: | |
return '\n'.join([ | |
f'recompute_critic_for_actor_loss={self.recompute_critic_for_actor_loss}', | |
- f'critics_share_embedding={self.critics_share_embedding}', | |
f'critics_total_grad_clip_norm={self.critics_total_grad_clip_norm}', | |
- f'critic_losses_use_target_encoder={self.critic_losses_use_target_encoder}', | |
]) | |
@@ -210,13 +203,10 @@ class QRLLosses(Module): | |
class QRLConf: | |
actor: Optional['actor.ActorConf'] = actor.ActorConf() | |
quasimetric_critic: 'quasimetric_critic.QuasimetricCriticConf' = quasimetric_critic.QuasimetricCriticConf() | |
- num_critics: int = attrs.field(default=2, validator=attrs.validators.gt(0)) # NB that TD-MPC2 uses 5 # type: ignore | |
- critics_share_embedding: bool = False | |
critics_total_grad_clip_norm: Optional[float] = attrs.field( | |
default=None, validator=attrs.validators.optional(attrs.validators.gt(0)), | |
) | |
recompute_critic_for_actor_loss: bool = False | |
- critic_losses_use_target_encoder: bool = True | |
def make(self, *, env_spec: EnvSpec, total_optim_steps: int) -> Tuple[QRLAgent, QRLLosses]: | |
if self.actor is None: | |
@@ -224,24 +214,12 @@ class QRLConf: | |
else: | |
actor, actor_losses = self.actor.make(env_spec=env_spec, total_optim_steps=total_optim_steps) | |
- # first critic | |
- critic, critic_loss = self.quasimetric_critic.make(env_spec=env_spec, total_optim_steps=total_optim_steps) | |
- | |
- critics: List[quasimetric_critic.QuasimetricCritic] = [critic] | |
- critic_losses: List[quasimetric_critic.QuasimetricCriticLosses] = [critic_loss] | |
- | |
- for _ in range(self.num_critics - 1): | |
- critic, critic_loss = self.quasimetric_critic.make(env_spec=env_spec, total_optim_steps=total_optim_steps, | |
- share_embedding_from=(critics[0] if self.critics_share_embedding else None)) | |
- critics.append(critic) | |
- critic_losses.append(critic_loss) | |
+ critics, critic_loss = self.quasimetric_critic.make(env_spec=env_spec, total_optim_steps=total_optim_steps) | |
return QRLAgent(actor=actor, critics=critics), QRLLosses( | |
- actor_loss=actor_losses, critic_losses=critic_losses, | |
- critics_share_embedding=self.critics_share_embedding, | |
+ actor_loss=actor_losses, critic_loss=critic_loss, | |
critics_total_grad_clip_norm=self.critics_total_grad_clip_norm, | |
recompute_critic_for_actor_loss=self.recompute_critic_for_actor_loss, | |
- critic_losses_use_target_encoder=self.critic_losses_use_target_encoder, | |
) | |
__all__ = ['QRLAgent', 'QRLLosses', 'QRLConf', 'InfoT', 'InfoValT'] | |
diff --git a/quasimetric_rl/modules/quasimetric_critic/__init__.py b/quasimetric_rl/modules/quasimetric_critic/__init__.py | |
index 6986704..dcbb2aa 100644 | |
--- a/quasimetric_rl/modules/quasimetric_critic/__init__.py | |
+++ b/quasimetric_rl/modules/quasimetric_critic/__init__.py | |
@@ -11,12 +11,17 @@ from ...data import EnvSpec | |
@attrs.define(kw_only=True) | |
class QuasimetricCriticConf: | |
model: QuasimetricCritic.Conf = QuasimetricCritic.Conf() | |
+ num_critics: int = attrs.field(default=2, validator=attrs.validators.gt(0)) # NB that TD-MPC2 uses 5 # type: ignore | |
+ critics_share_embedding: bool = False | |
losses: QuasimetricCriticLosses.Conf = QuasimetricCriticLosses.Conf() | |
- def make(self, *, env_spec: EnvSpec, total_optim_steps: int, | |
- share_embedding_from: Optional[QuasimetricCritic] = None) -> Tuple[QuasimetricCritic, QuasimetricCriticLosses]: | |
- critic = self.model.make(env_spec=env_spec, share_embedding_from=share_embedding_from) | |
- return critic, self.losses.make(critic, total_optim_steps, share_embedding_from=share_embedding_from) | |
+ def make(self, *, env_spec: EnvSpec, total_optim_steps: int) -> Tuple[Sequence[QuasimetricCritic], QuasimetricCriticLosses]: | |
+ critics = [self.model.make(env_spec=env_spec)] | |
+ for _ in range(self.num_critics - 1): | |
+ critics.append( | |
+ self.model.make(env_spec=env_spec, share_embedding_from=critics[0] if self.critics_share_embedding else None) | |
+ ) | |
+ return critics, self.losses.make(critics, total_optim_steps) | |
__all__ = ['QuasimetricCritic', 'QuasimetricCriticLosses', 'CriticBatchInfo', 'QuasimetricCriticConf'] | |
diff --git a/quasimetric_rl/modules/quasimetric_critic/losses/__init__.py b/quasimetric_rl/modules/quasimetric_critic/losses/__init__.py | |
index 5487442..cfaaa7f 100644 | |
--- a/quasimetric_rl/modules/quasimetric_critic/losses/__init__.py | |
+++ b/quasimetric_rl/modules/quasimetric_critic/losses/__init__.py | |
@@ -2,6 +2,7 @@ from typing import * | |
import abc | |
import attrs | |
+from collections import defaultdict | |
import torch | |
@@ -25,15 +26,15 @@ class CriticBatchInfo: | |
class CriticLossBase(LossBase): | |
@abc.abstractmethod | |
- def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult: | |
+ def forward(self, data: BatchData, critic_batch_infos: Sequence[CriticBatchInfo]) -> LossResult: | |
pass | |
# for type hints | |
- def __call__(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult: | |
- return super().__call__(data, critic_batch_info) | |
+ def __call__(self, data: BatchData, critic_batch_infos: Sequence[CriticBatchInfo]) -> LossResult: | |
+ return super().__call__(data, critic_batch_infos) | |
-from .global_push import GlobalPushLoss, GlobalPushLinearLoss, GlobalPushLogLoss, GlobalPushRBFLoss, GlobalPushNextMSELoss | |
+from .global_push import GlobalPushLoss, GlobalPushLinearLoss, GlobalPushLogLoss, GlobalPushRBFLoss, GlobalPushNextLoss | |
from .local_constraint import LocalConstraintLoss | |
from .latent_dynamics import LatentDynamicsLoss | |
@@ -43,12 +44,14 @@ class QuasimetricCriticLosses(CriticLossBase): | |
class Conf: | |
global_push: GlobalPushLoss.Conf = GlobalPushLoss.Conf() | |
global_push_linear: GlobalPushLinearLoss.Conf = GlobalPushLinearLoss.Conf() | |
- global_push_next_mse: GlobalPushNextMSELoss.Conf = GlobalPushNextMSELoss.Conf() | |
+ global_push_next: GlobalPushNextLoss.Conf = GlobalPushNextLoss.Conf() | |
global_push_log: GlobalPushLogLoss.Conf = GlobalPushLogLoss.Conf() | |
global_push_rbf: GlobalPushRBFLoss.Conf = GlobalPushRBFLoss.Conf() | |
local_constraint: LocalConstraintLoss.Conf = LocalConstraintLoss.Conf() | |
latent_dynamics: LatentDynamicsLoss.Conf = LatentDynamicsLoss.Conf() | |
+ goal_use_target_encoder: bool = False | |
+ | |
critic_optim: AdamWSpec.Conf = AdamWSpec.Conf(lr=1e-4) | |
latent_dynamics_lr_mul: float = attrs.field(default=1., validator=attrs.validators.ge(0)) | |
quasimetric_model_lr_mul: float = attrs.field(default=1., validator=attrs.validators.ge(0)) | |
@@ -62,22 +65,22 @@ class QuasimetricCriticLosses(CriticLossBase): | |
['best_local_fit', 'best_local_fit_clip5', 'best_local_fit_clip10', | |
'best_local_fit_detach']))) # type: ignore | |
- def make(self, critic: QuasimetricCritic, total_optim_steps: int, | |
- share_embedding_from: Optional[QuasimetricCritic] = None) -> 'QuasimetricCriticLosses': | |
+ def make(self, critics: Sequence[QuasimetricCritic], total_optim_steps: int) -> 'QuasimetricCriticLosses': | |
return QuasimetricCriticLosses( | |
- critic, | |
+ critics, | |
total_optim_steps=total_optim_steps, | |
- share_embedding_from=share_embedding_from, | |
# global losses | |
global_push=self.global_push.make(), | |
global_push_linear=self.global_push_linear.make(), | |
- global_push_next_mse=self.global_push_next_mse.make(), | |
+ global_push_next=self.global_push_next.make(), | |
global_push_log=self.global_push_log.make(), | |
global_push_rbf=self.global_push_rbf.make(), | |
# local loss | |
local_constraint=self.local_constraint.make(), | |
# dyn loss | |
latent_dynamics=self.latent_dynamics.make(), | |
+ # | |
+ goal_use_target_encoder=self.goal_use_target_encoder, | |
# critic optim | |
critic_optim_spec=self.critic_optim.make(), | |
# lr mult | |
@@ -91,15 +94,16 @@ class QuasimetricCriticLosses(CriticLossBase): | |
quasimetric_scale=self.quasimetric_scale, | |
) | |
- borrowing_embedding: bool | |
global_push: Optional[GlobalPushLoss] | |
global_push_linear: Optional[GlobalPushLinearLoss] | |
- global_push_next_mse: Optional[GlobalPushNextMSELoss] | |
+ global_push_next: Optional[GlobalPushNextLoss] | |
global_push_log: Optional[GlobalPushLogLoss] | |
global_push_rbf: Optional[GlobalPushRBFLoss] | |
- local_constraint: Optional[LocalConstraintLoss] | |
+ local_constraint: LocalConstraintLoss | |
latent_dynamics: LatentDynamicsLoss | |
+ goal_use_target_encoder: bool | |
+ | |
critic_optim: OptimWrapper | |
critic_sched: LRScheduler | |
local_lagrange_mult_optim: OptimWrapper | |
@@ -108,12 +112,12 @@ class QuasimetricCriticLosses(CriticLossBase): | |
dynamics_lagrange_mult_sched: LRScheduler | |
quasimetric_scale: Optional[str] | |
- def __init__(self, critic: QuasimetricCritic, *, total_optim_steps: int, | |
- share_embedding_from: Optional[QuasimetricCritic] = None, | |
+ def __init__(self, critics: Sequence[QuasimetricCritic], *, total_optim_steps: int, | |
global_push: Optional[GlobalPushLoss], global_push_linear: Optional[GlobalPushLinearLoss], | |
- global_push_next_mse: Optional[GlobalPushNextMSELoss], global_push_log: Optional[GlobalPushLogLoss], | |
+ global_push_next: Optional[GlobalPushNextLoss], global_push_log: Optional[GlobalPushLogLoss], | |
global_push_rbf: Optional[GlobalPushRBFLoss], | |
- local_constraint: Optional[LocalConstraintLoss], latent_dynamics: LatentDynamicsLoss, | |
+ local_constraint: LocalConstraintLoss, latent_dynamics: LatentDynamicsLoss, | |
+ goal_use_target_encoder: bool, | |
critic_optim_spec: AdamWSpec, | |
latent_dynamics_lr_mul: float, | |
quasimetric_model_lr_mul: float, | |
@@ -123,31 +127,34 @@ class QuasimetricCriticLosses(CriticLossBase): | |
dynamics_lagrange_mult_optim_spec: AdamWSpec, | |
quasimetric_scale: Optional[str]): | |
super().__init__() | |
- self.borrowing_embedding = share_embedding_from is not None | |
- if self.borrowing_embedding: | |
- global_push = None | |
- global_push_linear = None | |
- global_push_log = None | |
- global_push_rbf = None | |
- local_constraint = None | |
+ # if self.borrowing_embedding: | |
+ # global_push = None | |
+ # global_push_linear = None | |
+ # global_push_log = None | |
+ # global_push_rbf = None | |
+ # local_constraint = None | |
self.add_module('global_push', global_push) | |
self.add_module('global_push_linear', global_push_linear) | |
- self.add_module('global_push_next_mse', global_push_next_mse) | |
+ self.add_module('global_push_next', global_push_next) | |
self.add_module('global_push_log', global_push_log) | |
self.add_module('global_push_rbf', global_push_rbf) | |
self.add_module('local_constraint', local_constraint) | |
self.latent_dynamics = latent_dynamics | |
- critic_param_groups = [ | |
- dict(params=critic.latent_dynamics.parameters(), lr_mul=latent_dynamics_lr_mul), | |
- ] | |
- if not self.borrowing_embedding: | |
- # add encoder and quasimetric head | |
+ self.goal_use_target_encoder = goal_use_target_encoder | |
+ | |
+ critic_param_groups = [] | |
+ for critic in critics: | |
critic_param_groups += [ | |
- dict(params=critic.quasimetric_model.parameters(include_head=False), lr_mul=quasimetric_model_lr_mul), | |
- dict(params=critic.quasimetric_model.quasimetric_head.parameters(), lr_mul=quasimetric_model_lr_mul * quasimetric_head_lr_mul), | |
- dict(params=critic.encoder.parameters(), lr_mul=encoder_lr_mul), | |
+ dict(params=critic.latent_dynamics.parameters(), lr_mul=latent_dynamics_lr_mul), | |
] | |
+ if not critic.borrowing_embedding: | |
+ # add encoder and quasimetric head | |
+ critic_param_groups += [ | |
+ dict(params=critic.quasimetric_model.parameters(include_head=False), lr_mul=quasimetric_model_lr_mul), | |
+ dict(params=critic.quasimetric_model.quasimetric_head.parameters(), lr_mul=quasimetric_model_lr_mul * quasimetric_head_lr_mul), | |
+ dict(params=critic.encoder.parameters(), lr_mul=encoder_lr_mul), | |
+ ] | |
self.critic_optim, self.critic_sched = critic_optim_spec.create_optim_scheduler(critic_param_groups, total_optim_steps) | |
self.local_lagrange_mult_optim, self.local_lagrange_mult_sched = local_lagrange_mult_optim_spec.create_optim_scheduler( | |
@@ -166,48 +173,49 @@ class QuasimetricCriticLosses(CriticLossBase): | |
return [self.critic_sched, self.local_lagrange_mult_sched, self.dynamics_lagrange_mult_sched] | |
def compute_best_quasimetric_scale(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> Tuple[torch.Tensor, torch.Tensor]: | |
- assert self.local_constraint is not None and not self.borrowing_embedding | |
critic_batch_info.critic.quasimetric_model.quasimetric_head.scale.detach_().fill_(1) # reset | |
dist = critic_batch_info.critic.quasimetric_model(critic_batch_info.zx, critic_batch_info.zy) | |
return dist, (self.local_constraint.step_cost * (dist.mean() / dist.square().mean().clamp_min(1e-12))) # .detach().clamp_(1e-1, 1e1) | |
- def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult: | |
- extra_info: Dict[str, torch.Tensor] = {} | |
- if self.quasimetric_scale is not None and not self.borrowing_embedding: | |
- unscaled_dist, scale = self.compute_best_quasimetric_scale(data, critic_batch_info) | |
- assert scale.grad_fn is not None # allow bp | |
- if self.quasimetric_scale == 'best_local_fit_detach': | |
- scale = scale.detach() | |
- elif self.quasimetric_scale == 'best_local_fit_clip5': | |
- scale = scale.clamp(1 / 5, 5) | |
- elif self.quasimetric_scale == 'best_local_fit_clip10': | |
- scale = scale.clamp(1 / 10, 10) | |
- extra_info['unscaled_dist'] = unscaled_dist | |
- extra_info['quasimetric_autoscale'] = scale | |
- critic_batch_info.critic.quasimetric_model.quasimetric_head.scale = scale | |
+ def forward(self, data: BatchData, critic_batch_infos: Sequence[CriticBatchInfo]) -> LossResult: | |
+ extra_info: DefaultDict[str, Dict[str, torch.Tensor]] = defaultdict(dict) | |
+ if self.quasimetric_scale is not None: | |
+ for idx, critic_batch_info in enumerate(critic_batch_infos): | |
+ if critic_batch_info.critic.borrowing_embedding: | |
+ continue | |
+ unscaled_dist, scale = self.compute_best_quasimetric_scale(data, critic_batch_info) | |
+ assert scale.grad_fn is not None # allow bp | |
+ if self.quasimetric_scale == 'best_local_fit_detach': | |
+ scale = scale.detach() | |
+ elif self.quasimetric_scale == 'best_local_fit_clip5': | |
+ scale = scale.clamp(1 / 5, 5) | |
+ elif self.quasimetric_scale == 'best_local_fit_clip10': | |
+ scale = scale.clamp(1 / 10, 10) | |
+ extra_info[f"critic_{idx:02d}"]['unscaled_dist'] = unscaled_dist | |
+ extra_info[f"critic_{idx:02d}"]['quasimetric_autoscale'] = scale | |
+ critic_batch_info.critic.quasimetric_model.quasimetric_head.scale = scale | |
loss_results: Dict[str, LossResult] = {} | |
if self.global_push is not None: | |
- loss_results.update(global_push=self.global_push(data, critic_batch_info)) | |
+ loss_results.update(global_push=self.global_push(data, critic_batch_infos)) | |
if self.global_push_linear is not None: | |
- loss_results.update(global_push_linear=self.global_push_linear(data, critic_batch_info)) | |
- if self.global_push_next_mse is not None: | |
- loss_results.update(global_push_next_mse=self.global_push_next_mse(data, critic_batch_info)) | |
+ loss_results.update(global_push_linear=self.global_push_linear(data, critic_batch_infos)) | |
+ if self.global_push_next is not None: | |
+ loss_results.update(global_push_next=self.global_push_next(data, critic_batch_infos)) | |
if self.global_push_log is not None: | |
- loss_results.update(global_push_log=self.global_push_log(data, critic_batch_info)) | |
+ loss_results.update(global_push_log=self.global_push_log(data, critic_batch_infos)) | |
if self.global_push_rbf is not None: | |
- loss_results.update(global_push_rbf=self.global_push_rbf(data, critic_batch_info)) | |
- if self.local_constraint is not None: | |
- loss_results.update(local_constraint=self.local_constraint(data, critic_batch_info)) | |
+ loss_results.update(global_push_rbf=self.global_push_rbf(data, critic_batch_infos)) | |
loss_results.update( | |
- latent_dynamics=self.latent_dynamics(data, critic_batch_info), | |
+ local_constraint=self.local_constraint(data, critic_batch_infos), | |
+ latent_dynamics=self.latent_dynamics(data, critic_batch_infos), | |
) | |
result = LossResult.combine(loss_results, **extra_info) | |
return result | |
# for type hints | |
- def __call__(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult: | |
- return torch.nn.Module.__call__(self, data, critic_batch_info) | |
+ def __call__(self, data: BatchData, critic_batch_infos: Sequence[CriticBatchInfo]) -> LossResult: | |
+ return torch.nn.Module.__call__(self, data, critic_batch_infos) | |
def extra_repr(self) -> str: | |
- return f"borrowing_embedding={self.borrowing_embedding}, quasimetric_scale={self.quasimetric_scale!r}" | |
+ return f"goal_use_target_encoder={self.goal_use_target_encoder}, quasimetric_scale={self.quasimetric_scale!r}" | |
diff --git a/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py b/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py | |
index fedce27..ec2ee66 100644 | |
--- a/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py | |
+++ b/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py | |
@@ -140,7 +140,7 @@ class GlobalPushLossBase(CriticLossBase): | |
target: Optional[torch.Tensor], weight: float, info: Mapping[str, torch.Tensor]) -> LossResult: | |
raise NotImplementedError | |
- def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult: | |
+ def forward_one(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult: | |
return LossResult.combine( | |
{ | |
name: self.compute_loss(data, critic_batch_info, zgoal, dist, target, weight, info) | |
@@ -148,6 +148,14 @@ class GlobalPushLossBase(CriticLossBase): | |
}, | |
) | |
+ def forward(self, data: BatchData, critic_batch_infos: Sequence[CriticBatchInfo]) -> LossResult: | |
+ return LossResult.combine({ | |
+ f"critic_{idx:02d}": self.forward_one(data, critic_batch_info) | |
+ for idx, critic_batch_info in enumerate(critic_batch_infos) | |
+ if not critic_batch_info.critic.borrowing_embedding | |
+ }) | |
+ | |
+ | |
def extra_repr(self) -> str: | |
return '\n'.join([ | |
f"weight={self.weight:g}, detach_proj_goal={self.detach_proj_goal}", | |
@@ -240,7 +248,8 @@ class GlobalPushLinearLoss(GlobalPushLossBase): | |
clamp_max: Optional[float] | |
- def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, detach_qmet: bool, step_cost: float, regress_max_future_goal: bool, | |
+ def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, | |
+ detach_qmet: bool, step_cost: float, regress_max_future_goal: bool, | |
clamp_max: Optional[float]): | |
super().__init__(weight=weight, weight_future_goal=weight_future_goal, detach_goal=detach_goal, | |
detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, regress_max_future_goal=regress_max_future_goal) | |
@@ -284,7 +293,7 @@ class GlobalPushLinearLoss(GlobalPushLossBase): | |
-class GlobalPushNextMSELoss(GlobalPushLossBase): | |
+class GlobalPushNextLoss(GlobalPushLossBase): | |
@attrs.define(kw_only=True) | |
class Conf(GlobalPushLossBase.Conf): | |
enabled: bool = False | |
@@ -295,11 +304,15 @@ class GlobalPushNextMSELoss(GlobalPushLossBase): | |
attrs.validators.gt(0), | |
attrs.validators.lt(1), | |
))) | |
+ # expectile: float = attrs.field(default=0.5, validator=attrs.validators.and_( | |
+ # attrs.validators.ge(0), | |
+ # attrs.validators.le(1), | |
+ # )) | |
- def make(self) -> Optional['GlobalPushNextMSELoss']: | |
+ def make(self) -> Optional['GlobalPushNextLoss']: | |
if not self.enabled: | |
return None | |
- return GlobalPushNextMSELoss( | |
+ return GlobalPushNextLoss( | |
weight=self.weight, | |
weight_future_goal=self.weight_future_goal, | |
detach_goal=self.detach_goal, | |
@@ -312,7 +325,8 @@ class GlobalPushNextMSELoss(GlobalPushLossBase): | |
gamma=self.gamma, | |
) | |
- def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, detach_qmet: bool, step_cost: float, regress_max_future_goal: bool, | |
+ def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, | |
+ detach_qmet: bool, step_cost: float, regress_max_future_goal: bool, | |
detach_target_dist: bool, allow_gt: bool, gamma: Optional[float]): | |
super().__init__(weight=weight, weight_future_goal=weight_future_goal, detach_goal=detach_goal, | |
detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, regress_max_future_goal=regress_max_future_goal) | |
@@ -327,7 +341,7 @@ class GlobalPushNextMSELoss(GlobalPushLossBase): | |
if target is None: | |
with torch.enable_grad(self.detach_target_dist): | |
# by tri-eq, the actual cost can't be larger than step_cost + d(s', g) | |
- next_dist = critic_batch_info.critic.quasimetric_model( | |
+ next_dist = critic_batch_info.critic.target_quasimetric_model( | |
critic_batch_info.zy, zgoal, proj_grad_enabled=(True, not self.detach_proj_goal) | |
) | |
if self.detach_target_dist: | |
@@ -382,7 +396,8 @@ class GlobalPushLogLoss(GlobalPushLossBase): | |
offset: float | |
- def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, detach_qmet: bool, step_cost: float, regress_max_future_goal: bool, | |
+ def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, | |
+ detach_qmet: bool, step_cost: float, regress_max_future_goal: bool, | |
offset: float): | |
super().__init__(weight=weight, weight_future_goal=weight_future_goal, detach_goal=detach_goal, | |
detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, regress_max_future_goal=regress_max_future_goal) | |
@@ -439,7 +454,8 @@ class GlobalPushRBFLoss(GlobalPushLossBase): | |
inv_scale: float | |
- def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, detach_qmet: bool, step_cost: float, regress_max_future_goal: bool, | |
+ def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, | |
+ detach_qmet: bool, step_cost: float, regress_max_future_goal: bool, | |
inv_scale: float): | |
super().__init__(weight=weight, weight_future_goal=weight_future_goal, detach_goal=detach_goal, | |
detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, regress_max_future_goal=regress_max_future_goal) | |
diff --git a/quasimetric_rl/modules/quasimetric_critic/losses/latent_dynamics.py b/quasimetric_rl/modules/quasimetric_critic/losses/latent_dynamics.py | |
index 19e4886..15213c6 100644 | |
--- a/quasimetric_rl/modules/quasimetric_critic/losses/latent_dynamics.py | |
+++ b/quasimetric_rl/modules/quasimetric_critic/losses/latent_dynamics.py | |
@@ -88,7 +88,7 @@ class LatentDynamicsLoss(CriticLossBase): | |
self.raw_lagrange_multiplier = nn.Parameter( | |
torch.tensor(softplus_inv_float(init_lagrange_multiplier), dtype=torch.float32)) | |
- def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult: | |
+ def forward_one(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult: | |
critic = critic_batch_info.critic | |
zx = critic_batch_info.zx | |
zy = critic_batch_info.zy | |
@@ -173,6 +173,11 @@ class LatentDynamicsLoss(CriticLossBase): | |
info=info, # type: ignore | |
) | |
+ def forward(self, data: BatchData, critic_batch_infos: Sequence[CriticBatchInfo]) -> LossResult: | |
+ return LossResult.combine({ | |
+ f"critic_{idx:02d}": self.forward_one(data, critic_batch_info) for idx, critic_batch_info in enumerate(critic_batch_infos) | |
+ }) | |
+ | |
def extra_repr(self) -> str: | |
return '\n'.join([ | |
f"kind={self.kind}, gamma={self.gamma!r}", | |
diff --git a/quasimetric_rl/modules/quasimetric_critic/losses/local_constraint.py b/quasimetric_rl/modules/quasimetric_critic/losses/local_constraint.py | |
index c3d6b38..30abbbe 100644 | |
--- a/quasimetric_rl/modules/quasimetric_critic/losses/local_constraint.py | |
+++ b/quasimetric_rl/modules/quasimetric_critic/losses/local_constraint.py | |
@@ -111,7 +111,7 @@ class LocalConstraintLoss(CriticLossBase): | |
self.raw_lagrange_multiplier = nn.Parameter( | |
torch.tensor(softplus_inv_float(init_lagrange_multiplier), dtype=torch.float32)) # type: ignore | |
- def forward(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult: | |
+ def forward_one(self, data: BatchData, critic_batch_info: CriticBatchInfo) -> LossResult: | |
info: Dict[str, Union[float, torch.Tensor]] = {} | |
@@ -198,6 +198,12 @@ class LocalConstraintLoss(CriticLossBase): | |
return LossResult(loss=loss * lagrange_mult, info=info) | |
+ def forward(self, data: BatchData, critic_batch_infos: Sequence[CriticBatchInfo]) -> LossResult: | |
+ return LossResult.combine({ | |
+ f"critic_{idx:02d}": self.forward_one(data, critic_batch_info) for idx, critic_batch_info in enumerate(critic_batch_infos) | |
+ if not critic_batch_info.critic.borrowing_embedding | |
+ }) | |
+ | |
def extra_repr(self) -> str: | |
return '\n'.join([ | |
# f"kind={self.kind}, log={self.log}, batch_reduction={self.batch_reduction}", | |
diff --git a/quasimetric_rl/modules/utils.py b/quasimetric_rl/modules/utils.py | |
index 8fbe219..fa0a42c 100644 | |
--- a/quasimetric_rl/modules/utils.py | |
+++ b/quasimetric_rl/modules/utils.py | |
@@ -31,6 +31,22 @@ InfoT = Union[NestedMapping[Union[float, torch.Tensor]], Mapping[str, 'InfoValT' | |
InfoValT = Union[InfoT, float, torch.Tensor] | |
+ | |
+ | |
+# merge kwargs into info, deep merge | |
+def deep_merge_dict(a: dict, b: Mapping, path=[]): | |
+ # https://stackoverflow.com/a/7205107 | |
+ for key in b: | |
+ if key in a: | |
+ if isinstance(a[key], dict) and isinstance(b[key], Mapping): | |
+ deep_merge_dict(a[key], b[key], path + [str(key)]) | |
+ elif a[key] != b[key]: | |
+ raise Exception('Conflict at ' + '.'.join(path + [str(key)])) | |
+ else: | |
+ a[key] = b[key] | |
+ return a | |
+ | |
+ | |
@attrs.define(kw_only=True) | |
class LossResult: | |
loss: InfoValT | |
@@ -84,11 +100,20 @@ class LossResult: | |
assert isinstance(l, torch.Tensor) | |
return l | |
+ def evolve_info(self, **kwargs) -> 'LossResult': | |
+ info = dict(self.info) | |
+ deep_merge_dict(info, kwargs) | |
+ return LossResult( | |
+ loss=self.loss, | |
+ info=info, | |
+ ) | |
+ | |
@classmethod | |
def combine(cls, results: Mapping[str, 'LossResult'], **kwargs) -> 'LossResult': | |
info = {k: r.info for k, r in results.items()} | |
- assert info.keys().isdisjoint(kwargs.keys()) | |
- info.update(**kwargs) | |
+ | |
+ deep_merge_dict(info, kwargs) | |
+ | |
return LossResult( | |
loss={k: r.loss for k, r in results.items()}, | |
info=info, | |
Submodule third_party/torch-quasimetric contains modified content | |
diff --git a/third_party/torch-quasimetric/torchqmet/iqe.py b/third_party/torch-quasimetric/torchqmet/iqe.py | |
index deeb541..4faa873 100644 | |
--- a/third_party/torch-quasimetric/torchqmet/iqe.py | |
+++ b/third_party/torch-quasimetric/torchqmet/iqe.py | |
@@ -404,8 +404,8 @@ class IQE2(IQE): | |
components = components.reshape(*bshape, self.num_components) | |
- self.last_components = components | |
- return components | |
+ self.last_components = components # type: ignore | |
+ return components # type: ignore | |
def extra_repr(self) -> str: | |
return super().extra_repr() + rf""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment