Last active
August 5, 2020 16:15
-
-
Save proelbtn/3ec335ceaef91bdb9ce72b3980e8eb37 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from copy import deepcopy | |
from dataclasses import dataclass | |
from typing import Any, Callable, List, TypeVar | |
Input = TypeVar("Input") | |
Output = TypeVar("Output") | |
@dataclass | |
class Link: | |
from_state: int | |
to_state: int | |
data: Any | |
@dataclass | |
class ViterbiDecoderParams: | |
nstate: int | |
initial_state: int | |
final_state: int | |
initial_metric: int | |
links: List[Link] | |
calc_metric: Callable[[Link, Output], int] | |
calc_input: Callable[[Link], Input] | |
class ViterbiDecoder: | |
@dataclass | |
class State: | |
metric: int | |
links: List[Link] | |
def __init__(self, params: ViterbiDecoderParams): | |
self.nstate = params.nstate | |
if self.nstate <= 0: | |
raise Exception("nstate must be greater than 0") | |
self.initial_state = params.initial_state | |
if not 0 <= self.initial_state < self.nstate: | |
raise Exception("initial_state is out of range") | |
self.final_state = params.final_state | |
if not 0 <= self.final_state < self.nstate: | |
raise Exception("final_state is out of range") | |
self.initial_metric = params.initial_metric | |
self.links = params.links | |
for link in self.links: | |
if not 0 <= link.from_state < self.nstate: | |
raise Exception("from_state is out of range") | |
if not 0 <= link.to_state < self.nstate: | |
raise Exception("to_state is out of range") | |
self.calc_metric = params.calc_metric | |
self.calc_input = params.calc_input | |
def decode(self, outputs: List[Output]) -> List[Input]: | |
states: List[self.State] = [None] * self.nstate | |
states[self.initial_state] = self.State(self.initial_metric, []) | |
for output in outputs: | |
new_states: List[self.State] = [None] * self.nstate | |
for link in self.links: | |
f, t = link.from_state, link.to_state | |
if states[f] is None: | |
continue | |
m = self.calc_metric(link, output) | |
need_update = new_states[t] is None | |
if not need_update: | |
need_update = states[f].metric + m < new_states[t].metric | |
if need_update: | |
new_metric = states[f].metric + m | |
new_links = deepcopy(states[f].links) + [link] | |
new_states[t] = self.State(new_metric, new_links) | |
states = new_states | |
if states[self.final_state] is None: | |
raise Exception("decode failed") | |
inputs = [] | |
for link in states[self.final_state].links: | |
inputs.append(self.calc_input(link)) | |
return inputs | |
@dataclass | |
class LinkData: | |
input: Input | |
output: Output | |
def calc_metric(link: Link, output: Output) -> int: | |
return sum([o1 != o2 for o1, o2 in zip(output, link.data.output)]) | |
def calc_input(link: Link) -> Input: | |
return link.data.input | |
links = [ | |
Link(0, 0, LinkData(0, (0, 0))), | |
Link(0, 2, LinkData(1, (1, 1))), | |
Link(2, 1, LinkData(0, (0, 1))), | |
Link(2, 3, LinkData(1, (1, 0))), | |
Link(1, 0, LinkData(0, (0, 1))), | |
Link(1, 2, LinkData(1, (1, 0))), | |
Link(3, 1, LinkData(0, (0, 0))), | |
Link(3, 3, LinkData(1, (1, 1))) | |
] | |
params = ViterbiDecoderParams( | |
nstate=4, | |
initial_state=0, | |
final_state=0, | |
initial_metric=0, | |
links=links, | |
calc_metric=calc_metric, | |
calc_input=calc_input | |
) | |
decoder = ViterbiDecoder(params) | |
sample_received = [(1, 1), (1, 1), (0, 1), (1, 1), (0, 1), (1, 1)] | |
sample_inputs = [1, 0, 0, 1, 0, 0] | |
assert decoder.decode(sample_received) == sample_inputs | |
received = [(0, 0), (1, 1), (1, 0), (1, 1), (0, 1), (1, 1)] | |
inputs = decoder.decode(received) | |
print(inputs) | |
state = 0 | |
outputs = [] | |
for input in inputs: | |
for link in links: | |
if link.from_state == state and link.data.input == input: | |
state = link.to_state | |
outputs.append(link.data.output) | |
break | |
print(outputs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment