Skip to content

Instantly share code, notes, and snippets.

@peci1
Last active April 13, 2022 02:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save peci1/a0346538acc6c289b2c6d596b184ad21 to your computer and use it in GitHub Desktop.
Save peci1/a0346538acc6c289b2c6d596b184ad21 to your computer and use it in GitHub Desktop.
from math import ceil
from typing import Tuple, List, Optional
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, ImageMagickWriter
from matplotlib.collections import Collection
from matplotlib.container import Container
from matplotlib.patches import Patch
class BitSize:
def __init__(self, bits: int, bytes: int):
self._bits = bits
self._bytes = bytes
def bits(self) -> int:
return self._bits
def bytes(self) -> int:
return self._bytes
def __add__(self, other):
assert isinstance(other, BitSize)
return BitSize(self.bits() + other.bits(), self.bytes() + other.bytes())
def __iadd__(self, other):
assert isinstance(other, BitSize)
self._bits += other.bits()
self._bytes += other.bytes()
return self
def __sub__(self, other):
assert isinstance(other, BitSize)
return BitSize(self.bits() - other.bits(), self.bytes() - other.bytes())
def __isub__(self, other):
assert isinstance(other, BitSize)
self._bits -= other.bits()
self._bytes -= other.bytes()
return self
def __ge__(self, other):
return self.bits() >= other.bits()
def __le__(self, other):
return self.bits() <= other.bits()
def __gt__(self, other):
return self.bits() > other.bits()
def __lt__(self, other):
return self.bits() < other.bits()
def __eq__(self, other):
return self.bits() == other.bits()
def __ne__(self, other):
return self.bits() != other.bits()
class Bits(BitSize):
def __init__(self, bits: int):
bytes = int(ceil(bits / 8.0))
super().__init__(bits, bytes)
@staticmethod
def kilo(kilobits):
return Bits(1024 * kilobits)
@staticmethod
def mega(megabits):
return Bits.kilo(1024 * megabits)
@staticmethod
def giga(gigabits):
return Bits.mega(1024 * gigabits)
class Bytes(BitSize):
def __init__(self, bytes: int):
super().__init__(bytes * 8, bytes)
@staticmethod
def kilo(kilobytes):
return Bytes(1024 * kilobytes)
@staticmethod
def mega(megabytes):
return Bytes.kilo(1024 * megabytes)
@staticmethod
def giga(gigabytes):
return Bytes.mega(1024 * gigabytes)
class Drawable:
def __init__(self):
self._patches: List[plt.Artist] = list()
self._axes: Optional[plt.Axes] = None
self._pos: Tuple[float, float] = (0, 0)
self._last_update_time = 0
def get_patches(self) -> Container:
if len(self._patches) == 0:
self.update_patches()
return self._patches
def update_patches(self):
if len(self._patches) > 0 and self._axes is not None:
for artist in self._patches:
artist.remove()
self._patches = self._create_patches()
if self._axes is not None:
self.add_to(self._axes)
self.set_pos(self._pos)
def _create_patches(self) -> List[plt.Artist]:
raise NotImplementedError()
def set_pos(self, pos: Tuple[float, float]) -> None:
self._pos = pos
self._set_pos(pos)
def _set_pos(self, pos: Tuple[float, float]) -> None:
raise NotImplementedError()
def add_to(self, axes: plt.Axes) -> None:
self._axes = axes
for patch in self.get_patches():
if isinstance(patch, Patch):
axes.add_patch(patch)
elif isinstance(patch, Collection):
axes.add_collection(patch)
elif isinstance(patch, plt.Artist):
axes.add_artist(patch)
elif isinstance(patch, Container):
axes.add_container(patch)
def advance_time(self, time_ns: int) -> None:
self._last_update_time = time_ns
def update_visuals(self, time_ns: int) -> bool:
return True
class Rectangle(Drawable):
def __init__(self, name: str, color: str):
super().__init__()
self._name = name
self._width = 0.1
self._height = 0.1
self._color = color
self._edgecolor = None
self._alpha = 0.2
self._x = 0.0
self._y = 0.0
self._rectangle_patch: Optional[plt.Rectangle] = None
def _create_patches(self) -> List[plt.Artist]:
self._x = self._y = 0
self._rectangle_patch = plt.Rectangle((0, 0), self._width, self._height,
facecolor=self._color, edgecolor=self._edgecolor,
alpha=self._alpha)
return [self._rectangle_patch]
def _set_pos(self, pos: Tuple[float, float]) -> None:
x = pos[0] - self._width / 2
y = pos[1] - self._height / 2
dx = x - self._x
dy = y - self._y
for patch in self.get_patches():
px = patch.get_x() if isinstance(patch, plt.Rectangle) else patch.get_position()[0]
py = patch.get_y() if isinstance(patch, plt.Rectangle) else patch.get_position()[1]
patch.set_x(px + dx)
patch.set_y(py + dy)
self._x = x
self._y = y
class Port(Rectangle):
def __init__(self, name: str, is_input: bool = True, pos: Tuple[float, float] = None):
super().__init__(name, "green" if is_input else "red")
if pos is not None:
self.set_pos(pos)
def _create_patches(self) -> List[plt.Artist]:
patches = super()._create_patches()
rect = self._rectangle_patch
text = plt.Text(text=self._name)
text_x = rect.get_x() + self._width / 2 - 0.03
text_y = rect.get_y() + self._height / 2 - 0.02
text.set_position((text_x, text_y))
patches.append(text)
return patches
class Packet(Rectangle):
def __init__(self, length: BitSize, color: str, switch: 'Switch', in_port: Port, out_port: Port,
receive_time_ns: int):
super().__init__("packet", color)
self.length = length
self._width = length.bits() / 10 / Bytes(1500).bits()
self._edgecolor = "black"
self._hidden = False
self._is_dropped = False
self._dropped_time: Optional[int] = None
self._started_receiving = False
self._is_received = False
self._is_in_buffer = False
self._started_transmitting = False
self._is_transmitted = False
self._started_processing = False
self._in_port = in_port
self._out_port = out_port
self._switch = switch
self._receive_time_ns = receive_time_ns
self._processing_start_time_ns: Optional[int] = None
self._transmission_start_time_ns: Optional[int] = None
self.set_pos((-100000, in_port._pos[1]))
@property
def color(self):
return self._color
def mark_dropped(self):
self._rectangle_patch.set_facecolor("black")
self._is_dropped = True
def transmission_duration_ns(self) -> int:
return self.length.bits() # one bit takes one ns on a Gigabit link
def receive_end_time_ns(self) -> int:
return self._receive_time_ns + self.transmission_duration_ns()
def transmit_end_time_ns(self) -> int:
return self._transmission_start_time_ns + self.transmission_duration_ns()
def hide(self) -> None:
if self._hidden:
return
for patch in self._patches:
patch.set_alpha(0)
self._hidden = True
def show(self) -> None:
if not self._hidden:
return
for patch in self._patches:
patch.set_alpha(0.2 if patch == self._rectangle_patch else 1)
self._hidden = False
def _get_position(self, time_ns: int) -> Tuple[float, float]:
speed = 0.1 / 12000
if not self._is_received:
time_until_receive = self._receive_time_ns - time_ns
x = self._in_port._pos[0] - self._width - time_until_receive * speed
y = self._in_port._pos[1]
elif self._started_transmitting:
time_after_transmit = time_ns - self._transmission_start_time_ns
x = self._out_port._pos[0] + time_after_transmit * speed
y = self._out_port._pos[1]
else:
x = 0.5
y = 0.5
return x, y
def _create_patches(self) -> List[plt.Artist]:
patches = super()._create_patches()
x = self._rectangle_patch.get_x()
y = self._rectangle_patch.get_y()
text = plt.Text(text=f"{self.length.bytes() / 1024.0 : .1f} kB")
text_x = x + self._width / 2 - 0.04
text_y = y + self._height + 0.02
text.set_fontsize('small')
text.set_position((text_x, text_y))
patches.append(text)
return patches
def start_processing(self, time_ns: int) -> None:
self._started_processing = True
self._processing_start_time_ns = time_ns
def start_transmission(self, time_ns: int) -> None:
self._started_transmitting = True
self._transmission_start_time_ns = time_ns
def advance_time(self, time_ns: int) -> None:
if not self._is_dropped and time_ns > self._receive_time_ns - 25000:
if not self._is_in_buffer and not self._is_transmitted and time_ns >= self._receive_time_ns:
self._switch.buffer.add_packet(self)
self._started_receiving = True
if not self._is_dropped:
self._is_in_buffer = True
else:
self._dropped_time = time_ns
elif self._is_in_buffer and not self._started_transmitting and time_ns > self.receive_end_time_ns():
self.hide()
self._is_received = True
elif self._started_transmitting:
self.show()
self._last_update_time = time_ns
def update_visuals(self, time_ns: int) -> bool:
if time_ns < self._receive_time_ns - 25000:
return True
pos = self._get_position(time_ns)
self.set_pos(pos)
return pos[0] < 2
class Buffer(Rectangle):
def __init__(self, capacity: BitSize, unit_packet_size: BitSize, processing_delay_ns: int):
super().__init__("buffer", "black")
self._capacity: BitSize = capacity
self._used_size: BitSize = Bytes(0)
self._unit_packet_size: BitSize = unit_packet_size
self._processing_delay_ns = processing_delay_ns
self._num_bins = int(ceil(self._capacity.bits() / self._unit_packet_size.bits()))
self._alpha = 0.1
self._width = 0.5
self._height = 0.5
self._packets: List[Packet] = list()
self._bin_patches: List[plt.Rectangle] = list()
self.set_pos((0.5, 0.5))
def _create_patches(self) -> List[plt.Artist]:
patches = super()._create_patches()
x = self._rectangle_patch.get_x()
y = self._rectangle_patch.get_y()
w, h = self._width, self._height
text = plt.Text(text=f"{self._capacity.bytes() // 1024} kB buffer")
text_x = x + w / 2 - 0.07
text_y = y + h + 0.02
text.set_position((text_x, text_y))
patches.append(text)
self._bin_patches.clear()
bin_w = w / self._num_bins
for i in range(self._num_bins):
rect = plt.Rectangle((i * bin_w, y), bin_w, h, facecolor="blue", edgecolor=None)
self._bin_patches.append(rect)
return patches + self._bin_patches
def update_buffer_content(self) -> None:
patches = self._bin_patches
packet_idx = 0 if len(self._packets) > 0 else None
packet_bits_processed = Bits(0)
for i in range(self._num_bins):
color = self._packets[packet_idx].color if packet_idx is not None else "blue"
patches[i].set_color(color)
if packet_idx is not None:
if packet_bits_processed + self._unit_packet_size >= self._packets[packet_idx].length:
if packet_idx + 1 < len(self._packets):
packet_idx += 1
packet_bits_processed = Bits(0)
else:
packet_idx = None
else:
packet_bits_processed += self._unit_packet_size
def add_packet(self, packet: Packet) -> None:
if self._used_size + packet.length <= self._capacity:
self._packets.append(packet)
self._used_size += packet.length
self.update_buffer_content()
else:
packet.mark_dropped()
def remove_packet(self, packet: Packet) -> None:
self._packets.remove(packet)
self._used_size -= packet.length
self.update_buffer_content()
def advance_time(self, time_ns: int) -> None:
if len(self._packets) == 0:
return
p0 = self._packets[0]
if not p0._is_received:
return
if not p0._started_processing:
p0.start_processing(time_ns)
elif not p0._started_transmitting and time_ns >= p0._processing_start_time_ns + self._processing_delay_ns:
p0.start_transmission(time_ns)
elif p0._started_transmitting and time_ns >= p0.transmit_end_time_ns():
p0._is_transmitted = True
p0._is_in_buffer = False
self.remove_packet(p0)
self._last_update_time = time_ns
class Switch(Rectangle):
def __init__(self):
super().__init__("switch", "black")
self._alpha = 0.1
self._width = 0.8
self._height = 0.8
self._buffer = Optional[Buffer]
self._ports: List[Port] = list()
self.set_pos((0.5, 0.5))
def add_port(self, port: Port) -> None:
self._ports.append(port)
@property
def buffer(self):
return self._buffer
@buffer.setter
def buffer(self, buf: Buffer):
self._buffer = buf
fig: plt.Figure = plt.figure(figsize=(10, 7))
ax: plt.Axes = fig.gca()
ax.set_xlim(-0.25, 1.25)
ax.set_ylim(0, 1)
ax.set_frame_on(False)
ax.set_xticks([])
ax.set_yticks([])
switch = Switch()
switch.add_to(ax)
processing_delay = 512 # processing delay is max the transmission duration of a 64-byte packet
memory_page_size = Bytes(1514)
# Gigablox Rugged (chip VSC7512)
# buffer = Buffer(Bits.kilo(1750), memory_page_size, processing_delay)
# Gigablox (chip RTL8367N-VB-CG)
# buffer = Buffer(Bits.mega(2), memory_page_size, processing_delay)
# Netgear GS105 (chip BCM53115)
# buffer = Buffer(Bytes.kilo(128), memory_page_size, processing_delay)
# Netgear GS108 (chip BCM53118)
# buffer = Buffer(Bytes.kilo(192), memory_page_size, processing_delay)
# Zyxel XGS1210-12 (chip RTL9302B)
# buffer = Buffer(Bytes.kilo(1500), memory_page_size, processing_delay // 10)
# D-Link DIS-100G-5W (chip QCA8337N-AL3C)
# buffer = Buffer(Bits.mega(1), memory_page_size, processing_delay // 10)
buffer = Buffer(Bytes.kilo(24), memory_page_size, processing_delay)
buffer.add_to(ax)
switch.buffer = buffer
in_ports = [
Port("in1", True, (0.1, 0.8)),
Port("in2", True, (0.1, 0.2)),
]
out_ports = [
Port("out1", False, (0.9, 0.5)),
]
for port in in_ports + out_ports:
port.add_to(ax)
switch.add_port(port)
packets: List[Packet] = list()
def add_ouster_batch(packets: List[Packet], start_time: int):
packets.append(Packet(Bytes(1514), "yellow", switch, in_ports[0], out_ports[0], start_time))
for i in range(15):
t = packets[-1].receive_end_time_ns()
packets.append(Packet(Bytes(1514), "yellow", switch, in_ports[0], out_ports[0], t))
packets.append(Packet(Bytes(1258), "yellow", switch, in_ports[0], out_ports[0], packets[-1].receive_end_time_ns()))
def add_uniform_batch(packets: List[Packet], start_time: int):
packets.append(Packet(Bytes(1514), "green", switch, in_ports[1], out_ports[0], start_time))
for i in range(128):
t = packets[-1].receive_end_time_ns() + packets[-1].transmission_duration_ns()
packets.append(Packet(Bytes(1514), "green", switch, in_ports[1], out_ports[0], t))
def add_peaking_batch(packets: List[Packet], start_time: int, length: int):
packets.append(Packet(Bytes(1514), "green", switch, in_ports[1], out_ports[0], start_time))
for i in range(length):
t = packets[-1].receive_end_time_ns()
packets.append(Packet(Bytes(1514), "green", switch, in_ports[1], out_ports[0], t))
add_ouster_batch(packets, 12500)
add_ouster_batch(packets, packets[-16].receive_end_time_ns() + 1560000)
# add_uniform_batch(packets, 12500)
add_peaking_batch(packets, 12500, 42)
add_peaking_batch(packets, 12500 + 1000000, 42)
for packet in packets:
packet.add_to(ax)
to_update: List[Drawable] = [buffer] + packets
time_step = 2000
def update(t):
for dt in range(time_step):
time_ns = t + dt
for obj in to_update:
obj.advance_time(time_ns)
to_remove = list()
for obj in to_update:
retain = obj.update_visuals(time_ns)
if not retain:
to_remove.append(obj)
for obj in to_remove:
to_update.remove(obj)
packets.remove(obj)
anim = FuncAnimation(fig, update, frames=range(0, 2000000, time_step), interval=1, blit=False, repeat=False)
#writer = ImageMagickWriter(fps=1000)
#anim.save('/tmp/a.gif', writer=writer)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment