Skip to content

Instantly share code, notes, and snippets.

@SteelPh0enix
Last active July 22, 2022 14:31
Show Gist options
  • Save SteelPh0enix/10635d2d8f2b316ddf256edce7e75b80 to your computer and use it in GitHub Desktop.
Save SteelPh0enix/10635d2d8f2b316ddf256edce7e75b80 to your computer and use it in GitHub Desktop.
Simple 1D Kalman filter implementation
from dataclasses import dataclass
import pandas as pd
@dataclass
class KalmanFilterUpdateResult:
time_point: float
gain: float
previous_state: float
current_state: float
predicted_state: float
previous_uncertainty: float
current_uncertainty: float
predicted_uncertainty: float
@dataclass
class KalmanFilterInitializationResult:
predicted_state: float
predicted_uncertainty: float
class StaticSystemKalmanFilter:
def __init__(
self,
initial_state: float | None = None,
initial_uncertainty: float | None = None,
):
self._gains: List[float] = []
self._time_points: List[float] = []
self._measurements: List[float] = []
self._measurement_uncertainties: List[float] = []
self._states: List[float] = []
self._state_uncertainties: List[float] = []
self._predicted_states: List[float] = []
self._predicted_state_uncertainties: List[float] = []
self._initial_state: float | None = None
self._initial_uncertainty: float | None = None
self.initialize(initial_state, initial_uncertainty)
def initialize(
self, initial_state: float, initial_uncertainty: float
) -> KalmanFilterInitializationResult:
# Clear filter's data
self.reset()
# Store the values
self._initial_state = initial_state
self._initial_uncertainty = initial_uncertainty
# Predict the next state
predicted_state = self._calculate_state_prediction(initial_state)
predicted_uncertainty = self._calculate_uncertainty_prediction(
initial_uncertainty
)
self._predicted_states.append(predicted_state)
self._predicted_state_uncertainties.append(predicted_uncertainty)
return KalmanFilterInitializationResult(predicted_state, predicted_uncertainty)
def update(
self, measurement: float, measurement_uncertainty: float, time_delta: float
) -> KalmanFilterUpdateResult:
"""`time_delta` is the time since last measurement."""
if len(self._predicted_states) == 0:
raise RuntimeError("Kalman filter is not initialized!")
self._measurements.append(measurement)
self._measurement_uncertainties.append(measurement_uncertainty)
time_point = (
time_delta
if len(self._time_points) == 0
else self._time_points[-1] + time_delta
)
self._time_points.append(time_point)
# Let's name some values
previous_state = self._predicted_states[-1]
previous_uncertainty = self._predicted_state_uncertainties[-1]
# Calculate Kalman Gain
gain = self._calculate_kalman_gain(
previous_uncertainty, measurement_uncertainty
)
self._gains.append(gain)
# Estimate current state and uncertainty
current_state = self._calculate_state_update(previous_state, measurement, gain)
current_uncertainty = self._calculate_uncertainty_update(
gain, previous_uncertainty
)
self._states.append(current_state)
self._state_uncertainties.append(current_uncertainty)
# Predict the next state
predicted_state = self._calculate_state_prediction(current_state)
predicted_uncertainty = self._calculate_uncertainty_prediction(
current_uncertainty
)
self._predicted_states.append(predicted_state)
self._predicted_state_uncertainties.append(predicted_uncertainty)
return KalmanFilterUpdateResult(
time_point,
gain,
previous_state,
current_state,
predicted_state,
previous_uncertainty,
current_uncertainty,
predicted_uncertainty,
)
def reset(self):
self._gains.clear()
self._time_points.clear()
self._measurements.clear()
self._measurement_uncertainties.clear()
self._states.clear()
self._predicted_states.clear()
self._state_uncertainties.clear()
self._predicted_state_uncertainties.clear()
self._initial_state = None
self._initial_uncertainty = None
@property
def initial_state(self) -> float | None:
return self._initial_state
@property
def initial_state_uncertainty(self) -> float | None:
return self._initial_uncertainty
@property
def gains(self) -> List[float]:
return self._gains
@property
def measurements(self) -> List[float]:
return self._measurements
@property
def measurements_uncertainties(self) -> List[float]:
return self._measurement_uncertainties
@property
def states(self) -> List[float]:
return self._states
@property
def uncertainties(self) -> List[float]:
return self._state_uncertainties
@property
def predicted_states(self) -> List[float]:
return self._predicted_states
@property
def predicted_state_uncertainties(self) -> List[float]:
return self._predicted_state_uncertainties
def get_dataframe(self) -> pd.DataFrame:
measurements_amount = len(self._measurements)
dataframe_dict = {
"n": list(range(measurements_amount)),
"Time": self._time_points,
"Measurements": self._measurements,
"Measurement uncertainties": self._measurement_uncertainties,
"States": self._states,
"Predicted states": self._predicted_states[:-1],
"State uncertainties": self._state_uncertainties,
"Predicted state uncertainties": self._predicted_state_uncertainties[:-1],
"Kalman gain": self._gains,
}
return pd.DataFrame(dataframe_dict)
# These are Kalman equations
def _calculate_state_update(
self, previous_state: float, measurement: float, kalman_gain: float
) -> float:
return previous_state + (kalman_gain * (measurement - previous_state))
def _calculate_kalman_gain(
self, previous_uncertainty: float, measurement_uncertainty: float
) -> float:
return previous_uncertainty / (previous_uncertainty + measurement_uncertainty)
def _calculate_uncertainty_update(
self, kalman_gain: float, previous_uncertainty
) -> float:
return (1.0 - kalman_gain) * previous_uncertainty
# These two are case-specific, for a static system these are constant.
# For dynamic, you have to modify them accordingly.
def _calculate_state_prediction(self, current_state: float):
return current_state
def _calculate_uncertainty_prediction(self, current_uncertainty: float):
return current_uncertainty
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment