This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os, pickle, argparse | |
import torch | |
import numpy as np | |
import torch.nn as nn | |
import gymnasium as gym | |
import torch.nn.functional as F | |
from torch.distributions import Normal | |
from optim import ObGD as Optimizer | |
from time_wrapper import AddTimeInfo | |
from normalization_wrappers import NormalizeObservation, ScaleReward |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os, pickle, argparse | |
import torch | |
import numpy as np | |
import torch.nn as nn | |
import gymnasium as gym | |
import torch.nn.functional as F | |
from torch.distributions import Normal | |
from optim import ObGD as Optimizer | |
from time_wrapper import AddTimeInfo | |
from normalization_wrappers import NormalizeObservation, ScaleReward |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
from torch import Tensor | |
from torch.nn import functional as F | |
from torch.nn.modules.loss import _Loss | |
class GaussianNLLLoss(_Loss): | |
def __init__(self, full: bool = False, eps: float = 1e-6, reduction: str = 'mean', require_grad_var=False, require_grad_mean=True) -> None: | |
if require_grad_var == require_grad_mean: | |
raise ValueError("Either require_grad_var or require_grad_mean must be true") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.nn.modules.loss import _Loss | |
import torch | |
from torch import Tensor | |
from torch.nn import functional as F | |
class MultiLabelNLLoss(_Loss): | |
def __init__(self, reduction = 'mean'): | |
self.reduction = reduction | |
super(MultiLabelNLLoss, self).__init__(reduction=reduction) |