Skip to content

Instantly share code, notes, and snippets.

@zzeleznick
Created March 11, 2018 22:59
Show Gist options
  • Save zzeleznick/8e757ac8a211744856296b50a8313d7e to your computer and use it in GitHub Desktop.
Save zzeleznick/8e757ac8a211744856296b50a8313d7e to your computer and use it in GitHub Desktop.
Unix Socket Server
import click
# Add this to avoid the annoying warning: http://click.pocoo.org/5/python3/
click.disable_unicode_literals_warning = True
import sys
from select import select
# Internal modules
from sockets import SocketServer, ThreadedSocketServer, SocketClient
def main():
@click.group(help="Test sockets")
@click.pass_context
def cli(ctx):
pass
@cli.command(help="Launch Server")
def server():
sock_server = SocketServer()
while True:
try:
sock_server.read()
except Exception as e:
print e
break
del sock_server
@cli.command(help="Launch Threaded Server")
def threaded_server():
sock_server = ThreadedSocketServer()
sock_server.listen_for_connections()
@cli.command(help="Launch Client")
@click.argument('cmd_args', nargs=-1)
def client(cmd_args):
if cmd_args:
click.echo("args: {}".format(cmd_args))
sock_client = SocketClient()
sock_client.launch_reader_thread()
if cmd_args:
msg = " ".join(cmd_args)
print "SENDING:", msg
sock_client.write(msg)
print "Shutting down."
return
timeout = 5
click.echo("> ", nl=False)
while True:
try:
# Enable exit triggered by server shutdown at max delay 5s
rlist, _, _ = select([sys.stdin], [], [], timeout)
if rlist:
s = sys.stdin.readline().strip()
if not s:
click.echo("> ", nl=False)
continue
print "SENDING:", s
sock_client.write(s)
if "DONE" == s:
break
click.echo("> ", nl=False)
except KeyboardInterrupt:
print "Shutting down."
sock_client.close()
break
else:
print "Couldn't Connect!"
print "Done"
cli(obj={}, standalone_mode=False)
try:
main()
except click.exceptions.Abort:
print 'Aborted'
from abc import ABCMeta, abstractmethod
import os
import signal
import socket
import sys
import threading
import time
from itertools import count
from collections import OrderedDict
from thread import interrupt_main
from Queue import Queue
from os.path import abspath, dirname, join
class Socket(object):
__metaclass__ = ABCMeta
DEFAULT_NAME = join(abspath(dirname(__file__)), 'test.socket')
def __init__(self, name=""):
self.socket_file = name or self.DEFAULT_NAME
signal.signal(signal.SIGINT, self.sig_handler)
@abstractmethod
def read(self):
pass
@abstractmethod
def close(self):
pass
@abstractmethod
def write(self):
pass
def tearDown(self):
self.close()
def _remove_socket_file(self):
if os.path.exists(self.socket_file):
os.remove(self.socket_file)
def sig_handler(self, sig, frame):
print >>sys.stderr, "** Signal Received **"
self.tearDown()
raise KeyboardInterrupt()
def __del__(self):
print >>sys.stderr, "=== Delete Called ==="
self.tearDown()
print >>sys.stderr, "=== Delete Finished ==="
class SocketServer(Socket):
def __init__(self, name=""):
super(SocketServer, self).__init__(name)
self._remove_socket_file()
print >>sys.stderr, "=== Opening socket ==="
self.server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.server.bind(self.socket_file)
self.connection = None
self._listening = False
def listen(self):
print >>sys.stderr, "=== Listening ==="
# Listen for incoming connections
self.server.listen(1)
print >>sys.stderr, "=== Waiting for a Connection ==="
self.connection, _ = self.server.accept()
print >>sys.stderr, '*** Accepted connection ***'
self._listening = True
def read(self):
if not self._listening:
self.listen()
print >>sys.stderr, "=== Waiting for input === "
msg = self.connection.recv(1024)
print >>sys.stderr, 'Received', msg
if not msg:
print >>sys.stderr, 'Empty message'
return msg
def close(self):
print >>sys.stderr, "=== Cleaning up the connection ==="
if self.connection:
self.connection.close()
self.server.close()
def write(self, res):
try:
if not self._listening:
self.listen()
self.connection.send("{}".format(res))
except Exception as e:
print >>sys.stderr, e
class ThreadedSocketServer(Socket):
def __init__(self, name=""):
super(ThreadedSocketServer, self).__init__(name)
self._remove_socket_file()
print >>sys.stderr, "=== Opening socket ==="
self.server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.server.bind(self.socket_file)
self.counter = count(1)
self.connections = OrderedDict()
self.threads = OrderedDict()
self._listening = False
self.queue = Queue()
def read(self, connection, conn_idx=None):
"""A threaded read"""
current_thread = threading.current_thread()
retries = 2
ctr = count(1)
while self._listening:
print >>sys.stderr, "=== {} Waiting for input === ".format(current_thread)
msg = connection.recv(1024)
print >>sys.stderr, '** {} Received {} **'.format(current_thread, msg)
if not msg:
print >>sys.stderr, '** {} Empty message **'.format(current_thread)
if retries == 0: # Client did not send connection close
print >>sys.stderr, '** Closing {} due to empty messages **'.format(current_thread)
break
retries -= 1
else:
retries = 2
if 'echo' in msg:
print >>sys.stderr, '** {} Writing back message {} **'.format(current_thread, msg)
self.write(connection, msg)
elif 'task' in msg:
task = " ".join(msg.split('task')).strip()
task_id = "{}-{}".format(conn_idx, next(ctr))
resp = "Added task ({}): '{}' - Qsize: {}".format(task_id, task, self.queue.qsize())
print >>sys.stderr, '** {} {} **'.format(current_thread, resp)
self.queue.put((task_id, task))
self.write(connection, resp)
print "** {} exits **".format(current_thread)
if conn_idx in self.connections:
print "** Removing connection {} **".format(conn_idx)
self.connections[conn_idx].close()
del self.connections[conn_idx]
if conn_idx in self.threads:
# Cannot join current thread (https://github.com/python/cpython/blob/2.7/Lib/threading.py#L931)
# Just remove reference to it
print "** Releasing thread {} **".format(conn_idx)
del self.threads[conn_idx]
def process_queue(self):
print >>sys.stderr, "=== Starting process_queue ==="
while self._listening:
task_id, task = self.queue.get()
print >>sys.stderr, '** Executing task ({}): "{}" **'.format(task_id, task)
time.sleep(3) # add fake delay
self.queue.task_done()
conn_idx = int(task_id.split("-")[0])
resp = "Completed task ({}): '{}'".format(task_id, task)
if conn_idx in self.connections:
connection = self.connections[conn_idx]
self.write(connection, resp)
else:
# Possibly client disconnected while we were working
print >>sys.stderr, "Could not send client {} done status".format(conn_idx)
print >>sys.stderr, "=== End of process_queue ==="
def do_work(self):
thread = threading.Thread(target=self.process_queue)
thread.daemon = True
thread.start()
def listen_for_connections(self):
print >>sys.stderr, "=== Listening ==="
self._listening = True
self.do_work()
while self._listening:
try:
self.server.listen(1)
print >>sys.stderr, "=== Waiting for a Connection ==="
connection, _ = self.server.accept()
conn_idx = next(self.counter)
print >>sys.stderr, '*** Accepted connection {} ***'.format(conn_idx)
self.connections[conn_idx] = connection
thread = threading.Thread(target=self.read, args=(connection, conn_idx,))
self.threads[conn_idx] = thread
thread.daemon = True
thread.start()
except Exception as e:
print >>sys.stderr, e
break
self._listening = False
print >>sys.stderr, "=== End of listen_for_connections ==="
def close(self):
print >>sys.stderr, "=== Cleaning up the connection ==="
print self.connections
for (idx, connection) in self.connections.iteritems():
print >>sys.stderr, "** Closing connecton {} **".format(idx)
connection.close()
print self.threads
for (idx, thread) in self.threads.iteritems():
thread.join(1)
self._remove_socket_file()
print >>sys.stderr, "=== Finished close ==="
def tearDown(self):
self.close()
self._remove_socket_file()
def write(self, connection, res):
try:
connection.send("{}".format(res))
except Exception as e:
print >>sys.stderr, e
class SocketClient(Socket):
def __init__(self, name=""):
super(SocketClient, self).__init__(name)
self._listening = False
self.reader_thread = None
self.client = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.client.connect(self.socket_file)
self._listening = True
def read_loop(self):
retries = 2
while self._listening:
msg = self.read()
if not msg:
print >>sys.stderr, 'Empty message'
if retries == 0: # Something went wrong
print >>sys.stderr, 'Closing reader thread'
break
retries -= 1
else:
retries = 2
print >>sys.stderr, 'read_loop completed -- server exited?'
interrupt_main()
def read(self):
print >>sys.stderr, 'Waiting for input'
msg = self.client.recv(1024)
print >>sys.stderr, '** Received', msg, '**'
return msg
def launch_reader_thread(self):
print >>sys.stderr, 'Launching reader_thread'
if not self.client:
raise Exception("Not connected")
self.reader_thread = threading.Thread(target=self.read_loop)
self.reader_thread.daemon = True
self.reader_thread.start()
def write(self, res):
try:
self.client.send("{}".format(res))
except Exception as e:
print >>sys.stderr, e
raise
def close(self):
self._listening = False
if self.reader_thread:
self.reader_thread.join(1)
if self.client:
self.client.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment