Skip to content

Instantly share code, notes, and snippets.

@ssnl
Created April 8, 2024 16:53
Show Gist options
  • Save ssnl/89c9462279e4bdb39f65e51a93511ab6 to your computer and use it in GitHub Desktop.
Save ssnl/89c9462279e4bdb39f65e51a93511ab6 to your computer and use it in GitHub Desktop.
diff --git a/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py b/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py
index b353c39..1ffe375 100644
--- a/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py
+++ b/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py
@@ -54,6 +54,17 @@ from . import CriticLossBase, CriticBatchInfo
# return f"weight={self.weight:g}, softplus_beta={self.softplus_beta:g}, softplus_offset={self.softplus_offset:g}"
+@attrs.define(kw_only=True)
+class Task:
+ critic_batch_info: CriticBatchInfo
+ goal: torch.Tensor
+ zgoal: torch.Tensor
+ current_dist: torch.Tensor
+ target_dist: Optional[torch.Tensor]
+ weight: float
+ info: Mapping[str, Union[float, torch.Tensor]]
+
+
class GlobalPushLossBase(CriticLossBase):
@attrs.define(kw_only=True)
class Conf(abc.ABC):
@@ -107,7 +118,8 @@ class GlobalPushLossBase(CriticLossBase):
assert not regress_max_future_goal, "regress_max_future_goal is not compatible with mix_in_future_goal"
assert weight_future_goal > 0, "weight_future_goal must be positive when mix_in_future_goal is enabled"
- def generate_dist_weight(self, data: BatchData, critic_batch_infos: Sequence[CriticBatchInfo]):
+
+ def generate_tasks(self, data: BatchData, critic_batch_infos: Sequence[CriticBatchInfo]):
def get_dist(critic_batch_info: CriticBatchInfo, za: torch.Tensor, zb: torch.Tensor):
if self.detach_goal:
zb = zb.detach()
@@ -115,8 +127,9 @@ class GlobalPushLossBase(CriticLossBase):
with quasimetric_model.requiring_grad(not self.detach_qmet):
return quasimetric_model(za, zb, proj_grad_enabled=(True, not self.detach_proj_goal))
- # (critic_batch_info, goal, zgoal, current_dist, target_dist, weight, info)
- tasks: Dict[str, Tuple[CriticBatchInfo, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], float, Dict[str, torch.Tensor]]] = {}
+ info: Dict[str, Union[float, torch.Tensor]]
+
+ tasks: Dict[str, Task] = {}
for idx, critic_batch_info in enumerate(critic_batch_infos):
if critic_batch_info.critic.borrowing_embedding:
continue
@@ -148,14 +161,17 @@ class GlobalPushLossBase(CriticLossBase):
zgoal,
)
key = 'mixed_goal'
- tasks[f'critic_{idx:02d}'] = (
- critic_batch_info,
- goal,
- zgoal,
- get_dist(critic_batch_info, critic_batch_info.zx, zgoal),
- None,
- weight,
- {},
+ current_dist = get_dist(critic_batch_info, critic_batch_info.zx, zgoal)
+ tasks[f'critic_{idx:02d}'] = Task(
+ critic_batch_info=critic_batch_info,
+ goal=goal,
+ zgoal=zgoal,
+ current_dist=current_dist,
+ target_dist=None,
+ weight=weight,
+ info=dict(
+ dist=time_agg(current_dist, self.rho, data.horizon),
+ ),
)
yield key, tasks
@@ -166,48 +182,46 @@ class GlobalPushLossBase(CriticLossBase):
continue
zgoal = critic_batch_info.critic.encoder(
data.future_observations, target=critic_batch_info.zy_from_target_encoder)
- dist = get_dist(critic_batch_info, critic_batch_info.zx, zgoal)
+ current_dist = get_dist(critic_batch_info, critic_batch_info.zx, zgoal)
if self.regress_max_future_goal:
observed_upper_bound = self.step_cost * data.future_tdelta
info = dict(
target=observed_upper_bound.mean(),
- ratio_future_observed_dist=(dist / observed_upper_bound).mean(),
- exceed_future_observed_dist_rate=(dist > observed_upper_bound).mean(dtype=torch.float32),
+ ratio_future_observed_dist=(current_dist / observed_upper_bound).mean(),
+ exceed_future_observed_dist_rate=(current_dist > observed_upper_bound).mean(dtype=torch.float32),
)
- target = observed_upper_bound
+ target_dist = observed_upper_bound
# dist = dist.clamp_max(self.step_cost * data.future_tdelta)
else:
info = {}
- target = None
- tasks[f'critic_{idx:02d}'] = (
- critic_batch_info,
- data.future_observations,
- zgoal,
- dist,
- target,
- self.weight_future_goal,
- info,
+ target_dist = None
+ info.update(
+ dist=time_agg(current_dist, self.rho, data.horizon),
+ )
+ tasks[f'critic_{idx:02d}'] = Task(
+ critic_batch_info=critic_batch_info,
+ goal=data.future_observations,
+ zgoal=zgoal,
+ current_dist=current_dist,
+ target_dist=target_dist,
+ weight=self.weight_future_goal,
+ info=info,
)
yield 'future_goal', tasks
@abc.abstractmethod
- def compute_loss_one(self, data: BatchData, critic_batch_info: CriticBatchInfo,
- goal: torch.Tensor, zgoal: torch.Tensor, dist: torch.Tensor,
- target: Optional[torch.Tensor], weight: float,
- info: Mapping[str, torch.Tensor]) -> LossResult:
+ def compute_loss_one(self, data: BatchData, task: Task) -> LossResult:
raise NotImplementedError
- def compute_loss(self, data: BatchData, tasks: Dict[str, Tuple[CriticBatchInfo, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], float, Dict[str, torch.Tensor]]],
- **extra_kwargs) -> Dict[str, LossResult]:
+ def compute_loss(self, data: BatchData, tasks: Dict[str, Task], **extra_kwargs) -> Dict[str, LossResult]:
return {
- key: self.compute_loss_one(data, critic_batch_info, goal, zgoal, dist, target, weight, info, **extra_kwargs)
- for key, (critic_batch_info, goal, zgoal, dist, target, weight, info) in tasks.items()
+ key: self.compute_loss_one(data, task, **extra_kwargs) for key, task in tasks.items()
}
def forward(self, data: BatchData, critic_batch_infos: Sequence[CriticBatchInfo]) -> LossResult:
loss_results: DefaultDict[str, Dict[str, LossResult]] = defaultdict(dict)
- for name, tasks in self.generate_dist_weight(data, critic_batch_infos):
+ for name, tasks in self.generate_tasks(data, critic_batch_infos):
for critic_desc, loss_result in self.compute_loss(data, tasks).items():
loss_results[critic_desc][name] = loss_result
@@ -265,34 +279,35 @@ class GlobalPushLoss(GlobalPushLossBase):
self.softplus_offset = softplus_offset
self.clamp_max = clamp_max
- def compute_loss_one(self, data: BatchData, critic_batch_info: CriticBatchInfo,
- goal: torch.Tensor, zgoal: torch.Tensor, dist: torch.Tensor,
- target: Optional[torch.Tensor], weight: float,
- info: Mapping[str, torch.Tensor]) -> LossResult:
- dict_info: Dict[str, torch.Tensor] = dict(info)
+ def compute_loss_one(self, data: BatchData, task: Task) -> LossResult:
+ dict_info = dict(task.info)
# Sec 3.2. Transform so that we penalize large distances less.
- tsfm_dist: torch.Tensor = F.softplus(self.softplus_offset - dist, beta=self.softplus_beta) # type: ignore
+ tsfm_dist = task.current_dist
+ target_dist = task.target_dist
if self.clamp_max < float('inf'):
tsfm_dist = tsfm_dist.clamp_max(self.clamp_max)
dict_info.update(
- exceed_rate=(dist >= self.clamp_max).mean(dtype=torch.float32),
+ exceed_rate=time_agg(tsfm_dist >= self.clamp_max, self.rho, data.horizon),
)
- dict_info.update(dist=time_agg(dist, self.rho, data.horizon))
+ tsfm_dist = F.softplus(self.softplus_offset - tsfm_dist, beta=self.softplus_beta) # type: ignore
agg_tsfm_dist = time_agg(tsfm_dist, self.rho, data.horizon)
dict_info.update(tsfm_dist=agg_tsfm_dist)
- if target is None:
+ if target_dist is None:
loss = agg_tsfm_dist
else:
- tsfm_target = F.softplus(self.softplus_offset - target, beta=self.softplus_beta) # type: ignore
+ tsfm_target = target_dist
+ if self.clamp_max < float('inf'):
+ tsfm_target = tsfm_target.clamp_max(self.clamp_max)
+ tsfm_target = F.softplus(self.softplus_offset - tsfm_target, beta=self.softplus_beta) # type: ignore
dict_info.update(
- target=time_agg(target, self.rho, data.horizon),
+ target=time_agg(target_dist, self.rho, data.horizon),
tsfm_target=time_agg(tsfm_target, self.rho, data.horizon),
)
loss = time_agg(F.l1_loss(tsfm_dist, tsfm_target, reduction='none'), self.rho, data.horizon)
- return LossResult(loss=loss * weight, info=dict_info)
+ return LossResult(loss=loss * task.weight, info=dict_info)
def extra_repr(self) -> str:
return '\n'.join([
@@ -306,7 +321,7 @@ class GlobalPushLinearLoss(GlobalPushLossBase):
class Conf(GlobalPushLossBase.Conf):
enabled: bool = False
- clamp_max: Optional[float] = attrs.field(default=None, validator=attrs.validators.optional(attrs.validators.gt(0)))
+ clamp_max: float = attrs.field(default=float('inf'), validator=attrs.validators.gt(0))
def make(self) -> Optional['GlobalPushLinearLoss']:
if not self.enabled:
@@ -324,47 +339,44 @@ class GlobalPushLinearLoss(GlobalPushLossBase):
clamp_max=self.clamp_max,
)
- clamp_max: Optional[float]
+ clamp_max: float
def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool,
detach_qmet: bool, step_cost: float, rho: float, regress_max_future_goal: bool, mix_in_future_goal: bool,
- clamp_max: Optional[float]):
+ clamp_max: float):
super().__init__(weight=weight, weight_future_goal=weight_future_goal, detach_goal=detach_goal,
detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, rho=rho,
regress_max_future_goal=regress_max_future_goal, mix_in_future_goal=mix_in_future_goal)
self.clamp_max = clamp_max
- def compute_loss_one(self, data: BatchData, critic_batch_info: CriticBatchInfo,
- goal: torch.Tensor, zgoal: torch.Tensor, dist: torch.Tensor,
- target: Optional[torch.Tensor], weight: float,
- info: Mapping[str, torch.Tensor]) -> LossResult:
- dict_info: Dict[str, torch.Tensor] = dict(info)
- dict_info.update(dist=time_agg(dist, self.rho, data.horizon))
- if self.clamp_max is None:
- tsfm_dist = -dist
- else:
- tsfm_dist = -dist.clamp_max(self.clamp_max)
+ def compute_loss_one(self, data: BatchData, task: Task) -> LossResult:
+ dict_info = dict(task.info)
+ tsfm_dist = task.current_dist
+ target_dist = task.target_dist
+ if self.clamp_max < float('inf'):
+ tsfm_dist = tsfm_dist.clamp_max(self.clamp_max)
dict_info.update(
- exceed_rate=time_agg(dist >= self.clamp_max, self.rho, data.horizon),
+ exceed_rate=time_agg(tsfm_dist >= self.clamp_max, self.rho, data.horizon),
)
+ tsfm_dist = -tsfm_dist
agg_tsfm_dist = time_agg(tsfm_dist, self.rho, data.horizon)
dict_info.update(tsfm_dist=agg_tsfm_dist)
- if target is None:
+ if target_dist is None:
loss = agg_tsfm_dist
else:
- if self.clamp_max is not None:
- tsfm_target = target.clamp_max(self.clamp_max)
- else:
- tsfm_target = target
+ tsfm_target = target_dist
+ if self.clamp_max < float('inf'):
+ tsfm_target = tsfm_target.clamp_max(self.clamp_max)
+ tsfm_target = -tsfm_target
dict_info.update(
- target=time_agg(target, self.rho, data.horizon),
+ target=time_agg(target_dist, self.rho, data.horizon),
tsfm_target=time_agg(tsfm_target, self.rho, data.horizon),
)
loss = time_agg(F.l1_loss(tsfm_dist, tsfm_target, reduction='none'), self.rho, data.horizon)
- return LossResult(loss=loss * weight, info=dict_info)
+ return LossResult(loss=loss * task.weight, info=dict_info)
def extra_repr(self) -> str:
return '\n'.join([
@@ -451,27 +463,32 @@ class GlobalPushNextLoss(GlobalPushLossBase):
self.regress_step_cost_mul = regress_step_cost_mul
self.worst_v_regularizer_weight_mul = worst_v_regularizer_weight_mul
- def compute_loss(self, data: BatchData,
- tasks: Dict[str, Tuple[CriticBatchInfo, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], float, Dict[str, torch.Tensor]]]) -> Dict[str, LossResult]:
+ def compute_loss(self, data: BatchData, tasks: Dict[str, Task]) -> Dict[str, LossResult]:
# adv = min_Q - avg_V (HILP)
qs_for_adv: List[torch.Tensor] = [] # use slack
regress_target_qs: Dict[str, torch.Tensor] = {} # no slack
next_dists: Dict[str, torch.Tensor] = {}
vs: List[torch.Tensor] = []
- target_encoder_tasks = {}
- for critic_desc, (critic_batch_info, goal, zgoal, dist, target, weight, info) in tasks.items():
+
+ target_encoder_tasks: Dict[str, Task] = {}
+ for critic_desc, task in tasks.items():
+ # expand things that will be overwritten
+ critic_batch_info = task.critic_batch_info
+ info = dict(task.info)
+ target_dist = task.target_dist
+ zgoal = task.zgoal
+
critic = critic_batch_info.critic
- info = dict(info)
with torch.no_grad(), critic.mode(False):
if not critic_batch_info.zx_from_target_encoder and critic.has_separate_target_encoder:
zx = critic.encoder(data.observations, target=True)
else:
zx = critic_batch_info.zx
if not critic_batch_info.zy_from_target_encoder and critic.has_separate_target_encoder:
- zgoal = critic.encoder(goal, target=True) # goal is encoded in the same way as zy
+ zgoal = critic.encoder(task.goal, target=True) # goal is encoded in the same way as zy
- if target is not None:
- next_dist = target - self.step_cost # this is distance, not value
+ if target_dist is not None:
+ next_dist = target_dist - self.step_cost # this is distance, not value
q_dist_is_value = False
else:
if not critic_batch_info.zy_from_target_encoder and critic.has_separate_target_encoder:
@@ -503,7 +520,7 @@ class GlobalPushNextLoss(GlobalPushLossBase):
if FLAGS.DEBUG and critic.has_separate_target_encoder:
non_target_next_dist = critic.quasimetric_model(
critic.encoder(data.next_observations),
- critic.encoder(goal),
+ critic.encoder(task.goal),
)
(
info['tgt-ntgt_next_out_000'],
@@ -517,11 +534,15 @@ class GlobalPushNextLoss(GlobalPushLossBase):
non_target_next_dist.new_tensor([0, 0.1, 0.25, 0.5, 0.75, 0.9, 1], dtype=torch.float32)
).unbind()
- target_encoder_tasks[critic_desc] = (
- attrs.evolve(
+ target_encoder_tasks[critic_desc] = attrs.evolve(
+ task,
+ critic_batch_info=attrs.evolve(
critic_batch_info, zx=zx, zy=zy,
- zx_from_target_encoder=True, zy_from_target_encoder=True),
- goal, zgoal, dist, target, weight, info,
+ zx_from_target_encoder=True, zy_from_target_encoder=True,
+ ),
+ zgoal=zgoal,
+ target_dist=target_dist,
+ info=info,
)
tensor_vs = torch.stack(vs, dim=-1)
@@ -542,20 +563,17 @@ class GlobalPushNextLoss(GlobalPushLossBase):
raise NotImplementedError(self.regress_target_kind)
return {
- key: self.compute_loss_one(data, critic_batch_info, goal, zgoal, dist, target, weight, info,
+ key: self.compute_loss_one(data, task,
advantage=advantage,
next_dists=next_dists[key],
regress_target_qs=regress_target_qs[key],
worst_v=worst_v)
- for key, (critic_batch_info, goal, zgoal, dist, target, weight, info) in target_encoder_tasks.items()
+ for key, task in target_encoder_tasks.items()
}
- def compute_loss_one(self, data: BatchData, critic_batch_info: CriticBatchInfo,
- goal: torch.Tensor, zgoal: torch.Tensor, dist: torch.Tensor,
- target: Optional[torch.Tensor], weight: float,
- info: Mapping[str, torch.Tensor], advantage: torch.Tensor, next_dists: torch.Tensor,
+ def compute_loss_one(self, data: BatchData, task: Task, advantage: torch.Tensor, next_dists: torch.Tensor,
regress_target_qs: torch.Tensor, worst_v: torch.Tensor) -> LossResult:
- dict_info: Dict[str, torch.Tensor] = dict(info)
+ dict_info = dict(task.info)
# if target is None:
# with torch.set_grad_enabled(self.detach_target_dist), critic_batch_info.critic.mode(False):
@@ -593,8 +611,7 @@ class GlobalPushNextLoss(GlobalPushLossBase):
)
target_q = regress_target_qs
-
- v = dist_to_value(dist, self.gamma, self.step_cost,
+ v = dist_to_value(task.current_dist, self.gamma, self.step_cost,
dist_is_value=self.dist_is_value)
if self.allow_gt:
@@ -628,7 +645,7 @@ class GlobalPushNextLoss(GlobalPushLossBase):
loss += regularizer * self.worst_v_regularizer_weight_mul
actual_dist = -time_agg(dist_to_value( # convert possibly value back to dist
- dist=dist, gamma=None, step_cost=self.step_cost,
+ dist=task.current_dist, gamma=None, step_cost=self.step_cost,
dist_is_value=self.dist_is_value,
), self.rho, data.horizon)
actual_q_dist = -time_agg(dist_to_value( # convert possibly value back to dist
@@ -638,13 +655,13 @@ class GlobalPushNextLoss(GlobalPushLossBase):
), self.rho, data.horizon)
dict_info.update(
iql_loss=iql_loss,
- dist=actual_dist,
+ actual_dist=actual_dist,
v=time_agg(v, self.rho, data.horizon),
vq_difference=time_agg(q_difference, self.rho, data.horizon), # (v - q)
vq_dist_difference=actual_dist - actual_q_dist,
)
- return LossResult(loss=loss * weight, info=dict_info)
+ return LossResult(loss=loss * task.weight, info=dict_info)
def extra_repr(self) -> str:
return '\n'.join([
@@ -695,27 +712,26 @@ class GlobalPushLogLoss(GlobalPushLossBase):
self.offset = offset
- def compute_loss_one(self, data: BatchData, critic_batch_info: CriticBatchInfo,
- goal: torch.Tensor, zgoal: torch.Tensor, dist: torch.Tensor,
- target: Optional[torch.Tensor], weight: float,
- info: Mapping[str, torch.Tensor]) -> LossResult:
- dict_info: Dict[str, torch.Tensor] = dict(info)
- tsfm_dist: torch.Tensor = -dist.add(self.offset).log()
- dict_info.update(dist=time_agg(dist, self.rho, data.horizon))
+ def compute_loss_one(self, data: BatchData, task: Task) -> LossResult:
+ dict_info = dict(task.info)
+ tsfm_dist = task.current_dist
+ target_dist = task.target_dist
+
+ tsfm_dist: torch.Tensor = -tsfm_dist.add(self.offset).log()
agg_tsfm_dist = time_agg(tsfm_dist, self.rho, data.horizon)
dict_info.update(tsfm_dist=agg_tsfm_dist)
- if target is None:
+ if target_dist is None:
loss = agg_tsfm_dist
else:
- tsfm_target = -target.add(self.offset).log()
+ tsfm_target = -target_dist.add(self.offset).log()
dict_info.update(
- target=time_agg(target, self.rho, data.horizon),
+ target=time_agg(target_dist, self.rho, data.horizon),
tsfm_target=time_agg(tsfm_target, self.rho, data.horizon),
)
loss = time_agg(F.l1_loss(tsfm_dist, tsfm_target, reduction='none'), self.rho, data.horizon)
- return LossResult(loss=loss * weight, info=dict_info)
+ return LossResult(loss=loss * task.weight, info=dict_info)
def extra_repr(self) -> str:
return '\n'.join([
@@ -761,34 +777,39 @@ class GlobalPushRBFLoss(GlobalPushLossBase):
regress_max_future_goal=regress_max_future_goal, mix_in_future_goal=mix_in_future_goal)
self.inv_scale = inv_scale
- def compute_loss_one(self, data: BatchData, critic_batch_info: CriticBatchInfo,
- goal: torch.Tensor, zgoal: torch.Tensor, dist: torch.Tensor,
- target: Optional[torch.Tensor], weight: float,
- info: Mapping[str, torch.Tensor]) -> LossResult:
- dict_info: Dict[str, torch.Tensor] = dict(info)
- inv_scale = time_agg(dist.detach().square(), self.rho, data.horizon).div(2).sqrt().clamp(1e-3, self.inv_scale) # make E[d^2]/r^2 approx 2
- tsfm_dist: torch.Tensor = time_agg((dist / inv_scale).square().neg().exp(), self.rho, data.horizon)
+ def compute_loss_one(self, data: BatchData, task: Task) -> LossResult:
+ dict_info = dict(task.info)
+ current_dist = task.current_dist
+ target_dist = task.target_dist
+
+ inv_scale = time_agg(current_dist.detach().square(), self.rho, data.horizon).div(2).sqrt().clamp(1e-3, self.inv_scale) # make E[d^2]/r^2 approx 2
+ tsfm_dist: torch.Tensor = time_agg(
+ (current_dist / inv_scale).square().neg().exp(),
+ self.rho, data.horizon,
+ )
rbf_potential = tsfm_dist.log()
dict_info.update(
- dist=time_agg(dist, self.rho, data.horizon),
inv_scale=inv_scale,
tsfm_dist=tsfm_dist,
rbf_potential=rbf_potential,
)
- if target is None:
+ if target_dist is None:
loss = rbf_potential
else:
- tsfm_target = time_agg((target / inv_scale).square().neg().exp(), self.rho, data.horizon)
+ tsfm_target = time_agg(
+ (target_dist / inv_scale).square().neg().exp(),
+ self.rho, data.horizon,
+ )
target_rbf_potential = tsfm_target.log()
dict_info.update(
- target=time_agg(target, self.rho, data.horizon),
+ target=time_agg(target_dist, self.rho, data.horizon),
tsfm_target=tsfm_target,
target_rbf_potential=target_rbf_potential,
)
loss = F.l1_loss(rbf_potential, target_rbf_potential)
- return LossResult(loss=loss * weight, info=dict_info)
+ return LossResult(loss=loss * task.weight, info=dict_info)
def extra_repr(self) -> str:
return '\n'.join([
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment