Skip to content

Instantly share code, notes, and snippets.

@autumnjolitz
Last active April 7, 2024 22:53
Show Gist options
  • Save autumnjolitz/08d0774826737dcf267f80991252e3ea to your computer and use it in GitHub Desktop.
Save autumnjolitz/08d0774826737dcf267f80991252e3ea to your computer and use it in GitHub Desktop.
Subscribe sockets to multicast groups, allow send/recv
#!/usr/bin/env python3
"""
multicast - subscribe/send/recv for IPv4/IPv6 multicast groups with UDP.
:author: Autumn Jolitz
:date: 2024-04-06
:license: BSD-2-Clause
:tags: networking, multicast, sockets, udp
Copyright (c) 2024, Autumn Jolitz
Redistribution and use in source and binary forms, with or without modification, are permitted
provided that the following conditions are met:
Redistributions of source code must retain the above copyright notice, this list of
conditions and the following disclaimer.
Redistributions in binary form must reproduce the above copyright notice, this list of
conditions and the following disclaimer in the documentation and/or other materials
provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS
OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY
AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY
OF SUCH DAMAGE
"""
import ipaddress
import socket
import struct
import weakref
from abc import abstractmethod, abstractproperty
from collections.abc import Iterable, Buffer
from contextlib import ExitStack, suppress
from enum import IntEnum
from ipaddress import IPv4Address, IPv6Address
from typing import (
assert_never,
overload,
ClassVar,
Literal,
NamedTuple,
Sized as ImplementsSized,
TYPE_CHECKING,
)
IPAddress = IPv4Address | IPv6Address
MaybeIPAddress = str | int | bytes
HostAddress = Literal["", "::", "0.0.0.0"] | IPv4Address | IPv6Address | int | str
class SizedBuffer(ImplementsSized, Buffer):
pass
class MulticastFlags(IntEnum):
NONE = 0
MULTICAST_LOOP = 2
REUSEADDR = 4
REUSEPORT = 8
class ReadOnlyMulticastConfig:
__slots__ = ()
@property
def family(self) -> int:
version = self.address.version
if version == 4:
return socket.AF_INET
elif version == 6:
return socket.AF_INET6
else:
assert_never()
@property
def version(self):
return self.address.version
if TYPE_CHECKING:
@property
def address(self) -> IPAddress:
pass
class IMulticastSocket:
__slots__ = ("__weakref__",)
def __init_subclass__(cls):
errors = []
for key in dir(cls):
value = getattr(cls, key, None)
if isinstance(value, abstractproperty):
errors.append(NotImplementedError(f"You must implement abstractproperty {key!r}"))
elif getattr(value, "__isabstractmethod__", False) is True:
errors.append(NotImplementedError(f"You must implement abstractmethod {key!r}"))
if errors:
if len(errors) > 1:
raise ExceptionGroup(f"Multiple issues on class {cls}", tuple(errors))
raise NotImplementedError(str(errors[0]) + f" on class {cls}")
return cls
@abstractproperty
def groups(self):
pass
@abstractproperty
def group_set(self):
pass
@abstractmethod
def add(self, group):
pass
@abstractmethod
def remove(self, group):
pass
@abstractproperty
def port(self):
pass
@abstractproperty
def flags(self):
pass
@abstractmethod
def __enter__(self):
pass
@abstractmethod
def __exit__(self, exc_type, exc_value, traceback):
pass
@abstractmethod
def shutdown(self, direction: int):
pass
@abstractmethod
def send(self, b: bytes, /, group) -> int:
pass
@abstractmethod
def recvfrom(self, n: int, /):
pass
@abstractmethod
def recvfrom_into(self, buffer: SizedBuffer, nbytes: int = -1, flags: int = 0, /):
pass
@abstractproperty
def closed(self):
pass
@abstractmethod
def close(self):
pass
@abstractmethod
def broadcast_send(self, b: bytes, /) -> tuple[int, ...]:
pass
class BaseMulticastSocket(IMulticastSocket):
__slots__ = ("readonly", "_socket", "_groups", "_group_set")
_socket: socket.socket | None
if TYPE_CHECKING:
AddressType: ClassVar[type[IPv4Address] | type[IPv6Address]]
@property
def groups(self):
return self._groups
@property
def group_set(self):
return frozenset(self._group_set)
def add(self, group):
if self.closed:
raise ValueError("I/O operation on closed socket")
ip_cls = self.AddressType
if isinstance(group, ip_cls):
group_addr = group
elif isinstance(group, (str, int)):
group_addr = ip_cls(group)
else:
raise TypeError(f"{group!r} is not a str, int or {ip_cls.__name__}")
if not group_addr.is_multicast:
raise ValueError(f"{group!s} is not a multicast address!")
if group_addr in self._group_set:
return self
self._add(group_addr)
return self
def _add(self, group):
ip_cls = self.AddressType
assert self._socket is not None
assert isinstance(group, ip_cls), f"{group!r} is not a {self.AddressType!r}"
assert group.is_multicast, f"{group!r} not a multicast address!"
assert (
group.version == self.readonly.version
), f"ip version mismatch between socket {self.readonly.version} and group {group.version}!"
_add_group_to(self._socket, group, self.readonly)
self._groups = (*self._groups, group)
self._group_set.add(group)
def remove(self, group):
if self.closed:
raise ValueError("I/O operation on closed socket")
ip_cls = self.AddressType
if isinstance(group, ip_cls):
group_addr = group
elif isinstance(group, (str, int)):
group_addr = ip_cls(group)
else:
raise TypeError(f"{group!r} is not a str, int or {ip_cls.__name__}")
if not group.is_multicast:
raise ValueError(f"{group!s} is not a multicast address!")
if group not in self._group_set:
return self
self._remove(group_addr)
return self
def _remove(self, group):
assert self._socket is not None
assert isinstance(group, self.AddressType), f"{group!r} is not a {self.AddressType!r}"
assert group.is_multicast, f"{group!r} not a multicast address!"
assert (
group.version == self.readonly.version
), f"ip version mismatch between socket {self.readonly.version} and group {group.version}!"
_remove_group_from(self._socket, group, self.readonly)
index = self._groups.index(group)
self._groups = self._groups[:index] + self._groups[index + 1 :]
self._group_set.remove(group)
def __init__(self, sock, /, multicast_config, groups):
if not isinstance(sock, socket.socket):
raise TypeError(f"{sock!r} is not a socket!")
if not (
isinstance(groups, tuple)
and all(isinstance(group, self.AddressType) for group in groups)
):
raise TypeError(f"groups {groups!r} must be a tuple of {self.AddressType!r}")
if not isinstance(multicast_config, ReadOnlyMulticastConfig):
raise TypeError(f"multicast_config must be a ReadOnlyMulticastConfig!")
self._socket = sock
self.readonly = multicast_config
if sock.family != self.readonly.family:
raise ValueError(
f"mismatch between socket family {sock.family!r} and {self.readonly.family!r}"
)
self._groups = groups
self._group_set = set(groups)
if len(self._group_set) != len(groups):
raise ValueError("groups has duplicates!")
@property
def address(self):
return self.readonly.address
@property
def port(self):
return self.readonly.port
@property
def flags(self):
return self.readonly.flags
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()
def shutdown(self, direction: int):
assert direction in (socket.SHUT_RD, socket.SHUT_RDWR, socket.SHUT_WR)
if self.closed:
return
assert self._socket is not None
assert direction in (socket.SHUT_RD, socket.SHUT_RDWR, socket.SHUT_WR)
return self._socket.shutdown(direction)
def send(self, b: bytes, /, group) -> int:
if self.closed:
raise ValueError("operation on closed socket!")
assert self._socket is not None
ip_cls = self.AddressType
if not isinstance(group, (ip_cls, str, int)):
raise TypeError(f"group must be a str, int or {ip_cls.__name__}")
dest: IPAddress
if not group:
try:
(dest,) = self._groups
except ValueError:
raise ValueError("group must be specified!") from None
else:
if isinstance(group, (str, int)):
dest = ip_cls(group)
else:
dest = group
if not isinstance(dest, ip_cls):
raise TypeError(f"{group} is not a str or {ip_cls.__name__}")
if not dest.is_multicast:
raise ValueError("Not a multicast address!")
if dest not in self._group_set:
raise LookupError(f"Not subscribed to {dest!s}")
return self._socket.sendto(b, (str(dest), self.port))
def recvfrom(self, n: int, /):
if self.closed:
raise ValueError("operation on closed socket!")
assert self._socket is not None
if n < 0:
n = 4096
return self._socket.recvfrom(n)
def recvfrom_into(self, buffer: SizedBuffer, nbytes: int = -1, flags: int = 0, /):
if self.closed:
raise ValueError("operation on closed socket!")
assert self._socket is not None
if nbytes == -1:
assert not flags
return self._socket.recvfrom_into(buffer)
return self._socket.recvfrom_into(buffer, nbytes, flags)
@property
def closed(self):
return self._socket is None
def close(self):
if self._socket:
self._socket.close()
self._socket = None
@property
def socket(self):
return self._socket
def broadcast_send(self, b: bytes, /) -> tuple[int, ...]:
return tuple(self.send(b, group) for group in self._groups)
class _MulticastIPv4(NamedTuple):
address: IPv4Address
port: int
flags: MulticastFlags | int
class _MulticastIPv6(NamedTuple):
address: IPv6Address
port: int
flags: MulticastFlags | int
scope_id: int
device: str
class ReadOnlyMulticastIPv4Config(_MulticastIPv4, ReadOnlyMulticastConfig):
__slots__ = ()
class ReadOnlyMulticastIPv6Config(_MulticastIPv6, ReadOnlyMulticastConfig):
__slots__ = ()
class MulticastIPv4Socket(BaseMulticastSocket):
__slots__ = ()
AddressType: ClassVar[type[IPv4Address]] = IPv4Address
class MulticastIPv6Socket(BaseMulticastSocket):
__slots__ = ()
AddressType: ClassVar[type[IPv6Address]] = IPv6Address
@property
def scope_id(self) -> int:
return self.readonly.scope_id
@property
def device(self) -> str:
return self.readonly.device
class SocketGroupRead(NamedTuple):
socket_index: int
value: bytes
socket_group_ref: weakref.ReferenceType["MulticastSocketGroup"]
@property
def socket_group(self) -> "None | MulticastSocketGroup":
return self.socket_group_ref()
@property
def socket(self) -> MulticastIPv6Socket | MulticastIPv4Socket | None:
g = self.socket_group
if g is not None:
return g.sockets[self.socket_index]
return None
class SocketGroupBufferRead(NamedTuple):
socket_index: int
span: tuple[int, int]
from_addr: tuple[str, int]
socket_group_ref: weakref.ReferenceType["MulticastSocketGroup"]
@property
def length(self) -> int:
index, endex = self.span
return endex - index
@property
def socket_group(self) -> "None | MulticastSocketGroup":
return self.socket_group_ref()
@property
def socket(self) -> MulticastIPv6Socket | MulticastIPv4Socket | None:
g = self.socket_group
if g is not None:
return g.sockets[self.socket_index]
return None
class MulticastSocketGroup(IMulticastSocket):
__slots__ = ("sockets", "readonly", "in_ctx")
@property
def groups(self):
return tuple(group for group in socket.groups for socket in self.sockets)
@property
def group_set(self):
return frozenset(self.groups)
@property
def group_types(self):
return frozenset(socket.AddressType for socket in self.sockets)
@property
def port(self):
return self.sockets[0].port
def __init__(self, sock: MulticastIPv4Socket | MulticastIPv6Socket, *sockets):
self.sockets = (sock, *sockets)
self.in_ctx = None
for sock in self.sockets:
sock.socket.setblocking(False)
@property
def flags(self) -> tuple[MulticastFlags]:
return tuple(sock.flags for sock in self.sockets)
def add(self, group):
if isinstance(group, (int, str)):
group_addr = ipaddress.ip_address(group)
elif isinstance(group, (IPv4Address, IPv6Address)):
group_addr = group
else:
raise TypeError
types = tuple(self.group_types)
if not isinstance(group_addr, types):
raise TypeError("{} is not a {}".join(group, " or ".join(types)))
if group in self.group_set:
return self
for sock in self.sockets:
if isinstance(group_addr, sock.AddressType):
sock.add(group)
break
raise LookupError(f"Unable to find a group for {group_addr!r}")
def remove(self, group):
if isinstance(group, (int, str)):
group_addr = ipaddress.ip_address(group)
elif isinstance(group, (IPv4Address, IPv6Address)):
group_addr = group
else:
raise TypeError
types = tuple(self.group_types)
if not isinstance(group_addr, types):
raise TypeError("{} is not a {}".join(group, " or ".join(types)))
if group not in self.group_set:
return self
for sock in self.sockets:
if isinstance(group_addr, sock.AddressType) and group in sock._groups:
sock.remove(group)
return self
def __enter__(self):
if self.in_ctx is not None:
raise ValueError("reentrancy NOT supported")
self.in_ctx = ExitStack()
for s in self.sockets:
self.in_ctx.enter_context(s)
self.in_ctx.__enter__()
return self
def __exit__(self, exc_type, exc_value, traceback):
try:
return self.in_ctx.__exit__(exc_type, exc_value, traceback)
finally:
self.in_ctx = None
def send(self, b: bytes, /, group="") -> int:
if not group:
for sock in self.sockets:
return sock.send(b, group)
if isinstance(group, (int, str)):
group_addr = ipaddress.ip_address(group)
elif isinstance(group, (IPv4Address, IPv6Address)):
group_addr = group
else:
raise TypeError
types = tuple(self.group_types)
if not isinstance(group_addr, types):
raise TypeError("{} is not a {}".format(group, " or ".join(repr(x) for x in types)))
for sock in self.sockets:
if group_addr in sock.group_set:
return sock.send(b, group_addr)
raise LookupError(group)
def broadcast_send(self, b: bytes, /) -> tuple[int, ...]:
return tuple(val for sock in self.sockets for val in sock.broadcast_send(b))
def close(self):
for sock in self.sockets:
with suppress(OSError):
sock.close()
def shutdown(self, direction: int):
errors = []
for sock in self.sockets:
try:
sock.shutdown(direction)
except Exception as e:
errors.append(e)
if errors:
raise ExceptionGroup("Unable to shutdown", tuple(errors))
@property
def closed(self):
return all(sock.closed for sock in self.sockets)
def recvfrom(self, n: int, /) -> tuple[SocketGroupRead, ...]:
results = []
for index, sock in enumerate(self.sockets):
try:
value = sock.recvfrom(n)
except BlockingIOError:
continue
results.append(SocketGroupRead(index, value, weakref.ref(self)))
return tuple(results)
def recvfrom_into(
self, buffer: SizedBuffer, nbytes: int = -1, flags: int = 0, /
) -> tuple[SocketGroupBufferRead, ...]:
left = nbytes
if left == -1:
left = len(buffer)
spans = []
index = 0
with memoryview(buffer) as buf:
for sock_index, sock in enumerate(self.sockets):
try:
length, from_addr = sock.recvfrom_into(buf, left, flags)
except BlockingIOError:
continue
left -= length
spans.append(
SocketGroupBufferRead(
sock_index, (index, index + length), from_addr, weakref.ref(self)
)
)
index += length
if left == 0:
break
if length:
buf = buf[length:]
return tuple(spans)
def as_multicast_address(address: IPAddress | MaybeIPAddress) -> IPAddress | None:
assert isinstance(address, (str, int, IPv6Address, IPv4Address))
if isinstance(address, (str, int)):
addr = ipaddress.ip_address(address)
if not isinstance(addr, (IPv4Address, IPv6Address)):
return None
elif isinstance(address, (IPv4Address, IPv6Address)):
addr = address
else:
assert_never(address)
if addr.is_multicast:
return addr
return None
def _add_group_to(sock: socket.socket, group, config):
assert isinstance(group, (IPv4Address, IPv6Address))
assert isinstance(config, ReadOnlyMulticastConfig), f"{config!r}"
assert group.is_multicast
assert sock.family in (socket.AddressFamily.AF_INET, socket.AddressFamily.AF_INET6)
if socket.AddressFamily.AF_INET == sock.family:
assert isinstance(config, ReadOnlyMulticastIPv4Config)
assert isinstance(group, IPv4Address)
ipv4_group: IPv4Address = group
request = b"".join((ipv4_group.packed, config.address.packed))
sock.setsockopt(
socket.IPPROTO_IP,
socket.IP_ADD_MEMBERSHIP,
request,
)
elif socket.AddressFamily.AF_INET6 == sock.family:
assert isinstance(config, ReadOnlyMulticastIPv6Config)
assert isinstance(group, IPv6Address)
ipv6_group: IPv6Address = group
interface_index: bytes = struct.pack("i", config.scope_id)
request = b"".join((ipv6_group.packed, interface_index))
sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_JOIN_GROUP, request)
else:
raise NotImplementedError
def _remove_group_from(sock: socket.socket, group, config):
assert isinstance(group, (IPv4Address, IPv6Address))
assert sock.family in (socket.AddressFamily.AF_INET, socket.AddressFamily.AF_INET6)
assert group.is_multicast
if socket.AddressFamily.AF_INET == sock.family:
assert isinstance(group, IPv4Address)
assert isinstance(config, ReadOnlyMulticastIPv4Config)
ipv4_group: IPv4Address = group
request = b"".join((ipv4_group.packed, config.address.packed))
sock.setsockopt(
socket.IPPROTO_IP,
socket.IP_DROP_MEMBERSHIP,
request,
)
elif socket.AddressFamily.AF_INET6 == sock.family:
assert isinstance(group, IPv6Address)
assert isinstance(config, ReadOnlyMulticastIPv6Config)
ipv6_group: IPv6Address = group
interface_index: bytes = struct.pack("i", config.scope_id)
request = b"".join((ipv6_group.packed, interface_index))
sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_LEAVE_GROUP, request)
else:
raise NotImplementedError
if TYPE_CHECKING:
@overload
def create_multicast_on(
multicast_groups: Iterable[IPv4Address],
port: int,
/,
bind_on: Literal["0.0.0.0", ""] | IPv4Address = "",
flags: MulticastFlags | int = MulticastFlags.NONE,
*,
hop_limit: int = -1,
) -> MulticastIPv4Socket:
pass
@overload
def create_multicast_on(
multicast_groups: Iterable[IPv6Address],
port: int,
/,
bind_on: Literal["::", ""] | IPv6Address = "",
flags: MulticastFlags | int = MulticastFlags.NONE,
*,
hop_limit: int = -1,
) -> MulticastIPv6Socket:
pass
@overload
def create_multicast_on(
multicast_groups: Iterable[IPAddress],
port: int,
/,
bind_on: HostAddress = "",
flags: MulticastFlags | int = MulticastFlags.NONE,
*,
hop_limit: int = -1,
) -> MulticastIPv6Socket | MulticastIPv4Socket | MulticastSocketGroup:
pass
@overload
def create_multicast_on(
multicast_groups: str,
port: int,
/,
bind_on: HostAddress = "",
flags: MulticastFlags | int = MulticastFlags.NONE,
*,
hop_limit: int = -1,
) -> MulticastIPv6Socket | MulticastIPv4Socket | MulticastSocketGroup:
pass
@overload
def create_multicast_on(
multicast_addresses: Iterable[str],
port: int,
/,
bind_on: HostAddress = "",
flags: MulticastFlags | int = MulticastFlags.NONE,
*,
hop_limit: int = -1,
) -> MulticastIPv6Socket | MulticastIPv4Socket | MulticastSocketGroup:
pass
def create_multicast_on(
multicast_addresses,
port: int,
/,
bind_on: HostAddress = "",
flags=MulticastFlags.NONE,
*,
hop_limit: int = -1,
) -> MulticastSocketGroup | MulticastIPv4Socket | MulticastIPv6Socket:
assert isinstance(multicast_addresses, (str, IPv6Address, IPv4Address)) or isinstance(
multicast_addresses, Iterable
)
assert isinstance(port, int)
assert 0 < port < 65535
assert isinstance(flags, (int, MulticastFlags))
assert hop_limit != 0 and hop_limit < 256
pending_multicast_addresses: Iterable[str | IPv4Address | IPv6Address]
if isinstance(multicast_addresses, (str, IPv6Address, IPv4Address)):
pending_multicast_addresses = (multicast_addresses,)
else:
pending_multicast_addresses = multicast_addresses
del multicast_addresses
if hop_limit < 0:
hop_limit = 1
create_flags: MulticastFlags | int
if isinstance(flags, int):
try:
create_flags = MulticastFlags(flags)
except ValueError:
create_flags = flags
for flag in MulticastFlags:
if flag & flags:
flags ^= flag & flags
if flags:
raise
elif isinstance(flags, MulticastFlags):
create_flags = flags
else:
raise TypeError("flags are not an int or MulticastFlags")
errors = []
uniq_ipv4_multicast = set()
uniq_ipv6_multicast = set()
pending_ipv4_multicast = []
pending_ipv6_multicast = []
for addr in pending_multicast_addresses:
group_address = as_multicast_address(addr)
if group_address is None:
errors.append(ValueError(f"{addr!r} is not a valid multicast address!"))
continue
del addr
if 4 == group_address.version:
if group_address in uniq_ipv6_multicast:
continue
uniq_ipv4_multicast.add(group_address)
pending_ipv4_multicast.append(group_address)
elif 6 == group_address.version:
if group_address in uniq_ipv6_multicast:
continue
uniq_ipv6_multicast.add(group_address)
pending_ipv6_multicast.append(group_address)
else:
assert_never(group_address.version)
if errors:
if len(errors) == 1:
raise errors[0]
raise ExceptionGroup("Multiple invalid multicast addresses given", tuple(errors))
ipv4_multicast: tuple[IPv4Address, ...]
ipv6_multicast: tuple[IPv6Address, ...]
ipv4_multicast = tuple(pending_ipv4_multicast)
ipv6_multicast = tuple(pending_ipv6_multicast)
del uniq_ipv4_multicast, uniq_ipv6_multicast, pending_ipv4_multicast, pending_ipv6_multicast
if not any((ipv4_multicast, ipv6_multicast)):
raise ValueError("No multicast addresses given!")
del errors
del group_address
family: socket.AddressFamily
scope_id: int
bind_address: IPv4Address | IPv6Address
if bind_on in ("", "::", "0.0.0.0", 0):
if ipv4_multicast and ipv6_multicast:
s1 = create_multicast_on(ipv4_multicast, port, "0.0.0.0", flags=flags)
s2 = create_multicast_on(ipv6_multicast, port, "::", flags=flags)
assert isinstance(s1, MulticastIPv4Socket)
assert isinstance(s2, MulticastIPv4Socket)
return MulticastSocketGroup(s1, s2)
if ipv4_multicast:
bind_address = ipaddress.ip_address("0.0.0.0")
elif ipv6_multicast:
bind_address = ipaddress.ip_address("::")
else:
assert False, "unreachable"
else:
is_dns_name = (
isinstance(bind_on, str)
and ":" not in bind_on
and not (0 < bind_on.count(".") < 4 and all(x.isdigit() for x in bind_on.split(".")))
)
if is_dns_name:
assert isinstance(bind_on, str)
sockets: list[MulticastIPv4Socket | MulticastIPv6Socket] = []
gather_address_infos = socket.getaddrinfo(
bind_on, port, socket.AF_UNSPEC, socket.SOCK_DGRAM, socket.IPPROTO_UDP
)
distinct_address_families = frozenset(
family
for family, *_ in gather_address_infos
if family in (socket.AddressFamily.AF_INET, socket.AddressFamily.AF_INET6)
)
if len(distinct_address_families) > 1:
for family, sock_type, protocol, _, sock_addr in gather_address_infos:
address, socket_port, *addr_extra = sock_addr
if socket.AddressFamily.AF_INET == family:
bind_address4 = ipaddress.ip_address(address)
assert isinstance(bind_address4, IPv4Address)
new_socket4 = create_multicast_on(
ipv4_multicast, port, bind_address4, flags=create_flags
)
assert isinstance(new_socket4, MulticastIPv4Socket)
sockets.append(new_socket4)
elif socket.AddressFamily.AF_INET6 == family:
flow_info, scope_id = addr_extra
bind_address6 = ipaddress.ip_address(f"{address}%{scope_id}")
assert isinstance(bind_address6, IPv6Address)
new_socket6 = create_multicast_on(
ipv6_multicast,
port,
bind_address6,
flags=create_flags,
)
assert isinstance(new_socket6, MulticastIPv6Socket)
sockets.append(new_socket6)
del family, sock_type, protocol, sock_addr
del address, socket_port, addr_extra
assert sockets
return MulticastSocketGroup(*sockets)
elif distinct_address_families:
(gather_address_info,) = gather_address_infos
family = gather_address_info[0]
sock_addr = gather_address_info[-1]
host_address, _, *addr_extra = sock_addr
if family == socket.AddressFamily.AF_INET:
bind_address = ipaddress.ip_address(host_address)
assert isinstance(bind_on, IPv4Address)
elif family == socket.AddressFamily.AF_INET6:
flow_info, scope_id = addr_extra
bind_address = ipaddress.ip_address(f"{host_address}%{scope_id}")
assert isinstance(bind_on, IPv6Address)
else:
raise ValueError("unsupported")
is_dns_name = False
else:
raise ValueError("no interfaces to bind to")
if not is_dns_name:
if not isinstance(bind_on, (IPv4Address, IPv6Address)):
try:
bind_addr = ipaddress.ip_address(bind_on)
except ValueError:
...
else:
bind_address = bind_addr
del bind_addr
elif not isinstance(bind_on, (IPv4Address, IPv6Address)):
raise TypeError(f"{bind_on!r} is not an IPv4Address or IPv6Address!")
if isinstance(bind_address, IPv4Address):
family = socket.AF_INET
elif isinstance(bind_address, IPv6Address):
family = socket.AF_INET6
else:
assert_never(bind_address)
if 6 == bind_address.version and None is bind_address.scope_id:
(info,) = socket.getaddrinfo(
str(bind_address), port, family, socket.SOCK_DGRAM, socket.IPPROTO_UDP
)
socket_address = info[-1]
assert len(socket_address) == 4
address, _, flow_info6, scope_id = socket_address
bind_address = ipaddress.ip_address(f"{bind_address!s}%{scope_id}")
del flow_info6, scope_id, address, socket_address
sock = socket.socket(family, socket.SOCK_DGRAM)
if MulticastFlags.REUSEADDR & create_flags:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if MulticastFlags.REUSEPORT & create_flags:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
sock.bind(("", port))
request: bytes
ttl: bytes = struct.pack("B", hop_limit)
if family == socket.AF_INET:
assert isinstance(bind_address, IPv4Address)
readonly4 = ReadOnlyMulticastIPv4Config(bind_address, port, create_flags)
for multicast_group in ipv4_multicast:
_add_group_to(sock, multicast_group, readonly4)
if MulticastFlags.MULTICAST_LOOP & create_flags:
sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_LOOP, True)
sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, ttl)
sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_IF, bind_address.packed)
return MulticastIPv4Socket(
sock,
readonly4,
ipv4_multicast,
)
elif family == socket.AF_INET6:
assert isinstance(bind_address, IPv6Address)
assert not bind_address.ipv4_mapped, "ipv4-in-6 goes kernel kaboom"
assert bind_address.scope_id is not None
device: str
try:
if bind_address.scope_id.isdigit():
scope_id = int(bind_address.scope_id)
device = socket.if_indextoname(scope_id)
else:
scope_id = socket.if_nametoindex(bind_address.scope_id)
device = bind_address.scope_id
except OSError:
raise ValueError(f"IPv6 scope_id ({bind_address.scope_id!r}) is invalid!")
readonly6 = ReadOnlyMulticastIPv6Config(
bind_address,
port,
create_flags,
scope_id,
device,
)
for multicast_group in ipv4_multicast:
_add_group_to(sock, multicast_group, readonly6)
interface_index: bytes = struct.pack("i", readonly6.scope_id)
sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_IF, interface_index)
if MulticastFlags.MULTICAST_LOOP & create_flags:
sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_LOOP, True)
sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_HOPS, hop_limit)
return MulticastIPv6Socket(
sock,
readonly6,
ipv6_multicast,
)
else:
raise NotImplementedError
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.set_defaults(use_ipv4=True, use_ipv6=False)
excl_ipv4 = parser.add_mutually_exclusive_group()
excl_ipv4.add_argument("-4", "--ipv4", dest="use_ipv4", action="store_true")
excl_ipv4.add_argument("--no-ipv4", dest="use_ipv4", action="store_false")
excl_ipv6 = parser.add_mutually_exclusive_group()
excl_ipv6.add_argument("-6", "--ipv6", dest="use_ipv6", action="store_true")
excl_ipv6.add_argument("--no-ipv6", dest="use_ipv6", action="store_false")
args = parser.parse_args()
if args.use_ipv4:
with create_multicast_on(
{"224.0.1.187"}, 8885, flags=MulticastFlags.MULTICAST_LOOP
) as conn1, create_multicast_on(
{"224.0.1.188"}, 8886, flags=MulticastFlags.MULTICAST_LOOP
) as conn2:
for conn in conn1, conn2:
for name in conn.readonly._fields:
print(repr(getattr(conn, name)))
conn = MulticastSocketGroup(conn1, conn2)
print(conn.send(b"1234", "224.0.1.187"))
print(conn.send(b"1234", "224.0.1.188"))
print(conn.broadcast_send(b"1234519"))
print(conn.recvfrom(10))
b = bytearray(20)
print(conn.recvfrom_into(b))
print(b)
if args.use_ipv6:
with create_multicast_on("FF00::FD", 8885, flags=MulticastFlags.MULTICAST_LOOP) as conn:
print(conn.send(b"1234"))
print(conn.recvfrom(4))
@autumnjolitz
Copy link
Author

(cpython312) InvincibleReason:~$ mypy multicast.py
Success: no issues found in 1 source file
(cpython312) InvincibleReason:~$ black -l 101 multicast.py
All done! ✨ 🍰 ✨
1 file left unchanged.
(cpython312) InvincibleReason:~$ 

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment