Created
April 8, 2024 17:19
-
-
Save ssnl/dc9f03d282893cebad9273ec516188b6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py b/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py | |
index b353c39..1c95df3 100644 | |
--- a/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py | |
+++ b/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py | |
@@ -54,6 +54,22 @@ from . import CriticLossBase, CriticBatchInfo | |
# return f"weight={self.weight:g}, softplus_beta={self.softplus_beta:g}, softplus_offset={self.softplus_offset:g}" | |
+@attrs.define(kw_only=True) | |
+class TaskCriticInfo: | |
+ critic_batch_info: CriticBatchInfo | |
+ goal: torch.Tensor | |
+ zgoal: torch.Tensor | |
+ current_dist: torch.Tensor | |
+ target_dist: Optional[torch.Tensor] | |
+ weight: float | |
+ info: Mapping[str, Union[float, torch.Tensor]] | |
+ | |
+ def __attrs_post_init__(self): | |
+ self.info = dict(self.info) | |
+ assert 'weight' not in self.info | |
+ self.info['weight'] = self.weight | |
+ | |
+ | |
class GlobalPushLossBase(CriticLossBase): | |
@attrs.define(kw_only=True) | |
class Conf(abc.ABC): | |
@@ -107,7 +123,8 @@ class GlobalPushLossBase(CriticLossBase): | |
assert not regress_max_future_goal, "regress_max_future_goal is not compatible with mix_in_future_goal" | |
assert weight_future_goal > 0, "weight_future_goal must be positive when mix_in_future_goal is enabled" | |
- def generate_dist_weight(self, data: BatchData, critic_batch_infos: Sequence[CriticBatchInfo]): | |
+ | |
+ def generate_task_infos(self, data: BatchData, critic_batch_infos: Sequence[CriticBatchInfo]): | |
def get_dist(critic_batch_info: CriticBatchInfo, za: torch.Tensor, zb: torch.Tensor): | |
if self.detach_goal: | |
zb = zb.detach() | |
@@ -115,8 +132,9 @@ class GlobalPushLossBase(CriticLossBase): | |
with quasimetric_model.requiring_grad(not self.detach_qmet): | |
return quasimetric_model(za, zb, proj_grad_enabled=(True, not self.detach_proj_goal)) | |
- # (critic_batch_info, goal, zgoal, current_dist, target_dist, weight, info) | |
- tasks: Dict[str, Tuple[CriticBatchInfo, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], float, Dict[str, torch.Tensor]]] = {} | |
+ info: Dict[str, Union[float, torch.Tensor]] | |
+ | |
+ task_critic_infos: Dict[str, TaskCriticInfo] = {} | |
for idx, critic_batch_info in enumerate(critic_batch_infos): | |
if critic_batch_info.critic.borrowing_embedding: | |
continue | |
@@ -148,66 +166,67 @@ class GlobalPushLossBase(CriticLossBase): | |
zgoal, | |
) | |
key = 'mixed_goal' | |
- tasks[f'critic_{idx:02d}'] = ( | |
- critic_batch_info, | |
- goal, | |
- zgoal, | |
- get_dist(critic_batch_info, critic_batch_info.zx, zgoal), | |
- None, | |
- weight, | |
- {}, | |
+ current_dist = get_dist(critic_batch_info, critic_batch_info.zx, zgoal) | |
+ task_critic_infos[f'critic_{idx:02d}'] = TaskCriticInfo( | |
+ critic_batch_info=critic_batch_info, | |
+ goal=goal, | |
+ zgoal=zgoal, | |
+ current_dist=current_dist, | |
+ target_dist=None, | |
+ weight=weight, | |
+ info=dict( | |
+ dist=time_agg(current_dist, self.rho, data.horizon), | |
+ ), | |
) | |
- yield key, tasks | |
+ yield key, task_critic_infos | |
if self.weight_future_goal > 0 and not self.mix_in_future_goal: | |
- tasks = {} | |
+ task_critic_infos = {} | |
for idx, critic_batch_info in enumerate(critic_batch_infos): | |
if critic_batch_info.critic.borrowing_embedding: | |
continue | |
zgoal = critic_batch_info.critic.encoder( | |
data.future_observations, target=critic_batch_info.zy_from_target_encoder) | |
- dist = get_dist(critic_batch_info, critic_batch_info.zx, zgoal) | |
+ current_dist = get_dist(critic_batch_info, critic_batch_info.zx, zgoal) | |
if self.regress_max_future_goal: | |
observed_upper_bound = self.step_cost * data.future_tdelta | |
info = dict( | |
target=observed_upper_bound.mean(), | |
- ratio_future_observed_dist=(dist / observed_upper_bound).mean(), | |
- exceed_future_observed_dist_rate=(dist > observed_upper_bound).mean(dtype=torch.float32), | |
+ ratio_future_observed_dist=(current_dist / observed_upper_bound).mean(), | |
+ exceed_future_observed_dist_rate=(current_dist > observed_upper_bound).mean(dtype=torch.float32), | |
) | |
- target = observed_upper_bound | |
+ target_dist = observed_upper_bound | |
# dist = dist.clamp_max(self.step_cost * data.future_tdelta) | |
else: | |
info = {} | |
- target = None | |
- tasks[f'critic_{idx:02d}'] = ( | |
- critic_batch_info, | |
- data.future_observations, | |
- zgoal, | |
- dist, | |
- target, | |
- self.weight_future_goal, | |
- info, | |
+ target_dist = None | |
+ info.update( | |
+ dist=time_agg(current_dist, self.rho, data.horizon), | |
+ ) | |
+ task_critic_infos[f'critic_{idx:02d}'] = TaskCriticInfo( | |
+ critic_batch_info=critic_batch_info, | |
+ goal=data.future_observations, | |
+ zgoal=zgoal, | |
+ current_dist=current_dist, | |
+ target_dist=target_dist, | |
+ weight=self.weight_future_goal, | |
+ info=info, | |
) | |
- yield 'future_goal', tasks | |
+ yield 'future_goal', task_critic_infos | |
@abc.abstractmethod | |
- def compute_loss_one(self, data: BatchData, critic_batch_info: CriticBatchInfo, | |
- goal: torch.Tensor, zgoal: torch.Tensor, dist: torch.Tensor, | |
- target: Optional[torch.Tensor], weight: float, | |
- info: Mapping[str, torch.Tensor]) -> LossResult: | |
+ def compute_loss_one(self, data: BatchData, task_critic_info: TaskCriticInfo) -> LossResult: | |
raise NotImplementedError | |
- def compute_loss(self, data: BatchData, tasks: Dict[str, Tuple[CriticBatchInfo, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], float, Dict[str, torch.Tensor]]], | |
- **extra_kwargs) -> Dict[str, LossResult]: | |
+ def compute_loss(self, data: BatchData, task_critic_infos: Dict[str, TaskCriticInfo], **extra_kwargs) -> Dict[str, LossResult]: | |
return { | |
- key: self.compute_loss_one(data, critic_batch_info, goal, zgoal, dist, target, weight, info, **extra_kwargs) | |
- for key, (critic_batch_info, goal, zgoal, dist, target, weight, info) in tasks.items() | |
+ key: self.compute_loss_one(data, task, **extra_kwargs) for key, task in task_critic_infos.items() | |
} | |
def forward(self, data: BatchData, critic_batch_infos: Sequence[CriticBatchInfo]) -> LossResult: | |
loss_results: DefaultDict[str, Dict[str, LossResult]] = defaultdict(dict) | |
- for name, tasks in self.generate_dist_weight(data, critic_batch_infos): | |
+ for name, tasks in self.generate_task_infos(data, critic_batch_infos): | |
for critic_desc, loss_result in self.compute_loss(data, tasks).items(): | |
loss_results[critic_desc][name] = loss_result | |
@@ -265,34 +284,35 @@ class GlobalPushLoss(GlobalPushLossBase): | |
self.softplus_offset = softplus_offset | |
self.clamp_max = clamp_max | |
- def compute_loss_one(self, data: BatchData, critic_batch_info: CriticBatchInfo, | |
- goal: torch.Tensor, zgoal: torch.Tensor, dist: torch.Tensor, | |
- target: Optional[torch.Tensor], weight: float, | |
- info: Mapping[str, torch.Tensor]) -> LossResult: | |
- dict_info: Dict[str, torch.Tensor] = dict(info) | |
+ def compute_loss_one(self, data: BatchData, task_critic_info: TaskCriticInfo) -> LossResult: | |
+ dict_info = dict(task_critic_info.info) | |
# Sec 3.2. Transform so that we penalize large distances less. | |
- tsfm_dist: torch.Tensor = F.softplus(self.softplus_offset - dist, beta=self.softplus_beta) # type: ignore | |
+ tsfm_dist = task_critic_info.current_dist | |
+ target_dist = task_critic_info.target_dist | |
if self.clamp_max < float('inf'): | |
tsfm_dist = tsfm_dist.clamp_max(self.clamp_max) | |
dict_info.update( | |
- exceed_rate=(dist >= self.clamp_max).mean(dtype=torch.float32), | |
+ exceed_rate=time_agg(tsfm_dist >= self.clamp_max, self.rho, data.horizon), | |
) | |
- dict_info.update(dist=time_agg(dist, self.rho, data.horizon)) | |
+ tsfm_dist = F.softplus(self.softplus_offset - tsfm_dist, beta=self.softplus_beta) # type: ignore | |
agg_tsfm_dist = time_agg(tsfm_dist, self.rho, data.horizon) | |
dict_info.update(tsfm_dist=agg_tsfm_dist) | |
- if target is None: | |
+ if target_dist is None: | |
loss = agg_tsfm_dist | |
else: | |
- tsfm_target = F.softplus(self.softplus_offset - target, beta=self.softplus_beta) # type: ignore | |
+ tsfm_target = target_dist | |
+ if self.clamp_max < float('inf'): | |
+ tsfm_target = tsfm_target.clamp_max(self.clamp_max) | |
+ tsfm_target = F.softplus(self.softplus_offset - tsfm_target, beta=self.softplus_beta) # type: ignore | |
dict_info.update( | |
- target=time_agg(target, self.rho, data.horizon), | |
+ target=time_agg(target_dist, self.rho, data.horizon), | |
tsfm_target=time_agg(tsfm_target, self.rho, data.horizon), | |
) | |
loss = time_agg(F.l1_loss(tsfm_dist, tsfm_target, reduction='none'), self.rho, data.horizon) | |
- return LossResult(loss=loss * weight, info=dict_info) | |
+ return LossResult(loss=loss * task_critic_info.weight, info=dict_info) | |
def extra_repr(self) -> str: | |
return '\n'.join([ | |
@@ -306,7 +326,7 @@ class GlobalPushLinearLoss(GlobalPushLossBase): | |
class Conf(GlobalPushLossBase.Conf): | |
enabled: bool = False | |
- clamp_max: Optional[float] = attrs.field(default=None, validator=attrs.validators.optional(attrs.validators.gt(0))) | |
+ clamp_max: float = attrs.field(default=float('inf'), validator=attrs.validators.gt(0)) | |
def make(self) -> Optional['GlobalPushLinearLoss']: | |
if not self.enabled: | |
@@ -324,47 +344,44 @@ class GlobalPushLinearLoss(GlobalPushLossBase): | |
clamp_max=self.clamp_max, | |
) | |
- clamp_max: Optional[float] | |
+ clamp_max: float | |
def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, | |
detach_qmet: bool, step_cost: float, rho: float, regress_max_future_goal: bool, mix_in_future_goal: bool, | |
- clamp_max: Optional[float]): | |
+ clamp_max: float): | |
super().__init__(weight=weight, weight_future_goal=weight_future_goal, detach_goal=detach_goal, | |
detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, rho=rho, | |
regress_max_future_goal=regress_max_future_goal, mix_in_future_goal=mix_in_future_goal) | |
self.clamp_max = clamp_max | |
- def compute_loss_one(self, data: BatchData, critic_batch_info: CriticBatchInfo, | |
- goal: torch.Tensor, zgoal: torch.Tensor, dist: torch.Tensor, | |
- target: Optional[torch.Tensor], weight: float, | |
- info: Mapping[str, torch.Tensor]) -> LossResult: | |
- dict_info: Dict[str, torch.Tensor] = dict(info) | |
- dict_info.update(dist=time_agg(dist, self.rho, data.horizon)) | |
- if self.clamp_max is None: | |
- tsfm_dist = -dist | |
- else: | |
- tsfm_dist = -dist.clamp_max(self.clamp_max) | |
+ def compute_loss_one(self, data: BatchData, task_critic_info: TaskCriticInfo) -> LossResult: | |
+ dict_info = dict(task_critic_info.info) | |
+ tsfm_dist = task_critic_info.current_dist | |
+ target_dist = task_critic_info.target_dist | |
+ if self.clamp_max < float('inf'): | |
+ tsfm_dist = tsfm_dist.clamp_max(self.clamp_max) | |
dict_info.update( | |
- exceed_rate=time_agg(dist >= self.clamp_max, self.rho, data.horizon), | |
+ exceed_rate=time_agg(tsfm_dist >= self.clamp_max, self.rho, data.horizon), | |
) | |
+ tsfm_dist = -tsfm_dist | |
agg_tsfm_dist = time_agg(tsfm_dist, self.rho, data.horizon) | |
dict_info.update(tsfm_dist=agg_tsfm_dist) | |
- if target is None: | |
+ if target_dist is None: | |
loss = agg_tsfm_dist | |
else: | |
- if self.clamp_max is not None: | |
- tsfm_target = target.clamp_max(self.clamp_max) | |
- else: | |
- tsfm_target = target | |
+ tsfm_target = target_dist | |
+ if self.clamp_max < float('inf'): | |
+ tsfm_target = tsfm_target.clamp_max(self.clamp_max) | |
+ tsfm_target = -tsfm_target | |
dict_info.update( | |
- target=time_agg(target, self.rho, data.horizon), | |
+ target=time_agg(target_dist, self.rho, data.horizon), | |
tsfm_target=time_agg(tsfm_target, self.rho, data.horizon), | |
) | |
loss = time_agg(F.l1_loss(tsfm_dist, tsfm_target, reduction='none'), self.rho, data.horizon) | |
- return LossResult(loss=loss * weight, info=dict_info) | |
+ return LossResult(loss=loss * task_critic_info.weight, info=dict_info) | |
def extra_repr(self) -> str: | |
return '\n'.join([ | |
@@ -451,59 +468,80 @@ class GlobalPushNextLoss(GlobalPushLossBase): | |
self.regress_step_cost_mul = regress_step_cost_mul | |
self.worst_v_regularizer_weight_mul = worst_v_regularizer_weight_mul | |
- def compute_loss(self, data: BatchData, | |
- tasks: Dict[str, Tuple[CriticBatchInfo, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], float, Dict[str, torch.Tensor]]]) -> Dict[str, LossResult]: | |
+ def compute_loss(self, data: BatchData, task_critic_infos: Dict[str, TaskCriticInfo]) -> Dict[str, LossResult]: | |
# adv = min_Q - avg_V (HILP) | |
qs_for_adv: List[torch.Tensor] = [] # use slack | |
regress_target_qs: Dict[str, torch.Tensor] = {} # no slack | |
- next_dists: Dict[str, torch.Tensor] = {} | |
vs: List[torch.Tensor] = [] | |
- target_encoder_tasks = {} | |
- for critic_desc, (critic_batch_info, goal, zgoal, dist, target, weight, info) in tasks.items(): | |
+ | |
+ target_encoder_task_critic_infos: Dict[str, TaskCriticInfo] = {} | |
+ for critic_desc, task_critic_info in task_critic_infos.items(): | |
+ # expand things that will be overwritten | |
+ critic_batch_info = task_critic_info.critic_batch_info | |
+ info = dict(task_critic_info.info) | |
+ zgoal = task_critic_info.zgoal | |
+ | |
critic = critic_batch_info.critic | |
- info = dict(info) | |
with torch.no_grad(), critic.mode(False): | |
if not critic_batch_info.zx_from_target_encoder and critic.has_separate_target_encoder: | |
zx = critic.encoder(data.observations, target=True) | |
else: | |
zx = critic_batch_info.zx | |
if not critic_batch_info.zy_from_target_encoder and critic.has_separate_target_encoder: | |
- zgoal = critic.encoder(goal, target=True) # goal is encoded in the same way as zy | |
+ zgoal = critic.encoder(task_critic_info.goal, target=True) # goal is encoded in the same way as zy | |
- if target is not None: | |
- next_dist = target - self.step_cost # this is distance, not value | |
- q_dist_is_value = False | |
+ if task_critic_info.target_dist is not None: | |
+ next_est = task_critic_info.target_dist - self.step_cost # this is distance, not value | |
+ next_est_is_value = False | |
else: | |
if not critic_batch_info.zy_from_target_encoder and critic.has_separate_target_encoder: | |
zy = critic.encoder(data.next_observations, target=True) | |
else: | |
zy = critic_batch_info.zy | |
- next_dist = critic.quasimetric_model(zy, zgoal, target=True) | |
- q_dist_is_value = self.dist_is_value | |
+ next_est = critic.quasimetric_model(zy, zgoal, target=True) | |
+ next_est_is_value = self.dist_is_value | |
# Use slack in adv, allowing for some optimism, b/c q+1 <= v for qmet always, if | |
# local constraints are respected. | |
qs_for_adv.append(dist_to_value( | |
- next_dist, self.gamma, self.step_cost, | |
+ next_est, self.gamma, self.step_cost, | |
q_step_cost=self.step_cost - self.slack, | |
- dist_is_value=q_dist_is_value).detach()) | |
- next_dists[critic_desc] = next_dist.detach() | |
+ dist_is_value=next_est_is_value).detach()) | |
+ | |
+ | |
+ actual_dist = -time_agg(dist_to_value( # convert possibly value back to dist | |
+ dist=task_critic_info.current_dist, gamma=None, step_cost=self.step_cost, | |
+ dist_is_value=self.dist_is_value, | |
+ ), self.rho, data.horizon) | |
+ actual_q_dist = -time_agg(dist_to_value( # convert possibly value back to dist | |
+ dist=next_est, gamma=None, step_cost=self.step_cost, | |
+ q_step_cost=self.step_cost, | |
+ dist_is_value=next_est_is_value, | |
+ ), self.rho, data.horizon) | |
+ | |
+ info.update( | |
+ next_dist=time_agg( | |
+ dist_to_value(next_est, None, self.step_cost, dist_is_value=next_est_is_value), | |
+ self.rho, data.horizon, | |
+ ), | |
+ vq_dist_difference=actual_dist - actual_q_dist, | |
+ ) | |
# Regress to a more pessimistic value since qmet gives us the upper bound already | |
regress_target_qs[critic_desc] = dist_to_value( | |
- next_dist, self.gamma, self.step_cost, | |
+ next_est, self.gamma, self.step_cost, | |
q_step_cost=self.step_cost * self.regress_step_cost_mul, | |
- dist_is_value=q_dist_is_value).detach() | |
+ dist_is_value=next_est_is_value).detach() | |
v_dist = critic.quasimetric_model(zx, zgoal, target=True) | |
vs.append(dist_to_value( | |
v_dist, self.gamma, self.step_cost, | |
dist_is_value=self.dist_is_value).detach()) | |
- if FLAGS.DEBUG and critic.has_separate_target_encoder: | |
- non_target_next_dist = critic.quasimetric_model( | |
+ if FLAGS.DEBUG and task_critic_info.target_dist is None and critic.has_separate_target_encoder: | |
+ non_target_next_est = critic.quasimetric_model( | |
critic.encoder(data.next_observations), | |
- critic.encoder(goal), | |
+ critic.encoder(task_critic_info.goal), | |
) | |
( | |
info['tgt-ntgt_next_out_000'], | |
@@ -513,15 +551,18 @@ class GlobalPushNextLoss(GlobalPushLossBase): | |
info['tgt-ntgt_next_out_075'], | |
info['tgt-ntgt_next_out_090'], | |
info['tgt-ntgt_next_out_100'], | |
- ) = (next_dist - non_target_next_dist).float().quantile( | |
- non_target_next_dist.new_tensor([0, 0.1, 0.25, 0.5, 0.75, 0.9, 1], dtype=torch.float32) | |
+ ) = (next_est - non_target_next_est).float().quantile( | |
+ non_target_next_est.new_tensor([0, 0.1, 0.25, 0.5, 0.75, 0.9, 1], dtype=torch.float32) | |
).unbind() | |
- target_encoder_tasks[critic_desc] = ( | |
- attrs.evolve( | |
+ target_encoder_task_critic_infos[critic_desc] = attrs.evolve( | |
+ task_critic_info, | |
+ critic_batch_info=attrs.evolve( | |
critic_batch_info, zx=zx, zy=zy, | |
- zx_from_target_encoder=True, zy_from_target_encoder=True), | |
- goal, zgoal, dist, target, weight, info, | |
+ zx_from_target_encoder=True, zy_from_target_encoder=True, | |
+ ), | |
+ zgoal=zgoal, | |
+ info=info, | |
) | |
tensor_vs = torch.stack(vs, dim=-1) | |
@@ -542,20 +583,16 @@ class GlobalPushNextLoss(GlobalPushLossBase): | |
raise NotImplementedError(self.regress_target_kind) | |
return { | |
- key: self.compute_loss_one(data, critic_batch_info, goal, zgoal, dist, target, weight, info, | |
+ key: self.compute_loss_one(data, task, | |
advantage=advantage, | |
- next_dists=next_dists[key], | |
regress_target_qs=regress_target_qs[key], | |
worst_v=worst_v) | |
- for key, (critic_batch_info, goal, zgoal, dist, target, weight, info) in target_encoder_tasks.items() | |
+ for key, task in target_encoder_task_critic_infos.items() | |
} | |
- def compute_loss_one(self, data: BatchData, critic_batch_info: CriticBatchInfo, | |
- goal: torch.Tensor, zgoal: torch.Tensor, dist: torch.Tensor, | |
- target: Optional[torch.Tensor], weight: float, | |
- info: Mapping[str, torch.Tensor], advantage: torch.Tensor, next_dists: torch.Tensor, | |
+ def compute_loss_one(self, data: BatchData, task_critic_info: TaskCriticInfo, advantage: torch.Tensor, | |
regress_target_qs: torch.Tensor, worst_v: torch.Tensor) -> LossResult: | |
- dict_info: Dict[str, torch.Tensor] = dict(info) | |
+ dict_info = dict(task_critic_info.info) | |
# if target is None: | |
# with torch.set_grad_enabled(self.detach_target_dist), critic_batch_info.critic.mode(False): | |
@@ -588,13 +625,11 @@ class GlobalPushNextLoss(GlobalPushLossBase): | |
# ) | |
dict_info.update( | |
- next_dist=time_agg(next_dists, self.rho, data.horizon), | |
target_q=time_agg(regress_target_qs, self.rho, data.horizon), | |
) | |
target_q = regress_target_qs | |
- | |
- v = dist_to_value(dist, self.gamma, self.step_cost, | |
+ v = dist_to_value(task_critic_info.current_dist, self.gamma, self.step_cost, | |
dist_is_value=self.dist_is_value) | |
if self.allow_gt: | |
@@ -627,24 +662,13 @@ class GlobalPushNextLoss(GlobalPushLossBase): | |
) | |
loss += regularizer * self.worst_v_regularizer_weight_mul | |
- actual_dist = -time_agg(dist_to_value( # convert possibly value back to dist | |
- dist=dist, gamma=None, step_cost=self.step_cost, | |
- dist_is_value=self.dist_is_value, | |
- ), self.rho, data.horizon) | |
- actual_q_dist = -time_agg(dist_to_value( # convert possibly value back to dist | |
- dist=next_dists, gamma=None, step_cost=self.step_cost, | |
- q_step_cost=self.step_cost, | |
- dist_is_value=self.dist_is_value, | |
- ), self.rho, data.horizon) | |
dict_info.update( | |
iql_loss=iql_loss, | |
- dist=actual_dist, | |
v=time_agg(v, self.rho, data.horizon), | |
vq_difference=time_agg(q_difference, self.rho, data.horizon), # (v - q) | |
- vq_dist_difference=actual_dist - actual_q_dist, | |
) | |
- return LossResult(loss=loss * weight, info=dict_info) | |
+ return LossResult(loss=loss * task_critic_info.weight, info=dict_info) | |
def extra_repr(self) -> str: | |
return '\n'.join([ | |
@@ -695,27 +719,26 @@ class GlobalPushLogLoss(GlobalPushLossBase): | |
self.offset = offset | |
- def compute_loss_one(self, data: BatchData, critic_batch_info: CriticBatchInfo, | |
- goal: torch.Tensor, zgoal: torch.Tensor, dist: torch.Tensor, | |
- target: Optional[torch.Tensor], weight: float, | |
- info: Mapping[str, torch.Tensor]) -> LossResult: | |
- dict_info: Dict[str, torch.Tensor] = dict(info) | |
- tsfm_dist: torch.Tensor = -dist.add(self.offset).log() | |
- dict_info.update(dist=time_agg(dist, self.rho, data.horizon)) | |
+ def compute_loss_one(self, data: BatchData, task_critic_info: TaskCriticInfo) -> LossResult: | |
+ dict_info = dict(task_critic_info.info) | |
+ tsfm_dist = task_critic_info.current_dist | |
+ target_dist = task_critic_info.target_dist | |
+ | |
+ tsfm_dist: torch.Tensor = -tsfm_dist.add(self.offset).log() | |
agg_tsfm_dist = time_agg(tsfm_dist, self.rho, data.horizon) | |
dict_info.update(tsfm_dist=agg_tsfm_dist) | |
- if target is None: | |
+ if target_dist is None: | |
loss = agg_tsfm_dist | |
else: | |
- tsfm_target = -target.add(self.offset).log() | |
+ tsfm_target = -target_dist.add(self.offset).log() | |
dict_info.update( | |
- target=time_agg(target, self.rho, data.horizon), | |
+ target=time_agg(target_dist, self.rho, data.horizon), | |
tsfm_target=time_agg(tsfm_target, self.rho, data.horizon), | |
) | |
loss = time_agg(F.l1_loss(tsfm_dist, tsfm_target, reduction='none'), self.rho, data.horizon) | |
- return LossResult(loss=loss * weight, info=dict_info) | |
+ return LossResult(loss=loss * task_critic_info.weight, info=dict_info) | |
def extra_repr(self) -> str: | |
return '\n'.join([ | |
@@ -761,34 +784,39 @@ class GlobalPushRBFLoss(GlobalPushLossBase): | |
regress_max_future_goal=regress_max_future_goal, mix_in_future_goal=mix_in_future_goal) | |
self.inv_scale = inv_scale | |
- def compute_loss_one(self, data: BatchData, critic_batch_info: CriticBatchInfo, | |
- goal: torch.Tensor, zgoal: torch.Tensor, dist: torch.Tensor, | |
- target: Optional[torch.Tensor], weight: float, | |
- info: Mapping[str, torch.Tensor]) -> LossResult: | |
- dict_info: Dict[str, torch.Tensor] = dict(info) | |
- inv_scale = time_agg(dist.detach().square(), self.rho, data.horizon).div(2).sqrt().clamp(1e-3, self.inv_scale) # make E[d^2]/r^2 approx 2 | |
- tsfm_dist: torch.Tensor = time_agg((dist / inv_scale).square().neg().exp(), self.rho, data.horizon) | |
+ def compute_loss_one(self, data: BatchData, task_critic_info: TaskCriticInfo) -> LossResult: | |
+ dict_info = dict(task_critic_info.info) | |
+ current_dist = task_critic_info.current_dist | |
+ target_dist = task_critic_info.target_dist | |
+ | |
+ inv_scale = time_agg(current_dist.detach().square(), self.rho, data.horizon).div(2).sqrt().clamp(1e-3, self.inv_scale) # make E[d^2]/r^2 approx 2 | |
+ tsfm_dist: torch.Tensor = time_agg( | |
+ (current_dist / inv_scale).square().neg().exp(), | |
+ self.rho, data.horizon, | |
+ ) | |
rbf_potential = tsfm_dist.log() | |
dict_info.update( | |
- dist=time_agg(dist, self.rho, data.horizon), | |
inv_scale=inv_scale, | |
tsfm_dist=tsfm_dist, | |
rbf_potential=rbf_potential, | |
) | |
- if target is None: | |
+ if target_dist is None: | |
loss = rbf_potential | |
else: | |
- tsfm_target = time_agg((target / inv_scale).square().neg().exp(), self.rho, data.horizon) | |
+ tsfm_target = time_agg( | |
+ (target_dist / inv_scale).square().neg().exp(), | |
+ self.rho, data.horizon, | |
+ ) | |
target_rbf_potential = tsfm_target.log() | |
dict_info.update( | |
- target=time_agg(target, self.rho, data.horizon), | |
+ target=time_agg(target_dist, self.rho, data.horizon), | |
tsfm_target=tsfm_target, | |
target_rbf_potential=target_rbf_potential, | |
) | |
loss = F.l1_loss(rbf_potential, target_rbf_potential) | |
- return LossResult(loss=loss * weight, info=dict_info) | |
+ return LossResult(loss=loss * task_critic_info.weight, info=dict_info) | |
def extra_repr(self) -> str: | |
return '\n'.join([ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment