Skip to content

Instantly share code, notes, and snippets.

@rjurney
Created November 15, 2020 21:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rjurney/e419af8fabd39d605ef0a87eb8f1562c to your computer and use it in GitHub Desktop.
Save rjurney/e419af8fabd39d605ef0a87eb8f1562c to your computer and use it in GitHub Desktop.
A Tensorflow/Keras implementation of Adjusted R Squared
import typing
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow_addons.utils.types import AcceptableDTypes
from typeguard import typechecked
class AdjustedRSquared(tfa.metrics.RSquare):
@typechecked
def __init__(
self,
name: str = "adjusted_r2",
dtype: AcceptableDTypes = None,
y_shape: typing.Tuple[int, ...] = (),
multioutput: str = "uniform_average",
X_shape: typing.Tuple[int, ...] = (),
**kwargs
):
super().__init__(
name=name,
dtype=dtype,
y_shape=y_shape,
multioutput=multioutput,
**kwargs
)
# Set the X shape to compute Adjusted R^2
self.data_points = X_shape[0]
self.features = X_shape[1]
def result(self) -> tf.Tensor:
mean = self.sum / self.count
total = self.squared_sum - self.sum * mean
r_squared = 1 - (self.res / total)
raw_scores = 1 - (1 - r_squared) * (
(self.data_points - 1) /
(self.data_points - self.features - 1)
)
if self.multioutput == "raw_values":
return raw_scores
if self.multioutput == "uniform_average":
return tf.reduce_mean(raw_scores)
if self.multioutput == "variance_weighted":
return tfa.metrics._reduce_average(raw_scores, weights=total)
raise RuntimeError(
"The multioutput attribute must be one of {}, but was: {}".format(
tfa.metrics.VALID_MULTIOUTPUT, self.multioutput
)
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment