#
# Stackless compatible socket module:
#
# Author: Richard Tew <richard.m.tew@gmail.com>
#
# Source: http://stacklessexamples.googlecode.com/svn/trunk/examples/networking/stacklesssocket.py
#
# This code was written to serve as an example of Stackless Python usage.
# Feel free to email me with any questions, comments, or suggestions for
# improvement.
#
# This wraps the asyncore module and the dispatcher class it provides in order
# write a socket module replacement that uses channels to allow calls to it to
# block until a delayed event occurs.
#
# Not all aspects of the socket module are provided by this file. Examples of
# it in use can be seen at the bottom of this file.
#
# NOTE: Versions of the asyncore module from Python 2.4 or later include bug
# fixes and earlier versions will not guarantee correct behaviour.
# Specifically, it monitors for errors on sockets where the version in
# Python 2.3.3 does not.
#
# Possible improvements:
# - More correct error handling. When there is an error on a socket found by
# poll, there is no idea what it actually is.
# - Launching each bit of incoming data in its own tasklet on the recvChannel
# send is a little over the top. It should be possible to add it to the
# rest of the queued data
import stackless
import asyncore, weakref
import socket as stdsocket # We need the "socket" name for the function we export.
# If we are to masquerade as the socket module, we need to provide the constants.
if "__all__" in stdsocket.__dict__:
__all__ = stdsocket.__dict__
for k, v in stdsocket.__dict__.iteritems():
if k in __all__:
globals()[k] = v
elif k == "EBADF":
globals()[k] = v
else:
for k, v in stdsocket.__dict__.iteritems():
if k.upper() == k:
globals()[k] = v
error = stdsocket.error
timeout = stdsocket.timeout
# WARNING: this function blocks and is not thread safe.
# The only solution is to spawn a thread to handle all
# getaddrinfo requests. Implementing a stackless DNS
# lookup service is only second best as getaddrinfo may
# use other methods.
getaddrinfo = stdsocket.getaddrinfo
# urllib2 apparently uses this directly. We need to cater for that.
_fileobject = stdsocket._fileobject
# Someone needs to invoke asyncore.poll() regularly to keep the socket
# data moving. The "ManageSockets" function here is a simple example
# of such a function. It is started by StartManager(), which uses the
# global "managerRunning" to ensure that no more than one copy is
# running.
#
# If you think you can do this better, register an alternative to
# StartManager using stacklesssocket_manager(). Your function will be
# called every time a new socket is created; it's your responsibility
# to ensure it doesn't start multiple copies of itself unnecessarily.
#
managerRunning = False
def ManageSockets():
global managerRunning
while len(asyncore.socket_map):
# Check the sockets for activity.
asyncore.poll(0.05)
# Yield to give other tasklets a chance to be scheduled.
stackless.schedule()
managerRunning = False
def StartManager():
global managerRunning
if not managerRunning:
managerRunning = True
stackless.tasklet(ManageSockets)()
_manage_sockets_func = StartManager
def stacklesssocket_manager(mgr):
global _manage_sockets_func
_manage_sockets_func = mgr
def socket(*args, **kwargs):
import sys
if "socket" in sys.modules and sys.modules["socket"] is not stdsocket:
raise RuntimeError("Use 'stacklesssocket.install' instead of replacing the 'socket' module")
_realsocket_old = stdsocket._realsocket
_socketobject_old = stdsocket._socketobject
class _socketobject_new(_socketobject_old):
def __init__(self, family=AF_INET, type=SOCK_STREAM, proto=0, _sock=None):
# We need to do this here.
if _sock is None:
_sock = _realsocket_old(family, type, proto)
_sock = _fakesocket(_sock)
_manage_sockets_func()
_socketobject_old.__init__(self, family, type, proto, _sock)
if not isinstance(self._sock, _fakesocket):
raise RuntimeError("bad socket")
def accept(self):
sock, addr = self._sock.accept()
sock = _fakesocket(sock)
sock.wasConnected = True
return _socketobject_new(_sock=sock), addr
accept.__doc__ = _socketobject_old.accept.__doc__
def check_still_connected(f):
" Decorate socket functions to check they are still connected. "
def new_f(self, *args, **kwds):
if not self.connected:
# The socket was never connected.
if not self.wasConnected:
raise error(10057, "Socket is not connected")
# The socket has been closed already.
raise error(EBADF, 'Bad file descriptor')
return f(self, *args, **kwds)
return new_f
def install():
if stdsocket._realsocket is socket:
raise StandardError("Still installed")
stdsocket._realsocket = socket
stdsocket.socket = stdsocket.SocketType = stdsocket._socketobject = _socketobject_new
def uninstall():
stdsocket._realsocket = _realsocket_old
stdsocket.socket = stdsocket.SocketType = stdsocket._socketobject = _socketobject_old
class _fakesocket(asyncore.dispatcher):
connectChannel = None
acceptChannel = None
recvChannel = None
wasConnected = False
def __init__(self, realSocket):
# This is worth doing. I was passing in an invalid socket which
# was an instance of _fakesocket and it was causing tasklet death.
if not isinstance(realSocket, _realsocket_old):
raise StandardError("An invalid socket passed to fakesocket %s" % realSocket.__class__)
# This will register the real socket in the internal socket map.
asyncore.dispatcher.__init__(self, realSocket)
self.socket = realSocket
self.recvChannel = stackless.channel()
self.readString = ''
self.readIdx = 0
self.sendBuffer = ''
self.sendToBuffers = []
def __del__(self):
# There are no more users (sockets or files) of this fake socket, we
# are safe to close it fully. If we don't, asyncore will choke on
# the weakref failures.
self.close()
# The asyncore version of this function depends on socket being set
# which is not the case when this fake socket has been closed.
def __getattr__(self, attr):
if not hasattr(self, "socket"):
raise AttributeError("socket attribute unset on '"+ attr +"' lookup")
return getattr(self.socket, attr)
def add_channel(self, map=None):
if map is None:
map = self._map
map[self._fileno] = weakref.proxy(self)
def writable(self):
if self.socket.type != SOCK_DGRAM and not self.connected:
return True
return len(self.sendBuffer) or len(self.sendToBuffers)
def accept(self):
if not self.acceptChannel:
self.acceptChannel = stackless.channel()
return self.acceptChannel.receive()
def connect(self, address):
asyncore.dispatcher.connect(self, address)
# UDP sockets do not connect.
if self.socket.type != SOCK_DGRAM and not self.connected:
if not self.connectChannel:
self.connectChannel = stackless.channel()
# Prefer the sender. Do not block when sending, given that
# there is a tasklet known to be waiting, this will happen.
self.connectChannel.preference = 1
self.connectChannel.receive()
@check_still_connected
def send(self, data, flags=0):
self.sendBuffer += data
stackless.schedule()
return len(data)
@check_still_connected
def sendall(self, data, flags=0):
# WARNING: this will busy wait until all data is sent
# It should be possible to do away with the busy wait with
# the use of a channel.
self.sendBuffer += data
while self.sendBuffer:
stackless.schedule()
return len(data)
def sendto(self, sendData, sendArg1=None, sendArg2=None):
# sendto(data, address)
# sendto(data [, flags], address)
if sendArg2 is not None:
flags = sendArg1
sendAddress = sendArg2
else:
flags = 0
sendAddress = sendArg1
waitChannel = None
for idx, (data, address, channel, sentBytes) in enumerate(self.sendToBuffers):
if address == sendAddress:
self.sendToBuffers[idx] = (data + sendData, address, channel, sentBytes)
waitChannel = channel
break
if waitChannel is None:
waitChannel = stackless.channel()
self.sendToBuffers.append((sendData, sendAddress, waitChannel, 0))
return waitChannel.receive()
# Read at most byteCount bytes.
def recv(self, byteCount, flags=0):
# recv() must not concatenate two or more data fragments sent with
# send() on the remote side. Single fragment sent with single send()
# call should be split into strings of length less than or equal
# to 'byteCount', and returned by one or more recv() calls.
remainingBytes = self.readIdx != len(self.readString)
# TODO: Verify this connectivity behaviour.
if not self.connected:
# Sockets which have never been connected do this.
if not self.wasConnected:
raise error(10057, 'Socket is not connected')
# Sockets which were connected, but no longer are, use
# up the remaining input. Observed this with urllib.urlopen
# where it closes the socket and then allows the caller to
# use a file to access the body of the web page.
elif not remainingBytes:
self.readString = self.recvChannel.receive()
self.readIdx = 0
remainingBytes = len(self.readString)
if byteCount == 1 and remainingBytes:
ret = self.readString[self.readIdx]
self.readIdx += 1
elif self.readIdx == 0 and byteCount >= len(self.readString):
ret = self.readString
self.readString = ""
else:
idx = self.readIdx + byteCount
ret = self.readString[self.readIdx:idx]
self.readString = self.readString[idx:]
self.readIdx = 0
# ret will be '' when EOF.
return ret
def recvfrom(self, byteCount, flags=0):
if self.socket.type == SOCK_STREAM:
return self.recv(byteCount), None
# recvfrom() must not concatenate two or more packets.
# Each call should return the first 'byteCount' part of the packet.
data, address = self.recvChannel.receive()
return data[:byteCount], address
def close(self):
asyncore.dispatcher.close(self)
self.connected = False
self.accepting = False
self.sendBuffer = None # breaks the loop in sendall
# Clear out all the channels with relevant errors.
while self.acceptChannel and self.acceptChannel.balance < 0:
self.acceptChannel.send_exception(error, 9, 'Bad file descriptor')
while self.connectChannel and self.connectChannel.balance < 0:
self.connectChannel.send_exception(error, 10061, 'Connection refused')
while self.recvChannel and self.recvChannel.balance < 0:
# The closing of a socket is indicted by receiving nothing. The
# exception would have been sent if the server was killed, rather
# than closed down gracefully.
self.recvChannel.send("")
#self.recvChannel.send_exception(error, 10054, 'Connection reset by peer')
# asyncore doesn't support this. Why not?
def fileno(self):
return self.socket.fileno()
def handle_accept(self):
if self.acceptChannel and self.acceptChannel.balance < 0:
t = asyncore.dispatcher.accept(self)
if t is None:
return
t[0].setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
stackless.tasklet(self.acceptChannel.send)(t)
# Inform the blocked connect call that the connection has been made.
def handle_connect(self):
if self.socket.type != SOCK_DGRAM:
self.wasConnected = True
self.connectChannel.send(None)
# Asyncore says its done but self.readBuffer may be non-empty
# so can't close yet. Do nothing and let 'recv' trigger the close.
def handle_close(self):
pass
# Some error, just close the channel and let that raise errors to
# blocked calls.
def handle_expt(self):
self.close()
def handle_read(self):
try:
if self.socket.type == SOCK_DGRAM:
ret = self.socket.recvfrom(20000)
else:
ret = asyncore.dispatcher.recv(self, 20000)
# Not sure this is correct, but it seems to give the
# right behaviour. Namely removing the socket from
# asyncore.
if not ret:
self.close()
stackless.tasklet(self.recvChannel.send)(ret)
except stdsocket.error, err:
# If there's a read error assume the connection is
# broken and drop any pending output
if self.sendBuffer:
self.sendBuffer = ""
self.recvChannel.send_exception(stdsocket.error, err)
def handle_write(self):
if len(self.sendBuffer):
sentBytes = asyncore.dispatcher.send(self, self.sendBuffer[:512])
self.sendBuffer = self.sendBuffer[sentBytes:]
elif len(self.sendToBuffers):
data, address, channel, oldSentBytes = self.sendToBuffers[0]
sentBytes = self.socket.sendto(data, address)
totalSentBytes = oldSentBytes + sentBytes
if len(data) > sentBytes:
self.sendToBuffers[0] = data[sentBytes:], address, channel, totalSentBytes
else:
del self.sendToBuffers[0]
stackless.tasklet(channel.send)(totalSentBytes)
if __name__ == '__main__':
import sys
import struct
# Test code goes here.
testAddress = "127.0.0.1", 3000
info = -12345678
data = struct.pack("i", info)
dataLength = len(data)
def TestTCPServer(address):
global info, data, dataLength
print "server listen socket creation"
listenSocket = stdsocket.socket(AF_INET, SOCK_STREAM)
listenSocket.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
listenSocket.bind(address)
listenSocket.listen(5)
NUM_TESTS = 2
i = 1
while i < NUM_TESTS + 1:
# No need to schedule this tasklet as the accept should yield most
# of the time on the underlying channel.
print "server connection wait", i
currentSocket, clientAddress = listenSocket.accept()
print "server", i, "listen socket", currentSocket.fileno(), "from", clientAddress
if i == 1:
print "server closing (a)", i, "fd", currentSocket.fileno(), "id", id(currentSocket)
currentSocket.close()
print "server closed (a)", i
elif i == 2:
print "server test", i, "send"
currentSocket.send(data)
print "server test", i, "recv"
if currentSocket.recv(4) != "":
print "server recv(1)", i, "FAIL"
break
# multiple empty recvs are fine
if currentSocket.recv(4) != "":
print "server recv(2)", i, "FAIL"
break
else:
print "server closing (b)", i, "fd", currentSocket.fileno(), "id", id(currentSocket)
currentSocket.close()
print "server test", i, "OK"
i += 1
if i != NUM_TESTS+1:
print "server: FAIL", i
else:
print "server: OK", i
print "Done server"
def TestTCPClient(address):
global info, data, dataLength
# Attempt 1:
clientSocket = stdsocket.socket()
clientSocket.connect(address)
print "client connection (1) fd", clientSocket.fileno(), "id", id(clientSocket._sock), "waiting to recv"
if clientSocket.recv(5) != "":
print "client test", 1, "FAIL"
else:
print "client test", 1, "OK"
# Attempt 2:
clientSocket = stdsocket.socket()
clientSocket.connect(address)
print "client connection (2) fd", clientSocket.fileno(), "id", id(clientSocket._sock), "waiting to recv"
s = clientSocket.recv(dataLength)
if s == "":
print "client test", 2, "FAIL (disconnect)"
else:
t = struct.unpack("i", s)
if t[0] == info:
print "client test", 2, "OK"
else:
print "client test", 2, "FAIL (wrong data)"
print "client exit"
def TestMonkeyPatchUrllib(uri):
# replace the system socket with this module
#oldSocket = sys.modules["socket"]
#sys.modules["socket"] = __import__(__name__)
install()
try:
import urllib # must occur after monkey-patching!
f = urllib.urlopen(uri)
if not isinstance(f.fp._sock, _fakesocket):
raise AssertionError("failed to apply monkeypatch, got %s" % f.fp._sock.__class__)
s = f.read()
if len(s) != 0:
print "Fetched", len(s), "bytes via replaced urllib"
else:
raise AssertionError("no text received?")
finally:
#sys.modules["socket"] = oldSocket
uninstall()
def TestMonkeyPatchUDP(address):
# replace the system socket with this module
#oldSocket = sys.modules["socket"]
#sys.modules["socket"] = __import__(__name__)
install()
try:
def UDPServer(address):
listenSocket = stdsocket.socket(AF_INET, SOCK_DGRAM)
listenSocket.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
listenSocket.bind(address)
# Apparently each call to recvfrom maps to an incoming
# packet and if we only ask for part of that packet, the
# rest is lost. We really need a proper unittest suite
# which tests this module against the normal socket
# module.
print "waiting to receive"
data, address = listenSocket.recvfrom(256)
print "received", data, len(data)
if len(data) != 256:
raise StandardError("Unexpected UDP packet size")
def UDPClient(address):
clientSocket = stdsocket.socket(AF_INET, SOCK_DGRAM)
# clientSocket.connect(address)
print "sending 512 byte packet"
sentBytes = clientSocket.sendto("-"+ ("*" * 510) +"-", address)
print "sent 512 byte packet", sentBytes
stackless.tasklet(UDPServer)(address)
stackless.tasklet(UDPClient)(address)
stackless.run()
finally:
#sys.modules["socket"] = oldSocket
uninstall()
if len(sys.argv) == 2:
if sys.argv[1] == "client":
print "client started"
TestTCPClient(testAddress)
print "client exited"
elif sys.argv[1] == "slpclient":
print "client started"
stackless.tasklet(TestTCPClient)(testAddress)
stackless.run()
print "client exited"
elif sys.argv[1] == "server":
print "server started"
TestTCPServer(testAddress)
print "server exited"
elif sys.argv[1] == "slpserver":
print "server started"
stackless.tasklet(TestTCPServer)(testAddress)
stackless.run()
print "server exited"
else:
print "Usage:", sys.argv[0], "[client|server|slpclient|slpserver]"
sys.exit(1)
else:
print "* Running client/server test"
install()
try:
stackless.tasklet(TestTCPServer)(testAddress)
stackless.tasklet(TestTCPClient)(testAddress)
stackless.run()
finally:
uninstall()
print "* Running urllib test"
stackless.tasklet(TestMonkeyPatchUrllib)("http://python.org/")
stackless.run()
print "* Running udp test"
TestMonkeyPatchUDP(testAddress)
print "result: SUCCESS"