Skip to content

Instantly share code, notes, and snippets.

@nanthony007
Last active May 14, 2021 15:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nanthony007/37238c10ce43c0df5cdfd7a73cb1436d to your computer and use it in GitHub Desktop.
Save nanthony007/37238c10ce43c0df5cdfd7a73cb1436d to your computer and use it in GitHub Desktop.
from collections import defaultdict, Counter
from dataclasses import InitVar, dataclass, field
from typing import Union
import enum
class NodeType(enum.Enum):
PATIENT = 1
PRESCRIBER = 2
PHARMACY = 3
@dataclass
class CustomCounter:
lst: InitVar[list[tuple[str, str]]]
counts: dict[tuple[str, str], int] = field(init=False)
def __post_init__(self, lst):
result = defaultdict(int)
for pair in lst:
# by sorting we can ensure we catch the 'reflexive tuples'
sorted_pair = tuple(sorted(pair))
result[sorted_pair] += 1
self.counts = result
def most_common(
self, limit: Union[int, None] = None
) -> list[tuple[tuple[str, str], int]]:
ordered_keys = sorted(self.counts, key=lambda x: self.counts[x], reverse=True)
ordered_result: list[tuple[tuple[str, str], int]] = []
for item in ordered_keys:
ordered_result.append((item, self.counts[item]))
return ordered_result[:limit] if limit else ordered_result
@dataclass(unsafe_hash=True)
class Node:
__slots__ = "name", "kind"
name: str
kind: NodeType
def __str__(self) -> str:
return self.name
@dataclass
class Edge:
# TODO: want weight to track number of connections on each edge
__slots__ = "src", "dest"
src: Node
dest: Node
def __str__(self) -> str:
return f"{self.src.name} -> {self.dest.name}"
@dataclass
class Digraph:
vertices: dict[Node, list[Node]] = field(default_factory=dict)
def add_node(self, node: Node):
if self.vertices.get(node, None):
pass # node exists
else:
self.vertices[node] = []
def add_edge(self, edge: Edge):
try:
self.vertices[edge.src].append(edge.dest)
except KeyError as err:
return err
def children(self, node: Node) -> list[Node]:
return self.vertices[node]
def has_node(self, node: Node) -> bool:
return node in self.vertices
def get_node(self, name: str) -> Node:
for n in self.vertices:
if n.name == name:
return n
raise NameError(name)
def node_pairs(self) -> list[tuple[str, str]]:
pairs = []
for src in self.vertices:
for dest in self.vertices[src]:
pairs.append((src.name, dest.name))
return pairs
# * start analytics section
# * basics
def vertex_count(self, by_type: bool = True) -> Union[int, Counter]:
if by_type:
return Counter([vertex.kind.name for vertex in self.vertices])
return len(self.vertices)
def edge_count(self) -> int:
return sum(len(self.children(node)) for node in self.vertices)
def vertex_degree(
self, vertex: Union[Node, None] = None
) -> Union[int, list[tuple[int, int]]]:
if vertex:
return len(self.children(vertex))
return sorted(
[(n, len(self.children(n))) for n in self.vertices],
key=lambda x: x[1],
reverse=True,
)
# * advanced
def edge_rank(
self, vertex: Union[Node, None] = None, limit: Union[int, None] = None
) -> list[tuple[tuple[str, str], int]]:
if vertex:
pairs = []
for partner in self.vertices[vertex]:
pairs.append((vertex.name, partner.name))
return CustomCounter(pairs).most_common(limit)
return CustomCounter(self.node_pairs()).most_common(limit)
@dataclass
class Graph(Digraph):
def add_edge(self, edge: Edge):
Digraph.add_edge(self, edge)
rev = Edge(edge.dest, edge.src)
Digraph.add_edge(self, rev)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment