Skip to content

Instantly share code, notes, and snippets.

@HirbodBehnam
Last active March 8, 2024 07:23
Show Gist options
  • Save HirbodBehnam/54b708cd1ac21c0606c631e8fe957a87 to your computer and use it in GitHub Desktop.
Save HirbodBehnam/54b708cd1ac21c0606c631e8fe957a87 to your computer and use it in GitHub Desktop.
Convert NFA to DFA and visualize it with networkx
import networkx as nx
import matplotlib.pyplot as plt
EPSILON = "e"
class NFA:
def __init__(self, alphabet: list[str], adjacency_list: list[dict[str, list[int]]]):
"""
adjacency_list is a list of states that each state has an dict that specifies
each transition to new state.
"""
self.adjacency_list = adjacency_list
self.alphabet = alphabet
def to_dfa(
self,
) -> tuple[frozenset[int], dict[frozenset[int], dict[str, frozenset[int]]]]:
result: dict[frozenset[int], dict[str, frozenset[int]]] = {}
# Create the start state
to_explore_states: set[frozenset[int]] = set()
start_state = self.calculate_epsilon_moves(set([1]))
to_explore_states.add(start_state) # Start is the first state
# Explore all possible states
while len(to_explore_states) != 0:
to_explore_state = to_explore_states.pop()
current_state_result: dict[str, frozenset[int]] = dict()
for s in self.alphabet:
next_states = self.next_states(to_explore_state, s)
if len(next_states) == 0:
continue
current_state_result[s] = next_states
if next_states not in result:
to_explore_states.add(next_states)
result[to_explore_state] = current_state_result
return start_state, result
def calculate_epsilon_moves(self, current_states: frozenset[int]) -> frozenset[int]:
result = set(current_states)
while True:
new_state = result.copy()
for state in result:
if EPSILON in self.adjacency_list[state]:
new_state.update(self.adjacency_list[state][EPSILON])
if len(result) == len(new_state):
break
result = new_state
return frozenset(result)
def next_states(
self, current_states: frozenset[int], transition: str
) -> frozenset[int]:
result: set[int] = set()
for state in current_states:
if transition in self.adjacency_list[state]:
result.update(self.adjacency_list[state][transition])
return self.calculate_epsilon_moves(result)
# Q1
nfa_stuff = [
{}, # Zeroth state is dummy
{"a": [2], EPSILON: [5, 17]}, # 1
{"a": [3]}, # 2
{"a": [4]}, # 3
{EPSILON: [1, 5]}, # 4
{EPSILON: [6, 9, 13]}, # 5
{"a": [7]}, # 6
{"b": [8]}, # 7
{"a": [12]}, # 8
{"b": [10]}, # 9
{"a": [11]}, # 10
{"b": [12]}, # 11
{EPSILON: [5, 13]}, # 12
{"b": [14], EPSILON: [17]}, # 13
{"b": [15]}, # 14
{"b": [16]}, # 15
{EPSILON: [13, 17]}, # 16
{EPSILON: [1]}, # 17
]
# Q2
nfa_stuff = [
{}, # Dummy
{"a": [2], "b": [2], EPSILON: [2]}, # 1
{EPSILON: [3, 5, 7]}, # 2
{"a": [4]}, # 3
{"a": [7]}, # 4
{"b": [6]}, # 5
{"b": [7]}, # 6
{EPSILON: [8, 10, 12]}, # 7
{"a": [9]}, # 8
{"b": [12]}, # 9
{"b": [11]}, # 10
{"a": [12]}, # 11
{EPSILON: [7]}, # 12
]
nfa = NFA(["a", "b"], nfa_stuff)
start_state, dfa = nfa.to_dfa()
print(dfa)
# Visualize
def my_draw_networkx_edge_labels(
G,
pos,
edge_labels=None,
label_pos=0.5,
font_size=10,
font_color="k",
font_family="sans-serif",
font_weight="normal",
alpha=None,
bbox=None,
horizontalalignment="center",
verticalalignment="center",
ax=None,
rotate=True,
clip_on=True,
rad=0,
):
"""Draw edge labels.
Parameters
----------
G : graph
A networkx graph
pos : dictionary
A dictionary with nodes as keys and positions as values.
Positions should be sequences of length 2.
edge_labels : dictionary (default={})
Edge labels in a dictionary of labels keyed by edge two-tuple.
Only labels for the keys in the dictionary are drawn.
label_pos : float (default=0.5)
Position of edge label along edge (0=head, 0.5=center, 1=tail)
font_size : int (default=10)
Font size for text labels
font_color : string (default='k' black)
Font color string
font_weight : string (default='normal')
Font weight
font_family : string (default='sans-serif')
Font family
alpha : float or None (default=None)
The text transparency
bbox : Matplotlib bbox, optional
Specify text box properties (e.g. shape, color etc.) for edge labels.
Default is {boxstyle='round', ec=(1.0, 1.0, 1.0), fc=(1.0, 1.0, 1.0)}.
horizontalalignment : string (default='center')
Horizontal alignment {'center', 'right', 'left'}
verticalalignment : string (default='center')
Vertical alignment {'center', 'top', 'bottom', 'baseline', 'center_baseline'}
ax : Matplotlib Axes object, optional
Draw the graph in the specified Matplotlib axes.
rotate : bool (deafult=True)
Rotate edge labels to lie parallel to edges
clip_on : bool (default=True)
Turn on clipping of edge labels at axis boundaries
Returns
-------
dict
`dict` of labels keyed by edge
Examples
--------
>>> G = nx.dodecahedral_graph()
>>> edge_labels = nx.draw_networkx_edge_labels(G, pos=nx.spring_layout(G))
Also see the NetworkX drawing examples at
https://networkx.org/documentation/latest/auto_examples/index.html
See Also
--------
draw
draw_networkx
draw_networkx_nodes
draw_networkx_edges
draw_networkx_labels
"""
import matplotlib.pyplot as plt
import numpy as np
if ax is None:
ax = plt.gca()
if edge_labels is None:
labels = {(u, v): d for u, v, d in G.edges(data=True)}
else:
labels = edge_labels
text_items = {}
for (n1, n2), label in labels.items():
(x1, y1) = pos[n1]
(x2, y2) = pos[n2]
(x, y) = (
x1 * label_pos + x2 * (1.0 - label_pos),
y1 * label_pos + y2 * (1.0 - label_pos),
)
pos_1 = ax.transData.transform(np.array(pos[n1]))
pos_2 = ax.transData.transform(np.array(pos[n2]))
linear_mid = 0.5 * pos_1 + 0.5 * pos_2
d_pos = pos_2 - pos_1
rotation_matrix = np.array([(0, 1), (-1, 0)])
ctrl_1 = linear_mid + rad * rotation_matrix @ d_pos
ctrl_mid_1 = 0.5 * pos_1 + 0.5 * ctrl_1
ctrl_mid_2 = 0.5 * pos_2 + 0.5 * ctrl_1
bezier_mid = 0.5 * ctrl_mid_1 + 0.5 * ctrl_mid_2
(x, y) = ax.transData.inverted().transform(bezier_mid)
if rotate:
# in degrees
angle = np.arctan2(y2 - y1, x2 - x1) / (2.0 * np.pi) * 360
# make label orientation "right-side-up"
if angle > 90:
angle -= 180
if angle < -90:
angle += 180
# transform data coordinate angle to screen coordinate angle
xy = np.array((x, y))
trans_angle = ax.transData.transform_angles(
np.array((angle,)), xy.reshape((1, 2))
)[0]
else:
trans_angle = 0.0
# use default box of white with white border
if bbox is None:
bbox = dict(boxstyle="round", ec=(1.0, 1.0, 1.0), fc=(1.0, 1.0, 1.0))
if not isinstance(label, str):
label = str(label) # this makes "1" and 1 labeled the same
t = ax.text(
x,
y,
label,
size=font_size,
color=font_color,
family=font_family,
weight=font_weight,
alpha=alpha,
horizontalalignment=horizontalalignment,
verticalalignment=verticalalignment,
rotation=trans_angle,
transform=ax.transData,
bbox=bbox,
zorder=1,
clip_on=clip_on,
)
text_items[(n1, n2)] = t
ax.tick_params(
axis="both",
which="both",
bottom=False,
left=False,
labelbottom=False,
labelleft=False,
)
return text_items
G = nx.DiGraph()
for state in dfa.keys():
G.add_node(state, label=str(sorted(state)))
color_map = ["green" if node == start_state else "blue" for node in G]
for current_state, transitions in dfa.items():
for transition, next_state in transitions.items():
G.add_edge(current_state, next_state, label=transition)
pos = nx.spring_layout(G)
nx.draw_networkx_nodes(G, pos, node_color=color_map)
curved_edges = [edge for edge in G.edges() if reversed(edge) in G.edges()]
straight_edges = list(set(G.edges()) - set(curved_edges))
nx.draw_networkx_edges(G, pos, edgelist=straight_edges)
nx.draw_networkx_edges(G, pos, edgelist=curved_edges, connectionstyle="arc3, rad = 0.1")
nx.draw_networkx_labels(G, pos, labels=nx.get_node_attributes(G, "label"))
edge_weights = nx.get_edge_attributes(G, "label")
curved_edge_labels = {edge: edge_weights[edge] for edge in curved_edges}
straight_edge_labels = {edge: edge_weights[edge] for edge in straight_edges}
my_draw_networkx_edge_labels(
G, pos, edge_labels=curved_edge_labels, rotate=False, rad=0.1
)
nx.draw_networkx_edge_labels(G, pos, edge_labels=straight_edge_labels, rotate=False)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment