Skip to content

Instantly share code, notes, and snippets.

@meyt
Created October 5, 2021 02:05
Show Gist options
  • Save meyt/552ffa81cdddb74afec6f157262df0ac to your computer and use it in GitHub Desktop.
Save meyt/552ffa81cdddb74afec6f157262df0ac to your computer and use it in GitHub Desktop.
Watch active postgres connections and client processes info (name, pid, cmdline)
"""
Watch active postgres connections and client processes info (name, pid, cmdline)
Note: postgres server and clients must be in same machine.
https://gist.github.com/meyt
requirements:
- psycopg2 >= 2.8.5
- psutil >= 5.5.1
"""
import curses
import signal
import threading
import psutil
import psycopg2
class Database:
def __init__(self, dsn):
self.con = psycopg2.connect(dsn)
self.cur = self.con.cursor()
def query(self, sql):
self.cur.execute(sql)
cols = tuple(map(lambda x: x[0], self.cur.description))
for x in self.cur.fetchall():
yield dict(zip(cols, x))
self.con.commit()
class Output:
def __init__(self):
self.rowidx = 0
self.stdscr = None
def start(self):
self.stdscr = curses.initscr()
def end(self, *_, **__):
if not self.stdscr:
return
self.stdscr.keypad(0)
curses.echo()
curses.nocbreak()
curses.endwin()
def addline(self, ch):
_, width = self.stdscr.getmaxyx()
self.stdscr.addstr(self.rowidx, 0, ch * width)
self.rowidx += 1
def addrow(self, val, colwidths=None):
_, width = self.stdscr.getmaxyx()
if colwidths:
val = "".join(
map(
lambda x: str(x[1])[: colwidths[x[0]]].ljust(
colwidths[x[0]] + 1
),
enumerate(val),
)
)
self.stdscr.addstr(self.rowidx, 0, val.ljust(width))
self.rowidx += 1
def fillrows(self):
height, width = self.stdscr.getmaxyx()
for y in range(self.rowidx, height - 1):
self.stdscr.addstr(y, 0, " " * width)
def refresh(self):
global rowidx
self.stdscr.refresh()
self.fillrows()
self.rowidx = 0
class App:
def __init__(self) -> None:
self.exitevent = threading.Event()
self.db = Database("dbname=postgres user=postgres password=postgres")
self.out = Output()
self.port_pid_map = {}
self.pid_info_map = {}
self.tgap = 0.5
def run(self):
signal.signal(signal.SIGINT, self.kill_signal_handler)
signal.signal(signal.SIGTERM, self.kill_signal_handler)
for w in (self.worker1, self.worker2, self.worker3):
t = threading.Thread(
target=w,
daemon=True,
)
t.start()
signal.pause()
def kill_signal_handler(self, signal_number, _):
self.out.end()
print(
{
signal.SIGINT: "Killed by SIGINT",
signal.SIGTERM: "Killed by SIGTERM",
}[signal_number]
)
self.exitevent.set()
def worker1(self):
while not self.exitevent.is_set():
res = {}
for c in psutil.net_connections():
if c.raddr:
res[c.raddr[1]] = c.pid
if c.laddr:
res[c.laddr[1]] = c.pid
self.port_pid_map = res
self.exitevent.wait(self.tgap)
def worker2(self):
while not self.exitevent.is_set():
res = {}
for p in psutil.process_iter(["pid", "name", "cmdline"]):
res[p.pid] = p.info
self.pid_info_map = res
self.exitevent.wait(self.tgap)
def worker3(self):
out = self.out
out.start()
colwidths = (10, 10, 20, 20, 20, 20)
while not self.exitevent.is_set():
out.refresh()
pgcons = list(self.db.query("Select * from pg_stat_activity;"))
out.addrow("total: %s" % len(pgcons))
out.addline("=")
out.addrow(
("db", "pid", "state", "client", "process", "cmd"),
colwidths=colwidths,
)
for x in pgcons:
client_port = x["client_port"]
pid = (
self.port_pid_map.get(int(client_port))
if client_port
else None
)
psinfo = self.pid_info_map.get(pid) if pid else None
if psinfo:
cmdline = " ".join(psinfo["cmdline"])
psname = psinfo["name"]
else:
psname = None
cmdline = None
out.addrow(
(
x["datname"],
x["pid"],
x["state"],
"%s:%s" % (x["client_hostname"], client_port),
"%s:%s" % (pid, psname),
cmdline,
),
colwidths=colwidths,
)
self.exitevent.wait(self.tgap)
if __name__ == "__main__":
app = App()
app.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment