Last active
July 22, 2022 14:31
-
-
Save SteelPh0enix/10635d2d8f2b316ddf256edce7e75b80 to your computer and use it in GitHub Desktop.
Simple 1D Kalman filter implementation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dataclasses import dataclass | |
import pandas as pd | |
@dataclass | |
class KalmanFilterUpdateResult: | |
time_point: float | |
gain: float | |
previous_state: float | |
current_state: float | |
predicted_state: float | |
previous_uncertainty: float | |
current_uncertainty: float | |
predicted_uncertainty: float | |
@dataclass | |
class KalmanFilterInitializationResult: | |
predicted_state: float | |
predicted_uncertainty: float | |
class StaticSystemKalmanFilter: | |
def __init__( | |
self, | |
initial_state: float | None = None, | |
initial_uncertainty: float | None = None, | |
): | |
self._gains: List[float] = [] | |
self._time_points: List[float] = [] | |
self._measurements: List[float] = [] | |
self._measurement_uncertainties: List[float] = [] | |
self._states: List[float] = [] | |
self._state_uncertainties: List[float] = [] | |
self._predicted_states: List[float] = [] | |
self._predicted_state_uncertainties: List[float] = [] | |
self._initial_state: float | None = None | |
self._initial_uncertainty: float | None = None | |
self.initialize(initial_state, initial_uncertainty) | |
def initialize( | |
self, initial_state: float, initial_uncertainty: float | |
) -> KalmanFilterInitializationResult: | |
# Clear filter's data | |
self.reset() | |
# Store the values | |
self._initial_state = initial_state | |
self._initial_uncertainty = initial_uncertainty | |
# Predict the next state | |
predicted_state = self._calculate_state_prediction(initial_state) | |
predicted_uncertainty = self._calculate_uncertainty_prediction( | |
initial_uncertainty | |
) | |
self._predicted_states.append(predicted_state) | |
self._predicted_state_uncertainties.append(predicted_uncertainty) | |
return KalmanFilterInitializationResult(predicted_state, predicted_uncertainty) | |
def update( | |
self, measurement: float, measurement_uncertainty: float, time_delta: float | |
) -> KalmanFilterUpdateResult: | |
"""`time_delta` is the time since last measurement.""" | |
if len(self._predicted_states) == 0: | |
raise RuntimeError("Kalman filter is not initialized!") | |
self._measurements.append(measurement) | |
self._measurement_uncertainties.append(measurement_uncertainty) | |
time_point = ( | |
time_delta | |
if len(self._time_points) == 0 | |
else self._time_points[-1] + time_delta | |
) | |
self._time_points.append(time_point) | |
# Let's name some values | |
previous_state = self._predicted_states[-1] | |
previous_uncertainty = self._predicted_state_uncertainties[-1] | |
# Calculate Kalman Gain | |
gain = self._calculate_kalman_gain( | |
previous_uncertainty, measurement_uncertainty | |
) | |
self._gains.append(gain) | |
# Estimate current state and uncertainty | |
current_state = self._calculate_state_update(previous_state, measurement, gain) | |
current_uncertainty = self._calculate_uncertainty_update( | |
gain, previous_uncertainty | |
) | |
self._states.append(current_state) | |
self._state_uncertainties.append(current_uncertainty) | |
# Predict the next state | |
predicted_state = self._calculate_state_prediction(current_state) | |
predicted_uncertainty = self._calculate_uncertainty_prediction( | |
current_uncertainty | |
) | |
self._predicted_states.append(predicted_state) | |
self._predicted_state_uncertainties.append(predicted_uncertainty) | |
return KalmanFilterUpdateResult( | |
time_point, | |
gain, | |
previous_state, | |
current_state, | |
predicted_state, | |
previous_uncertainty, | |
current_uncertainty, | |
predicted_uncertainty, | |
) | |
def reset(self): | |
self._gains.clear() | |
self._time_points.clear() | |
self._measurements.clear() | |
self._measurement_uncertainties.clear() | |
self._states.clear() | |
self._predicted_states.clear() | |
self._state_uncertainties.clear() | |
self._predicted_state_uncertainties.clear() | |
self._initial_state = None | |
self._initial_uncertainty = None | |
@property | |
def initial_state(self) -> float | None: | |
return self._initial_state | |
@property | |
def initial_state_uncertainty(self) -> float | None: | |
return self._initial_uncertainty | |
@property | |
def gains(self) -> List[float]: | |
return self._gains | |
@property | |
def measurements(self) -> List[float]: | |
return self._measurements | |
@property | |
def measurements_uncertainties(self) -> List[float]: | |
return self._measurement_uncertainties | |
@property | |
def states(self) -> List[float]: | |
return self._states | |
@property | |
def uncertainties(self) -> List[float]: | |
return self._state_uncertainties | |
@property | |
def predicted_states(self) -> List[float]: | |
return self._predicted_states | |
@property | |
def predicted_state_uncertainties(self) -> List[float]: | |
return self._predicted_state_uncertainties | |
def get_dataframe(self) -> pd.DataFrame: | |
measurements_amount = len(self._measurements) | |
dataframe_dict = { | |
"n": list(range(measurements_amount)), | |
"Time": self._time_points, | |
"Measurements": self._measurements, | |
"Measurement uncertainties": self._measurement_uncertainties, | |
"States": self._states, | |
"Predicted states": self._predicted_states[:-1], | |
"State uncertainties": self._state_uncertainties, | |
"Predicted state uncertainties": self._predicted_state_uncertainties[:-1], | |
"Kalman gain": self._gains, | |
} | |
return pd.DataFrame(dataframe_dict) | |
# These are Kalman equations | |
def _calculate_state_update( | |
self, previous_state: float, measurement: float, kalman_gain: float | |
) -> float: | |
return previous_state + (kalman_gain * (measurement - previous_state)) | |
def _calculate_kalman_gain( | |
self, previous_uncertainty: float, measurement_uncertainty: float | |
) -> float: | |
return previous_uncertainty / (previous_uncertainty + measurement_uncertainty) | |
def _calculate_uncertainty_update( | |
self, kalman_gain: float, previous_uncertainty | |
) -> float: | |
return (1.0 - kalman_gain) * previous_uncertainty | |
# These two are case-specific, for a static system these are constant. | |
# For dynamic, you have to modify them accordingly. | |
def _calculate_state_prediction(self, current_state: float): | |
return current_state | |
def _calculate_uncertainty_prediction(self, current_uncertainty: float): | |
return current_uncertainty |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment