Created
April 8, 2024 16:53
-
-
Save ssnl/89c9462279e4bdb39f65e51a93511ab6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py b/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py | |
index b353c39..1ffe375 100644 | |
--- a/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py | |
+++ b/quasimetric_rl/modules/quasimetric_critic/losses/global_push.py | |
@@ -54,6 +54,17 @@ from . import CriticLossBase, CriticBatchInfo | |
# return f"weight={self.weight:g}, softplus_beta={self.softplus_beta:g}, softplus_offset={self.softplus_offset:g}" | |
+@attrs.define(kw_only=True) | |
+class Task: | |
+ critic_batch_info: CriticBatchInfo | |
+ goal: torch.Tensor | |
+ zgoal: torch.Tensor | |
+ current_dist: torch.Tensor | |
+ target_dist: Optional[torch.Tensor] | |
+ weight: float | |
+ info: Mapping[str, Union[float, torch.Tensor]] | |
+ | |
+ | |
class GlobalPushLossBase(CriticLossBase): | |
@attrs.define(kw_only=True) | |
class Conf(abc.ABC): | |
@@ -107,7 +118,8 @@ class GlobalPushLossBase(CriticLossBase): | |
assert not regress_max_future_goal, "regress_max_future_goal is not compatible with mix_in_future_goal" | |
assert weight_future_goal > 0, "weight_future_goal must be positive when mix_in_future_goal is enabled" | |
- def generate_dist_weight(self, data: BatchData, critic_batch_infos: Sequence[CriticBatchInfo]): | |
+ | |
+ def generate_tasks(self, data: BatchData, critic_batch_infos: Sequence[CriticBatchInfo]): | |
def get_dist(critic_batch_info: CriticBatchInfo, za: torch.Tensor, zb: torch.Tensor): | |
if self.detach_goal: | |
zb = zb.detach() | |
@@ -115,8 +127,9 @@ class GlobalPushLossBase(CriticLossBase): | |
with quasimetric_model.requiring_grad(not self.detach_qmet): | |
return quasimetric_model(za, zb, proj_grad_enabled=(True, not self.detach_proj_goal)) | |
- # (critic_batch_info, goal, zgoal, current_dist, target_dist, weight, info) | |
- tasks: Dict[str, Tuple[CriticBatchInfo, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], float, Dict[str, torch.Tensor]]] = {} | |
+ info: Dict[str, Union[float, torch.Tensor]] | |
+ | |
+ tasks: Dict[str, Task] = {} | |
for idx, critic_batch_info in enumerate(critic_batch_infos): | |
if critic_batch_info.critic.borrowing_embedding: | |
continue | |
@@ -148,14 +161,17 @@ class GlobalPushLossBase(CriticLossBase): | |
zgoal, | |
) | |
key = 'mixed_goal' | |
- tasks[f'critic_{idx:02d}'] = ( | |
- critic_batch_info, | |
- goal, | |
- zgoal, | |
- get_dist(critic_batch_info, critic_batch_info.zx, zgoal), | |
- None, | |
- weight, | |
- {}, | |
+ current_dist = get_dist(critic_batch_info, critic_batch_info.zx, zgoal) | |
+ tasks[f'critic_{idx:02d}'] = Task( | |
+ critic_batch_info=critic_batch_info, | |
+ goal=goal, | |
+ zgoal=zgoal, | |
+ current_dist=current_dist, | |
+ target_dist=None, | |
+ weight=weight, | |
+ info=dict( | |
+ dist=time_agg(current_dist, self.rho, data.horizon), | |
+ ), | |
) | |
yield key, tasks | |
@@ -166,48 +182,46 @@ class GlobalPushLossBase(CriticLossBase): | |
continue | |
zgoal = critic_batch_info.critic.encoder( | |
data.future_observations, target=critic_batch_info.zy_from_target_encoder) | |
- dist = get_dist(critic_batch_info, critic_batch_info.zx, zgoal) | |
+ current_dist = get_dist(critic_batch_info, critic_batch_info.zx, zgoal) | |
if self.regress_max_future_goal: | |
observed_upper_bound = self.step_cost * data.future_tdelta | |
info = dict( | |
target=observed_upper_bound.mean(), | |
- ratio_future_observed_dist=(dist / observed_upper_bound).mean(), | |
- exceed_future_observed_dist_rate=(dist > observed_upper_bound).mean(dtype=torch.float32), | |
+ ratio_future_observed_dist=(current_dist / observed_upper_bound).mean(), | |
+ exceed_future_observed_dist_rate=(current_dist > observed_upper_bound).mean(dtype=torch.float32), | |
) | |
- target = observed_upper_bound | |
+ target_dist = observed_upper_bound | |
# dist = dist.clamp_max(self.step_cost * data.future_tdelta) | |
else: | |
info = {} | |
- target = None | |
- tasks[f'critic_{idx:02d}'] = ( | |
- critic_batch_info, | |
- data.future_observations, | |
- zgoal, | |
- dist, | |
- target, | |
- self.weight_future_goal, | |
- info, | |
+ target_dist = None | |
+ info.update( | |
+ dist=time_agg(current_dist, self.rho, data.horizon), | |
+ ) | |
+ tasks[f'critic_{idx:02d}'] = Task( | |
+ critic_batch_info=critic_batch_info, | |
+ goal=data.future_observations, | |
+ zgoal=zgoal, | |
+ current_dist=current_dist, | |
+ target_dist=target_dist, | |
+ weight=self.weight_future_goal, | |
+ info=info, | |
) | |
yield 'future_goal', tasks | |
@abc.abstractmethod | |
- def compute_loss_one(self, data: BatchData, critic_batch_info: CriticBatchInfo, | |
- goal: torch.Tensor, zgoal: torch.Tensor, dist: torch.Tensor, | |
- target: Optional[torch.Tensor], weight: float, | |
- info: Mapping[str, torch.Tensor]) -> LossResult: | |
+ def compute_loss_one(self, data: BatchData, task: Task) -> LossResult: | |
raise NotImplementedError | |
- def compute_loss(self, data: BatchData, tasks: Dict[str, Tuple[CriticBatchInfo, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], float, Dict[str, torch.Tensor]]], | |
- **extra_kwargs) -> Dict[str, LossResult]: | |
+ def compute_loss(self, data: BatchData, tasks: Dict[str, Task], **extra_kwargs) -> Dict[str, LossResult]: | |
return { | |
- key: self.compute_loss_one(data, critic_batch_info, goal, zgoal, dist, target, weight, info, **extra_kwargs) | |
- for key, (critic_batch_info, goal, zgoal, dist, target, weight, info) in tasks.items() | |
+ key: self.compute_loss_one(data, task, **extra_kwargs) for key, task in tasks.items() | |
} | |
def forward(self, data: BatchData, critic_batch_infos: Sequence[CriticBatchInfo]) -> LossResult: | |
loss_results: DefaultDict[str, Dict[str, LossResult]] = defaultdict(dict) | |
- for name, tasks in self.generate_dist_weight(data, critic_batch_infos): | |
+ for name, tasks in self.generate_tasks(data, critic_batch_infos): | |
for critic_desc, loss_result in self.compute_loss(data, tasks).items(): | |
loss_results[critic_desc][name] = loss_result | |
@@ -265,34 +279,35 @@ class GlobalPushLoss(GlobalPushLossBase): | |
self.softplus_offset = softplus_offset | |
self.clamp_max = clamp_max | |
- def compute_loss_one(self, data: BatchData, critic_batch_info: CriticBatchInfo, | |
- goal: torch.Tensor, zgoal: torch.Tensor, dist: torch.Tensor, | |
- target: Optional[torch.Tensor], weight: float, | |
- info: Mapping[str, torch.Tensor]) -> LossResult: | |
- dict_info: Dict[str, torch.Tensor] = dict(info) | |
+ def compute_loss_one(self, data: BatchData, task: Task) -> LossResult: | |
+ dict_info = dict(task.info) | |
# Sec 3.2. Transform so that we penalize large distances less. | |
- tsfm_dist: torch.Tensor = F.softplus(self.softplus_offset - dist, beta=self.softplus_beta) # type: ignore | |
+ tsfm_dist = task.current_dist | |
+ target_dist = task.target_dist | |
if self.clamp_max < float('inf'): | |
tsfm_dist = tsfm_dist.clamp_max(self.clamp_max) | |
dict_info.update( | |
- exceed_rate=(dist >= self.clamp_max).mean(dtype=torch.float32), | |
+ exceed_rate=time_agg(tsfm_dist >= self.clamp_max, self.rho, data.horizon), | |
) | |
- dict_info.update(dist=time_agg(dist, self.rho, data.horizon)) | |
+ tsfm_dist = F.softplus(self.softplus_offset - tsfm_dist, beta=self.softplus_beta) # type: ignore | |
agg_tsfm_dist = time_agg(tsfm_dist, self.rho, data.horizon) | |
dict_info.update(tsfm_dist=agg_tsfm_dist) | |
- if target is None: | |
+ if target_dist is None: | |
loss = agg_tsfm_dist | |
else: | |
- tsfm_target = F.softplus(self.softplus_offset - target, beta=self.softplus_beta) # type: ignore | |
+ tsfm_target = target_dist | |
+ if self.clamp_max < float('inf'): | |
+ tsfm_target = tsfm_target.clamp_max(self.clamp_max) | |
+ tsfm_target = F.softplus(self.softplus_offset - tsfm_target, beta=self.softplus_beta) # type: ignore | |
dict_info.update( | |
- target=time_agg(target, self.rho, data.horizon), | |
+ target=time_agg(target_dist, self.rho, data.horizon), | |
tsfm_target=time_agg(tsfm_target, self.rho, data.horizon), | |
) | |
loss = time_agg(F.l1_loss(tsfm_dist, tsfm_target, reduction='none'), self.rho, data.horizon) | |
- return LossResult(loss=loss * weight, info=dict_info) | |
+ return LossResult(loss=loss * task.weight, info=dict_info) | |
def extra_repr(self) -> str: | |
return '\n'.join([ | |
@@ -306,7 +321,7 @@ class GlobalPushLinearLoss(GlobalPushLossBase): | |
class Conf(GlobalPushLossBase.Conf): | |
enabled: bool = False | |
- clamp_max: Optional[float] = attrs.field(default=None, validator=attrs.validators.optional(attrs.validators.gt(0))) | |
+ clamp_max: float = attrs.field(default=float('inf'), validator=attrs.validators.gt(0)) | |
def make(self) -> Optional['GlobalPushLinearLoss']: | |
if not self.enabled: | |
@@ -324,47 +339,44 @@ class GlobalPushLinearLoss(GlobalPushLossBase): | |
clamp_max=self.clamp_max, | |
) | |
- clamp_max: Optional[float] | |
+ clamp_max: float | |
def __init__(self, *, weight: float, weight_future_goal: float, detach_goal: bool, detach_proj_goal: bool, | |
detach_qmet: bool, step_cost: float, rho: float, regress_max_future_goal: bool, mix_in_future_goal: bool, | |
- clamp_max: Optional[float]): | |
+ clamp_max: float): | |
super().__init__(weight=weight, weight_future_goal=weight_future_goal, detach_goal=detach_goal, | |
detach_proj_goal=detach_proj_goal, detach_qmet=detach_qmet, step_cost=step_cost, rho=rho, | |
regress_max_future_goal=regress_max_future_goal, mix_in_future_goal=mix_in_future_goal) | |
self.clamp_max = clamp_max | |
- def compute_loss_one(self, data: BatchData, critic_batch_info: CriticBatchInfo, | |
- goal: torch.Tensor, zgoal: torch.Tensor, dist: torch.Tensor, | |
- target: Optional[torch.Tensor], weight: float, | |
- info: Mapping[str, torch.Tensor]) -> LossResult: | |
- dict_info: Dict[str, torch.Tensor] = dict(info) | |
- dict_info.update(dist=time_agg(dist, self.rho, data.horizon)) | |
- if self.clamp_max is None: | |
- tsfm_dist = -dist | |
- else: | |
- tsfm_dist = -dist.clamp_max(self.clamp_max) | |
+ def compute_loss_one(self, data: BatchData, task: Task) -> LossResult: | |
+ dict_info = dict(task.info) | |
+ tsfm_dist = task.current_dist | |
+ target_dist = task.target_dist | |
+ if self.clamp_max < float('inf'): | |
+ tsfm_dist = tsfm_dist.clamp_max(self.clamp_max) | |
dict_info.update( | |
- exceed_rate=time_agg(dist >= self.clamp_max, self.rho, data.horizon), | |
+ exceed_rate=time_agg(tsfm_dist >= self.clamp_max, self.rho, data.horizon), | |
) | |
+ tsfm_dist = -tsfm_dist | |
agg_tsfm_dist = time_agg(tsfm_dist, self.rho, data.horizon) | |
dict_info.update(tsfm_dist=agg_tsfm_dist) | |
- if target is None: | |
+ if target_dist is None: | |
loss = agg_tsfm_dist | |
else: | |
- if self.clamp_max is not None: | |
- tsfm_target = target.clamp_max(self.clamp_max) | |
- else: | |
- tsfm_target = target | |
+ tsfm_target = target_dist | |
+ if self.clamp_max < float('inf'): | |
+ tsfm_target = tsfm_target.clamp_max(self.clamp_max) | |
+ tsfm_target = -tsfm_target | |
dict_info.update( | |
- target=time_agg(target, self.rho, data.horizon), | |
+ target=time_agg(target_dist, self.rho, data.horizon), | |
tsfm_target=time_agg(tsfm_target, self.rho, data.horizon), | |
) | |
loss = time_agg(F.l1_loss(tsfm_dist, tsfm_target, reduction='none'), self.rho, data.horizon) | |
- return LossResult(loss=loss * weight, info=dict_info) | |
+ return LossResult(loss=loss * task.weight, info=dict_info) | |
def extra_repr(self) -> str: | |
return '\n'.join([ | |
@@ -451,27 +463,32 @@ class GlobalPushNextLoss(GlobalPushLossBase): | |
self.regress_step_cost_mul = regress_step_cost_mul | |
self.worst_v_regularizer_weight_mul = worst_v_regularizer_weight_mul | |
- def compute_loss(self, data: BatchData, | |
- tasks: Dict[str, Tuple[CriticBatchInfo, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], float, Dict[str, torch.Tensor]]]) -> Dict[str, LossResult]: | |
+ def compute_loss(self, data: BatchData, tasks: Dict[str, Task]) -> Dict[str, LossResult]: | |
# adv = min_Q - avg_V (HILP) | |
qs_for_adv: List[torch.Tensor] = [] # use slack | |
regress_target_qs: Dict[str, torch.Tensor] = {} # no slack | |
next_dists: Dict[str, torch.Tensor] = {} | |
vs: List[torch.Tensor] = [] | |
- target_encoder_tasks = {} | |
- for critic_desc, (critic_batch_info, goal, zgoal, dist, target, weight, info) in tasks.items(): | |
+ | |
+ target_encoder_tasks: Dict[str, Task] = {} | |
+ for critic_desc, task in tasks.items(): | |
+ # expand things that will be overwritten | |
+ critic_batch_info = task.critic_batch_info | |
+ info = dict(task.info) | |
+ target_dist = task.target_dist | |
+ zgoal = task.zgoal | |
+ | |
critic = critic_batch_info.critic | |
- info = dict(info) | |
with torch.no_grad(), critic.mode(False): | |
if not critic_batch_info.zx_from_target_encoder and critic.has_separate_target_encoder: | |
zx = critic.encoder(data.observations, target=True) | |
else: | |
zx = critic_batch_info.zx | |
if not critic_batch_info.zy_from_target_encoder and critic.has_separate_target_encoder: | |
- zgoal = critic.encoder(goal, target=True) # goal is encoded in the same way as zy | |
+ zgoal = critic.encoder(task.goal, target=True) # goal is encoded in the same way as zy | |
- if target is not None: | |
- next_dist = target - self.step_cost # this is distance, not value | |
+ if target_dist is not None: | |
+ next_dist = target_dist - self.step_cost # this is distance, not value | |
q_dist_is_value = False | |
else: | |
if not critic_batch_info.zy_from_target_encoder and critic.has_separate_target_encoder: | |
@@ -503,7 +520,7 @@ class GlobalPushNextLoss(GlobalPushLossBase): | |
if FLAGS.DEBUG and critic.has_separate_target_encoder: | |
non_target_next_dist = critic.quasimetric_model( | |
critic.encoder(data.next_observations), | |
- critic.encoder(goal), | |
+ critic.encoder(task.goal), | |
) | |
( | |
info['tgt-ntgt_next_out_000'], | |
@@ -517,11 +534,15 @@ class GlobalPushNextLoss(GlobalPushLossBase): | |
non_target_next_dist.new_tensor([0, 0.1, 0.25, 0.5, 0.75, 0.9, 1], dtype=torch.float32) | |
).unbind() | |
- target_encoder_tasks[critic_desc] = ( | |
- attrs.evolve( | |
+ target_encoder_tasks[critic_desc] = attrs.evolve( | |
+ task, | |
+ critic_batch_info=attrs.evolve( | |
critic_batch_info, zx=zx, zy=zy, | |
- zx_from_target_encoder=True, zy_from_target_encoder=True), | |
- goal, zgoal, dist, target, weight, info, | |
+ zx_from_target_encoder=True, zy_from_target_encoder=True, | |
+ ), | |
+ zgoal=zgoal, | |
+ target_dist=target_dist, | |
+ info=info, | |
) | |
tensor_vs = torch.stack(vs, dim=-1) | |
@@ -542,20 +563,17 @@ class GlobalPushNextLoss(GlobalPushLossBase): | |
raise NotImplementedError(self.regress_target_kind) | |
return { | |
- key: self.compute_loss_one(data, critic_batch_info, goal, zgoal, dist, target, weight, info, | |
+ key: self.compute_loss_one(data, task, | |
advantage=advantage, | |
next_dists=next_dists[key], | |
regress_target_qs=regress_target_qs[key], | |
worst_v=worst_v) | |
- for key, (critic_batch_info, goal, zgoal, dist, target, weight, info) in target_encoder_tasks.items() | |
+ for key, task in target_encoder_tasks.items() | |
} | |
- def compute_loss_one(self, data: BatchData, critic_batch_info: CriticBatchInfo, | |
- goal: torch.Tensor, zgoal: torch.Tensor, dist: torch.Tensor, | |
- target: Optional[torch.Tensor], weight: float, | |
- info: Mapping[str, torch.Tensor], advantage: torch.Tensor, next_dists: torch.Tensor, | |
+ def compute_loss_one(self, data: BatchData, task: Task, advantage: torch.Tensor, next_dists: torch.Tensor, | |
regress_target_qs: torch.Tensor, worst_v: torch.Tensor) -> LossResult: | |
- dict_info: Dict[str, torch.Tensor] = dict(info) | |
+ dict_info = dict(task.info) | |
# if target is None: | |
# with torch.set_grad_enabled(self.detach_target_dist), critic_batch_info.critic.mode(False): | |
@@ -593,8 +611,7 @@ class GlobalPushNextLoss(GlobalPushLossBase): | |
) | |
target_q = regress_target_qs | |
- | |
- v = dist_to_value(dist, self.gamma, self.step_cost, | |
+ v = dist_to_value(task.current_dist, self.gamma, self.step_cost, | |
dist_is_value=self.dist_is_value) | |
if self.allow_gt: | |
@@ -628,7 +645,7 @@ class GlobalPushNextLoss(GlobalPushLossBase): | |
loss += regularizer * self.worst_v_regularizer_weight_mul | |
actual_dist = -time_agg(dist_to_value( # convert possibly value back to dist | |
- dist=dist, gamma=None, step_cost=self.step_cost, | |
+ dist=task.current_dist, gamma=None, step_cost=self.step_cost, | |
dist_is_value=self.dist_is_value, | |
), self.rho, data.horizon) | |
actual_q_dist = -time_agg(dist_to_value( # convert possibly value back to dist | |
@@ -638,13 +655,13 @@ class GlobalPushNextLoss(GlobalPushLossBase): | |
), self.rho, data.horizon) | |
dict_info.update( | |
iql_loss=iql_loss, | |
- dist=actual_dist, | |
+ actual_dist=actual_dist, | |
v=time_agg(v, self.rho, data.horizon), | |
vq_difference=time_agg(q_difference, self.rho, data.horizon), # (v - q) | |
vq_dist_difference=actual_dist - actual_q_dist, | |
) | |
- return LossResult(loss=loss * weight, info=dict_info) | |
+ return LossResult(loss=loss * task.weight, info=dict_info) | |
def extra_repr(self) -> str: | |
return '\n'.join([ | |
@@ -695,27 +712,26 @@ class GlobalPushLogLoss(GlobalPushLossBase): | |
self.offset = offset | |
- def compute_loss_one(self, data: BatchData, critic_batch_info: CriticBatchInfo, | |
- goal: torch.Tensor, zgoal: torch.Tensor, dist: torch.Tensor, | |
- target: Optional[torch.Tensor], weight: float, | |
- info: Mapping[str, torch.Tensor]) -> LossResult: | |
- dict_info: Dict[str, torch.Tensor] = dict(info) | |
- tsfm_dist: torch.Tensor = -dist.add(self.offset).log() | |
- dict_info.update(dist=time_agg(dist, self.rho, data.horizon)) | |
+ def compute_loss_one(self, data: BatchData, task: Task) -> LossResult: | |
+ dict_info = dict(task.info) | |
+ tsfm_dist = task.current_dist | |
+ target_dist = task.target_dist | |
+ | |
+ tsfm_dist: torch.Tensor = -tsfm_dist.add(self.offset).log() | |
agg_tsfm_dist = time_agg(tsfm_dist, self.rho, data.horizon) | |
dict_info.update(tsfm_dist=agg_tsfm_dist) | |
- if target is None: | |
+ if target_dist is None: | |
loss = agg_tsfm_dist | |
else: | |
- tsfm_target = -target.add(self.offset).log() | |
+ tsfm_target = -target_dist.add(self.offset).log() | |
dict_info.update( | |
- target=time_agg(target, self.rho, data.horizon), | |
+ target=time_agg(target_dist, self.rho, data.horizon), | |
tsfm_target=time_agg(tsfm_target, self.rho, data.horizon), | |
) | |
loss = time_agg(F.l1_loss(tsfm_dist, tsfm_target, reduction='none'), self.rho, data.horizon) | |
- return LossResult(loss=loss * weight, info=dict_info) | |
+ return LossResult(loss=loss * task.weight, info=dict_info) | |
def extra_repr(self) -> str: | |
return '\n'.join([ | |
@@ -761,34 +777,39 @@ class GlobalPushRBFLoss(GlobalPushLossBase): | |
regress_max_future_goal=regress_max_future_goal, mix_in_future_goal=mix_in_future_goal) | |
self.inv_scale = inv_scale | |
- def compute_loss_one(self, data: BatchData, critic_batch_info: CriticBatchInfo, | |
- goal: torch.Tensor, zgoal: torch.Tensor, dist: torch.Tensor, | |
- target: Optional[torch.Tensor], weight: float, | |
- info: Mapping[str, torch.Tensor]) -> LossResult: | |
- dict_info: Dict[str, torch.Tensor] = dict(info) | |
- inv_scale = time_agg(dist.detach().square(), self.rho, data.horizon).div(2).sqrt().clamp(1e-3, self.inv_scale) # make E[d^2]/r^2 approx 2 | |
- tsfm_dist: torch.Tensor = time_agg((dist / inv_scale).square().neg().exp(), self.rho, data.horizon) | |
+ def compute_loss_one(self, data: BatchData, task: Task) -> LossResult: | |
+ dict_info = dict(task.info) | |
+ current_dist = task.current_dist | |
+ target_dist = task.target_dist | |
+ | |
+ inv_scale = time_agg(current_dist.detach().square(), self.rho, data.horizon).div(2).sqrt().clamp(1e-3, self.inv_scale) # make E[d^2]/r^2 approx 2 | |
+ tsfm_dist: torch.Tensor = time_agg( | |
+ (current_dist / inv_scale).square().neg().exp(), | |
+ self.rho, data.horizon, | |
+ ) | |
rbf_potential = tsfm_dist.log() | |
dict_info.update( | |
- dist=time_agg(dist, self.rho, data.horizon), | |
inv_scale=inv_scale, | |
tsfm_dist=tsfm_dist, | |
rbf_potential=rbf_potential, | |
) | |
- if target is None: | |
+ if target_dist is None: | |
loss = rbf_potential | |
else: | |
- tsfm_target = time_agg((target / inv_scale).square().neg().exp(), self.rho, data.horizon) | |
+ tsfm_target = time_agg( | |
+ (target_dist / inv_scale).square().neg().exp(), | |
+ self.rho, data.horizon, | |
+ ) | |
target_rbf_potential = tsfm_target.log() | |
dict_info.update( | |
- target=time_agg(target, self.rho, data.horizon), | |
+ target=time_agg(target_dist, self.rho, data.horizon), | |
tsfm_target=tsfm_target, | |
target_rbf_potential=target_rbf_potential, | |
) | |
loss = F.l1_loss(rbf_potential, target_rbf_potential) | |
- return LossResult(loss=loss * weight, info=dict_info) | |
+ return LossResult(loss=loss * task.weight, info=dict_info) | |
def extra_repr(self) -> str: | |
return '\n'.join([ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment