Last active
May 14, 2021 15:47
-
-
Save nanthony007/37238c10ce43c0df5cdfd7a73cb1436d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import defaultdict, Counter | |
from dataclasses import InitVar, dataclass, field | |
from typing import Union | |
import enum | |
class NodeType(enum.Enum): | |
PATIENT = 1 | |
PRESCRIBER = 2 | |
PHARMACY = 3 | |
@dataclass | |
class CustomCounter: | |
lst: InitVar[list[tuple[str, str]]] | |
counts: dict[tuple[str, str], int] = field(init=False) | |
def __post_init__(self, lst): | |
result = defaultdict(int) | |
for pair in lst: | |
# by sorting we can ensure we catch the 'reflexive tuples' | |
sorted_pair = tuple(sorted(pair)) | |
result[sorted_pair] += 1 | |
self.counts = result | |
def most_common( | |
self, limit: Union[int, None] = None | |
) -> list[tuple[tuple[str, str], int]]: | |
ordered_keys = sorted(self.counts, key=lambda x: self.counts[x], reverse=True) | |
ordered_result: list[tuple[tuple[str, str], int]] = [] | |
for item in ordered_keys: | |
ordered_result.append((item, self.counts[item])) | |
return ordered_result[:limit] if limit else ordered_result | |
@dataclass(unsafe_hash=True) | |
class Node: | |
__slots__ = "name", "kind" | |
name: str | |
kind: NodeType | |
def __str__(self) -> str: | |
return self.name | |
@dataclass | |
class Edge: | |
# TODO: want weight to track number of connections on each edge | |
__slots__ = "src", "dest" | |
src: Node | |
dest: Node | |
def __str__(self) -> str: | |
return f"{self.src.name} -> {self.dest.name}" | |
@dataclass | |
class Digraph: | |
vertices: dict[Node, list[Node]] = field(default_factory=dict) | |
def add_node(self, node: Node): | |
if self.vertices.get(node, None): | |
pass # node exists | |
else: | |
self.vertices[node] = [] | |
def add_edge(self, edge: Edge): | |
try: | |
self.vertices[edge.src].append(edge.dest) | |
except KeyError as err: | |
return err | |
def children(self, node: Node) -> list[Node]: | |
return self.vertices[node] | |
def has_node(self, node: Node) -> bool: | |
return node in self.vertices | |
def get_node(self, name: str) -> Node: | |
for n in self.vertices: | |
if n.name == name: | |
return n | |
raise NameError(name) | |
def node_pairs(self) -> list[tuple[str, str]]: | |
pairs = [] | |
for src in self.vertices: | |
for dest in self.vertices[src]: | |
pairs.append((src.name, dest.name)) | |
return pairs | |
# * start analytics section | |
# * basics | |
def vertex_count(self, by_type: bool = True) -> Union[int, Counter]: | |
if by_type: | |
return Counter([vertex.kind.name for vertex in self.vertices]) | |
return len(self.vertices) | |
def edge_count(self) -> int: | |
return sum(len(self.children(node)) for node in self.vertices) | |
def vertex_degree( | |
self, vertex: Union[Node, None] = None | |
) -> Union[int, list[tuple[int, int]]]: | |
if vertex: | |
return len(self.children(vertex)) | |
return sorted( | |
[(n, len(self.children(n))) for n in self.vertices], | |
key=lambda x: x[1], | |
reverse=True, | |
) | |
# * advanced | |
def edge_rank( | |
self, vertex: Union[Node, None] = None, limit: Union[int, None] = None | |
) -> list[tuple[tuple[str, str], int]]: | |
if vertex: | |
pairs = [] | |
for partner in self.vertices[vertex]: | |
pairs.append((vertex.name, partner.name)) | |
return CustomCounter(pairs).most_common(limit) | |
return CustomCounter(self.node_pairs()).most_common(limit) | |
@dataclass | |
class Graph(Digraph): | |
def add_edge(self, edge: Edge): | |
Digraph.add_edge(self, edge) | |
rev = Edge(edge.dest, edge.src) | |
Digraph.add_edge(self, rev) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment