Skip to content

Instantly share code, notes, and snippets.

@DiogoRibeiro7
Last active September 28, 2023 09:11
Show Gist options
  • Save DiogoRibeiro7/c5d2306dde70df7c85b2337605540d98 to your computer and use it in GitHub Desktop.
Save DiogoRibeiro7/c5d2306dde70df7c85b2337605540d98 to your computer and use it in GitHub Desktop.
from typing import Callable, List, Tuple
import numpy as np
from scipy.stats import norm
class SegmentChangePointDetector:
def __init__(self, prior: Callable[[int], float]) -> None:
"""
Initialize the SegmentChangePointDetector model.
Parameters:
- prior (Callable[[int], float]): The prior distribution for the change points
"""
self.prior = prior
def calculate_likelihood(self, data_segment: List[float]) -> float:
"""
Calculate the likelihood of a data segment using Gaussian distribution.
Parameters:
- data_segment (List[float]): A segment of the data
Returns:
- float: The likelihood value
"""
mean = np.mean(data_segment)
std = np.std(data_segment)
likelihood_values = norm.pdf(data_segment, mean, std)
return np.prod(likelihood_values)
def find_optimal_segments(self, data: List[float]) -> List[int]:
"""
Identify the most likely segments using dynamic programming.
Parameters:
- data (List[float]): The observed data
Returns:
- List[int]: The most likely change points
"""
n = len(data)
dp = np.zeros(n)
pointers = np.zeros(n, dtype=int)
change_points = []
for j in range(1, n):
segment_likelihoods = np.array([self.calculate_likelihood(data[i:j + 1]) for i in range(j)])
priors = np.array([self.prior(i) for i in range(j)])
dp[j] = np.max(dp[:j] + segment_likelihoods + priors)
pointers[j] = np.argmax(dp[:j] + segment_likelihoods + priors)
# Backtrack to find the change points
i = n - 1
while i > 0:
change_points.append(i)
i = pointers[i]
return list(reversed(change_points[:-1]))
# Usage
def uniform_prior(x: int) -> float:
"""Example of a uniform prior."""
return 1.0
# Create an instance of the detector with a uniform prior
detector = SegmentChangePointDetector(uniform_prior)
# Sample data
data = [1.0, 2.0, 3.0, 4.0, 10.0, 11.0, 12.0, 2.0, 3.0, 4.0]
# Find the optimal segments
optimal_segments = detector.find_optimal_segments(data)
print("Optimal Change Points:", optimal_segments)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment