Skip to content

Instantly share code, notes, and snippets.

@allenap
Created July 11, 2012 17:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save allenap/5d00abe7a7bd7603f1d5 to your computer and use it in GitHub Desktop.
Save allenap/5d00abe7a7bd7603f1d5 to your computer and use it in GitHub Desktop.
Patch: get_reader() and get_writer() can now defer
diff --git a/tftp/backend.py b/tftp/backend.py
index 76d056b..eb992ec 100644
--- a/tftp/backend.py
+++ b/tftp/backend.py
@@ -1,6 +1,7 @@
'''
@author: shylent
'''
+from os import fstat
from tftp.errors import Unsupported, FileExists, AccessViolation, FileNotFound
from twisted.python.filepath import FilePath, InsecurePath
import shutil
@@ -32,7 +33,8 @@ class IBackend(interface.Interface):
@raise BackendError: for any other errors, that were encountered while
attempting to construct a reader
- @return: an object, that provides L{IReader}
+ @return: an object, that provides L{IReader}, or a L{Deferred} that
+ will fire with an L{IReader}
"""
@@ -55,7 +57,8 @@ class IBackend(interface.Interface):
@raise BackendError: for any other errors, that were encountered while
attempting to construct a writer
- @return: an object, that provides L{IWriter}
+ @return: an object, that provides L{IWriter}, or a L{Deferred} that
+ will fire with an L{IWriter}
"""
@@ -139,7 +142,10 @@ class FilesystemReader(object):
@see: L{IReader.size}
"""
- return self.file_path.getsize()
+ if self.file_obj.closed:
+ return None
+ else:
+ return fstat(self.file_obj.fileno()).st_size
def read(self, size):
"""
diff --git a/tftp/protocol.py b/tftp/protocol.py
index 9566560..a226b48 100644
--- a/tftp/protocol.py
+++ b/tftp/protocol.py
@@ -9,6 +9,7 @@ from tftp.errors import (FileExists, Unsupported, AccessViolation, BackendError,
FileNotFound)
from tftp.netascii import NetasciiReceiverProxy, NetasciiSenderProxy
from twisted.internet import reactor
+from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.internet.protocol import DatagramProtocol
from twisted.python import log
@@ -42,34 +43,39 @@ class TFTP(DatagramProtocol):
return self.transport.write(ERRORDatagram.from_code(ERR_ILLEGAL_OP,
"Unknown transfer mode %s, - expected "
"'netascii' or 'octet' (case-insensitive)" % mode).to_wire(), addr)
+
+ self._clock.callLater(0, self._startSession, datagram, addr, mode)
+
+ @inlineCallbacks
+ def _startSession(self, datagram, addr, mode):
try:
if datagram.opcode == OP_WRQ:
- fs_interface = self.backend.get_writer(datagram.filename)
+ fs_interface = yield self.backend.get_writer(datagram.filename)
elif datagram.opcode == OP_RRQ:
- fs_interface = self.backend.get_reader(datagram.filename)
+ fs_interface = yield self.backend.get_reader(datagram.filename)
except Unsupported, e:
- return self.transport.write(ERRORDatagram.from_code(ERR_ILLEGAL_OP,
+ self.transport.write(ERRORDatagram.from_code(ERR_ILLEGAL_OP,
str(e)).to_wire(), addr)
except AccessViolation:
- return self.transport.write(ERRORDatagram.from_code(ERR_ACCESS_VIOLATION).to_wire(), addr)
+ self.transport.write(ERRORDatagram.from_code(ERR_ACCESS_VIOLATION).to_wire(), addr)
except FileExists:
- return self.transport.write(ERRORDatagram.from_code(ERR_FILE_EXISTS).to_wire(), addr)
+ self.transport.write(ERRORDatagram.from_code(ERR_FILE_EXISTS).to_wire(), addr)
except FileNotFound:
- return self.transport.write(ERRORDatagram.from_code(ERR_FILE_NOT_FOUND).to_wire(), addr)
+ self.transport.write(ERRORDatagram.from_code(ERR_FILE_NOT_FOUND).to_wire(), addr)
except BackendError, e:
- return self.transport.write(ERRORDatagram.from_code(ERR_NOT_DEFINED, str(e)).to_wire(), addr)
-
- if datagram.opcode == OP_WRQ:
- if mode == 'netascii':
- fs_interface = NetasciiReceiverProxy(fs_interface)
- session = RemoteOriginWriteSession(addr, fs_interface,
- datagram.options, _clock=self._clock)
- reactor.listenUDP(0, session)
- return session
- elif datagram.opcode == OP_RRQ:
- if mode == 'netascii':
- fs_interface = NetasciiSenderProxy(fs_interface)
- session = RemoteOriginReadSession(addr, fs_interface,
- datagram.options, _clock=self._clock)
- reactor.listenUDP(0, session)
- return session
+ self.transport.write(ERRORDatagram.from_code(ERR_NOT_DEFINED, str(e)).to_wire(), addr)
+ else:
+ if datagram.opcode == OP_WRQ:
+ if mode == 'netascii':
+ fs_interface = NetasciiReceiverProxy(fs_interface)
+ session = RemoteOriginWriteSession(addr, fs_interface,
+ datagram.options, _clock=self._clock)
+ reactor.listenUDP(0, session)
+ returnValue(session)
+ elif datagram.opcode == OP_RRQ:
+ if mode == 'netascii':
+ fs_interface = NetasciiSenderProxy(fs_interface)
+ session = RemoteOriginReadSession(addr, fs_interface,
+ datagram.options, _clock=self._clock)
+ reactor.listenUDP(0, session)
+ returnValue(session)
diff --git a/tftp/test/test_backend.py b/tftp/test/test_backend.py
index ce4a899..ce7f58e 100644
--- a/tftp/test/test_backend.py
+++ b/tftp/test/test_backend.py
@@ -77,6 +77,18 @@ line3
r = FilesystemReader(self.temp_dir.child('foo'))
self.assertEqual(len(self.test_data), r.size)
+ def test_size_when_reader_finished(self):
+ r = FilesystemReader(self.temp_dir.child('foo'))
+ r.finish()
+ self.assertIsNone(r.size)
+
+ def test_size_when_file_removed(self):
+ # FilesystemReader.size uses fstat() to discover the file's size, so
+ # the absence of the file does not matter.
+ r = FilesystemReader(self.temp_dir.child('foo'))
+ self.existing_file_name.remove()
+ self.assertEqual(len(self.test_data), r.size)
+
def test_cancel(self):
r = FilesystemReader(self.temp_dir.child('foo'))
r.read(3)
diff --git a/tftp/test/test_protocol.py b/tftp/test/test_protocol.py
index e6ff434..d383618 100644
--- a/tftp/test/test_protocol.py
+++ b/tftp/test/test_protocol.py
@@ -1,7 +1,7 @@
'''
@author: shylent
'''
-from tftp.backend import FilesystemSynchronousBackend
+from tftp.backend import FilesystemSynchronousBackend, IReader, IWriter
from tftp.bootstrap import RemoteOriginWriteSession, RemoteOriginReadSession
from tftp.datagram import (WRQDatagram, TFTPDatagramFactory, split_opcode,
ERR_ILLEGAL_OP, RRQDatagram, ERR_ACCESS_VIOLATION, ERR_FILE_EXISTS,
@@ -72,12 +72,14 @@ class DispatchErrors(unittest.TestCase):
tftp.transport = self.transport
wrq_datagram = WRQDatagram('foobar', 'netascii', {})
tftp.datagramReceived(wrq_datagram.to_wire(), ('127.0.0.1', 1111))
+ self.clock.advance(1)
error_datagram = TFTPDatagramFactory(*split_opcode(self.transport.value()))
self.assertEqual(error_datagram.errorcode, ERR_ILLEGAL_OP)
self.transport.clear()
rrq_datagram = RRQDatagram('foobar', 'octet', {})
tftp.datagramReceived(rrq_datagram.to_wire(), ('127.0.0.1', 1111))
+ self.clock.advance(1)
error_datagram = TFTPDatagramFactory(*split_opcode(self.transport.value()))
self.assertEqual(error_datagram.errorcode, ERR_ILLEGAL_OP)
@@ -86,12 +88,14 @@ class DispatchErrors(unittest.TestCase):
tftp.transport = self.transport
wrq_datagram = WRQDatagram('foobar', 'netascii', {})
tftp.datagramReceived(wrq_datagram.to_wire(), ('127.0.0.1', 1111))
+ self.clock.advance(1)
error_datagram = TFTPDatagramFactory(*split_opcode(self.transport.value()))
self.assertEqual(error_datagram.errorcode, ERR_ACCESS_VIOLATION)
self.transport.clear()
rrq_datagram = RRQDatagram('foobar', 'octet', {})
tftp.datagramReceived(rrq_datagram.to_wire(), ('127.0.0.1', 1111))
+ self.clock.advance(1)
error_datagram = TFTPDatagramFactory(*split_opcode(self.transport.value()))
self.assertEqual(error_datagram.errorcode, ERR_ACCESS_VIOLATION)
@@ -100,6 +104,7 @@ class DispatchErrors(unittest.TestCase):
tftp.transport = self.transport
wrq_datagram = WRQDatagram('foobar', 'netascii', {})
tftp.datagramReceived(wrq_datagram.to_wire(), ('127.0.0.1', 1111))
+ self.clock.advance(1)
error_datagram = TFTPDatagramFactory(*split_opcode(self.transport.value()))
self.assertEqual(error_datagram.errorcode, ERR_FILE_EXISTS)
@@ -108,6 +113,7 @@ class DispatchErrors(unittest.TestCase):
tftp.transport = self.transport
rrq_datagram = RRQDatagram('foobar', 'netascii', {})
tftp.datagramReceived(rrq_datagram.to_wire(), ('127.0.0.1', 1111))
+ self.clock.advance(1)
error_datagram = TFTPDatagramFactory(*split_opcode(self.transport.value()))
self.assertEqual(error_datagram.errorcode, ERR_FILE_NOT_FOUND)
@@ -116,12 +122,14 @@ class DispatchErrors(unittest.TestCase):
tftp.transport = self.transport
rrq_datagram = RRQDatagram('foobar', 'netascii', {})
tftp.datagramReceived(rrq_datagram.to_wire(), ('127.0.0.1', 1111))
+ self.clock.advance(1)
error_datagram = TFTPDatagramFactory(*split_opcode(self.transport.value()))
self.assertEqual(error_datagram.errorcode, ERR_NOT_DEFINED)
self.transport.clear()
rrq_datagram = RRQDatagram('foobar', 'octet', {})
tftp.datagramReceived(rrq_datagram.to_wire(), ('127.0.0.1', 1111))
+ self.clock.advance(1)
error_datagram = TFTPDatagramFactory(*split_opcode(self.transport.value()))
self.assertEqual(error_datagram.errorcode, ERR_NOT_DEFINED)
@@ -135,8 +143,15 @@ class DummyClient(DatagramProtocol):
class TFTPWrapper(TFTP):
- def datagramReceived(self, *args, **kwargs):
- self.session = TFTP.datagramReceived(self, *args, **kwargs)
+ def _startSession(self, *args, **kwargs):
+ d = TFTP._startSession(self, *args, **kwargs)
+
+ def save_session(session):
+ self.session = session
+ return session
+
+ d.addCallback(save_session)
+ return d
class SuccessfulDispatch(unittest.TestCase):
@@ -156,8 +171,8 @@ class SuccessfulDispatch(unittest.TestCase):
self.client.transport.write(WRQDatagram('foobar', 'NetASCiI', {}).to_wire(), ('127.0.0.1', 1069))
d = Deferred()
def cb(ign):
- self.failUnless(isinstance(self.tftp.session, RemoteOriginWriteSession))
- self.failUnless(isinstance(self.tftp.session.backend, NetasciiReceiverProxy))
+ self.assertIsInstance(self.tftp.session, RemoteOriginWriteSession)
+ self.assertIsInstance(self.tftp.session.backend, NetasciiReceiverProxy)
self.tftp.session.cancel()
d.addCallback(cb)
reactor.callLater(0.5, d.callback, None)
@@ -167,8 +182,8 @@ class SuccessfulDispatch(unittest.TestCase):
self.client.transport.write(RRQDatagram('nonempty', 'NetASCiI', {}).to_wire(), ('127.0.0.1', 1069))
d = Deferred()
def cb(ign):
- self.failUnless(isinstance(self.tftp.session, RemoteOriginReadSession))
- self.failUnless(isinstance(self.tftp.session.backend, NetasciiSenderProxy))
+ self.assertIsInstance(self.tftp.session, RemoteOriginReadSession)
+ self.assertIsInstance(self.tftp.session.backend, NetasciiSenderProxy)
self.tftp.session.cancel()
d.addCallback(cb)
reactor.callLater(0.5, d.callback, None)
@@ -177,3 +192,54 @@ class SuccessfulDispatch(unittest.TestCase):
def tearDown(self):
self.tftp.transport.stopListening()
self.client.transport.stopListening()
+
+
+class FilesystemAsyncBackend(FilesystemSynchronousBackend):
+
+ def __init__(self, base_path, clock):
+ super(FilesystemAsyncBackend, self).__init__(
+ base_path, can_read=True, can_write=True)
+ self.clock = clock
+
+ def get_reader(self, file_name):
+ reader = super(FilesystemAsyncBackend, self).get_reader(file_name)
+ d = Deferred()
+ self.clock.callLater(0, d.callback, reader)
+ return d
+
+ def get_writer(self, file_name):
+ writer = super(FilesystemAsyncBackend, self).get_writer(file_name)
+ d = Deferred()
+ self.clock.callLater(0, d.callback, writer)
+ return d
+
+
+class SuccessfulAsyncDispatch(unittest.TestCase):
+
+ def setUp(self):
+ self.clock = Clock()
+ self.tmp_dir_path = tempfile.mkdtemp()
+ with FilePath(self.tmp_dir_path).child('nonempty').open('w') as fd:
+ fd.write('Something uninteresting')
+ self.backend = FilesystemAsyncBackend(self.tmp_dir_path, self.clock)
+ self.tftp = TFTP(self.backend, self.clock)
+
+ def test_get_reader_can_defer(self):
+ rrq_datagram = RRQDatagram('nonempty', 'NetASCiI', {})
+ rrq_addr = ('127.0.0.1', 1069)
+ rrq_mode = "octet"
+ d = self.tftp._startSession(rrq_datagram, rrq_addr, rrq_mode)
+ self.assertFalse(d.called)
+ self.clock.advance(1)
+ self.assertTrue(d.called)
+ self.assertTrue(IReader.providedBy(d.result.backend))
+
+ def test_get_writer_can_defer(self):
+ wrq_datagram = WRQDatagram('foobar', 'NetASCiI', {})
+ wrq_addr = ('127.0.0.1', 1069)
+ wrq_mode = "octet"
+ d = self.tftp._startSession(wrq_datagram, wrq_addr, wrq_mode)
+ self.assertFalse(d.called)
+ self.clock.advance(1)
+ self.assertTrue(d.called)
+ self.assertTrue(IWriter.providedBy(d.result.backend))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment