Skip to content

Instantly share code, notes, and snippets.

@proelbtn
Last active August 5, 2020 16:15
Show Gist options
  • Save proelbtn/3ec335ceaef91bdb9ce72b3980e8eb37 to your computer and use it in GitHub Desktop.
Save proelbtn/3ec335ceaef91bdb9ce72b3980e8eb37 to your computer and use it in GitHub Desktop.
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Callable, List, TypeVar
Input = TypeVar("Input")
Output = TypeVar("Output")
@dataclass
class Link:
from_state: int
to_state: int
data: Any
@dataclass
class ViterbiDecoderParams:
nstate: int
initial_state: int
final_state: int
initial_metric: int
links: List[Link]
calc_metric: Callable[[Link, Output], int]
calc_input: Callable[[Link], Input]
class ViterbiDecoder:
@dataclass
class State:
metric: int
links: List[Link]
def __init__(self, params: ViterbiDecoderParams):
self.nstate = params.nstate
if self.nstate <= 0:
raise Exception("nstate must be greater than 0")
self.initial_state = params.initial_state
if not 0 <= self.initial_state < self.nstate:
raise Exception("initial_state is out of range")
self.final_state = params.final_state
if not 0 <= self.final_state < self.nstate:
raise Exception("final_state is out of range")
self.initial_metric = params.initial_metric
self.links = params.links
for link in self.links:
if not 0 <= link.from_state < self.nstate:
raise Exception("from_state is out of range")
if not 0 <= link.to_state < self.nstate:
raise Exception("to_state is out of range")
self.calc_metric = params.calc_metric
self.calc_input = params.calc_input
def decode(self, outputs: List[Output]) -> List[Input]:
states: List[self.State] = [None] * self.nstate
states[self.initial_state] = self.State(self.initial_metric, [])
for output in outputs:
new_states: List[self.State] = [None] * self.nstate
for link in self.links:
f, t = link.from_state, link.to_state
if states[f] is None:
continue
m = self.calc_metric(link, output)
need_update = new_states[t] is None
if not need_update:
need_update = states[f].metric + m < new_states[t].metric
if need_update:
new_metric = states[f].metric + m
new_links = deepcopy(states[f].links) + [link]
new_states[t] = self.State(new_metric, new_links)
states = new_states
if states[self.final_state] is None:
raise Exception("decode failed")
inputs = []
for link in states[self.final_state].links:
inputs.append(self.calc_input(link))
return inputs
@dataclass
class LinkData:
input: Input
output: Output
def calc_metric(link: Link, output: Output) -> int:
return sum([o1 != o2 for o1, o2 in zip(output, link.data.output)])
def calc_input(link: Link) -> Input:
return link.data.input
links = [
Link(0, 0, LinkData(0, (0, 0))),
Link(0, 2, LinkData(1, (1, 1))),
Link(2, 1, LinkData(0, (0, 1))),
Link(2, 3, LinkData(1, (1, 0))),
Link(1, 0, LinkData(0, (0, 1))),
Link(1, 2, LinkData(1, (1, 0))),
Link(3, 1, LinkData(0, (0, 0))),
Link(3, 3, LinkData(1, (1, 1)))
]
params = ViterbiDecoderParams(
nstate=4,
initial_state=0,
final_state=0,
initial_metric=0,
links=links,
calc_metric=calc_metric,
calc_input=calc_input
)
decoder = ViterbiDecoder(params)
sample_received = [(1, 1), (1, 1), (0, 1), (1, 1), (0, 1), (1, 1)]
sample_inputs = [1, 0, 0, 1, 0, 0]
assert decoder.decode(sample_received) == sample_inputs
received = [(0, 0), (1, 1), (1, 0), (1, 1), (0, 1), (1, 1)]
inputs = decoder.decode(received)
print(inputs)
state = 0
outputs = []
for input in inputs:
for link in links:
if link.from_state == state and link.data.input == input:
state = link.to_state
outputs.append(link.data.output)
break
print(outputs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment