-
-
Save allenap/5d00abe7a7bd7603f1d5 to your computer and use it in GitHub Desktop.
Patch: get_reader() and get_writer() can now defer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/tftp/backend.py b/tftp/backend.py | |
index 76d056b..eb992ec 100644 | |
--- a/tftp/backend.py | |
+++ b/tftp/backend.py | |
@@ -1,6 +1,7 @@ | |
''' | |
@author: shylent | |
''' | |
+from os import fstat | |
from tftp.errors import Unsupported, FileExists, AccessViolation, FileNotFound | |
from twisted.python.filepath import FilePath, InsecurePath | |
import shutil | |
@@ -32,7 +33,8 @@ class IBackend(interface.Interface): | |
@raise BackendError: for any other errors, that were encountered while | |
attempting to construct a reader | |
- @return: an object, that provides L{IReader} | |
+ @return: an object, that provides L{IReader}, or a L{Deferred} that | |
+ will fire with an L{IReader} | |
""" | |
@@ -55,7 +57,8 @@ class IBackend(interface.Interface): | |
@raise BackendError: for any other errors, that were encountered while | |
attempting to construct a writer | |
- @return: an object, that provides L{IWriter} | |
+ @return: an object, that provides L{IWriter}, or a L{Deferred} that | |
+ will fire with an L{IWriter} | |
""" | |
@@ -139,7 +142,10 @@ class FilesystemReader(object): | |
@see: L{IReader.size} | |
""" | |
- return self.file_path.getsize() | |
+ if self.file_obj.closed: | |
+ return None | |
+ else: | |
+ return fstat(self.file_obj.fileno()).st_size | |
def read(self, size): | |
""" | |
diff --git a/tftp/protocol.py b/tftp/protocol.py | |
index 9566560..a226b48 100644 | |
--- a/tftp/protocol.py | |
+++ b/tftp/protocol.py | |
@@ -9,6 +9,7 @@ from tftp.errors import (FileExists, Unsupported, AccessViolation, BackendError, | |
FileNotFound) | |
from tftp.netascii import NetasciiReceiverProxy, NetasciiSenderProxy | |
from twisted.internet import reactor | |
+from twisted.internet.defer import inlineCallbacks, returnValue | |
from twisted.internet.protocol import DatagramProtocol | |
from twisted.python import log | |
@@ -42,34 +43,39 @@ class TFTP(DatagramProtocol): | |
return self.transport.write(ERRORDatagram.from_code(ERR_ILLEGAL_OP, | |
"Unknown transfer mode %s, - expected " | |
"'netascii' or 'octet' (case-insensitive)" % mode).to_wire(), addr) | |
+ | |
+ self._clock.callLater(0, self._startSession, datagram, addr, mode) | |
+ | |
+ @inlineCallbacks | |
+ def _startSession(self, datagram, addr, mode): | |
try: | |
if datagram.opcode == OP_WRQ: | |
- fs_interface = self.backend.get_writer(datagram.filename) | |
+ fs_interface = yield self.backend.get_writer(datagram.filename) | |
elif datagram.opcode == OP_RRQ: | |
- fs_interface = self.backend.get_reader(datagram.filename) | |
+ fs_interface = yield self.backend.get_reader(datagram.filename) | |
except Unsupported, e: | |
- return self.transport.write(ERRORDatagram.from_code(ERR_ILLEGAL_OP, | |
+ self.transport.write(ERRORDatagram.from_code(ERR_ILLEGAL_OP, | |
str(e)).to_wire(), addr) | |
except AccessViolation: | |
- return self.transport.write(ERRORDatagram.from_code(ERR_ACCESS_VIOLATION).to_wire(), addr) | |
+ self.transport.write(ERRORDatagram.from_code(ERR_ACCESS_VIOLATION).to_wire(), addr) | |
except FileExists: | |
- return self.transport.write(ERRORDatagram.from_code(ERR_FILE_EXISTS).to_wire(), addr) | |
+ self.transport.write(ERRORDatagram.from_code(ERR_FILE_EXISTS).to_wire(), addr) | |
except FileNotFound: | |
- return self.transport.write(ERRORDatagram.from_code(ERR_FILE_NOT_FOUND).to_wire(), addr) | |
+ self.transport.write(ERRORDatagram.from_code(ERR_FILE_NOT_FOUND).to_wire(), addr) | |
except BackendError, e: | |
- return self.transport.write(ERRORDatagram.from_code(ERR_NOT_DEFINED, str(e)).to_wire(), addr) | |
- | |
- if datagram.opcode == OP_WRQ: | |
- if mode == 'netascii': | |
- fs_interface = NetasciiReceiverProxy(fs_interface) | |
- session = RemoteOriginWriteSession(addr, fs_interface, | |
- datagram.options, _clock=self._clock) | |
- reactor.listenUDP(0, session) | |
- return session | |
- elif datagram.opcode == OP_RRQ: | |
- if mode == 'netascii': | |
- fs_interface = NetasciiSenderProxy(fs_interface) | |
- session = RemoteOriginReadSession(addr, fs_interface, | |
- datagram.options, _clock=self._clock) | |
- reactor.listenUDP(0, session) | |
- return session | |
+ self.transport.write(ERRORDatagram.from_code(ERR_NOT_DEFINED, str(e)).to_wire(), addr) | |
+ else: | |
+ if datagram.opcode == OP_WRQ: | |
+ if mode == 'netascii': | |
+ fs_interface = NetasciiReceiverProxy(fs_interface) | |
+ session = RemoteOriginWriteSession(addr, fs_interface, | |
+ datagram.options, _clock=self._clock) | |
+ reactor.listenUDP(0, session) | |
+ returnValue(session) | |
+ elif datagram.opcode == OP_RRQ: | |
+ if mode == 'netascii': | |
+ fs_interface = NetasciiSenderProxy(fs_interface) | |
+ session = RemoteOriginReadSession(addr, fs_interface, | |
+ datagram.options, _clock=self._clock) | |
+ reactor.listenUDP(0, session) | |
+ returnValue(session) | |
diff --git a/tftp/test/test_backend.py b/tftp/test/test_backend.py | |
index ce4a899..ce7f58e 100644 | |
--- a/tftp/test/test_backend.py | |
+++ b/tftp/test/test_backend.py | |
@@ -77,6 +77,18 @@ line3 | |
r = FilesystemReader(self.temp_dir.child('foo')) | |
self.assertEqual(len(self.test_data), r.size) | |
+ def test_size_when_reader_finished(self): | |
+ r = FilesystemReader(self.temp_dir.child('foo')) | |
+ r.finish() | |
+ self.assertIsNone(r.size) | |
+ | |
+ def test_size_when_file_removed(self): | |
+ # FilesystemReader.size uses fstat() to discover the file's size, so | |
+ # the absence of the file does not matter. | |
+ r = FilesystemReader(self.temp_dir.child('foo')) | |
+ self.existing_file_name.remove() | |
+ self.assertEqual(len(self.test_data), r.size) | |
+ | |
def test_cancel(self): | |
r = FilesystemReader(self.temp_dir.child('foo')) | |
r.read(3) | |
diff --git a/tftp/test/test_protocol.py b/tftp/test/test_protocol.py | |
index e6ff434..d383618 100644 | |
--- a/tftp/test/test_protocol.py | |
+++ b/tftp/test/test_protocol.py | |
@@ -1,7 +1,7 @@ | |
''' | |
@author: shylent | |
''' | |
-from tftp.backend import FilesystemSynchronousBackend | |
+from tftp.backend import FilesystemSynchronousBackend, IReader, IWriter | |
from tftp.bootstrap import RemoteOriginWriteSession, RemoteOriginReadSession | |
from tftp.datagram import (WRQDatagram, TFTPDatagramFactory, split_opcode, | |
ERR_ILLEGAL_OP, RRQDatagram, ERR_ACCESS_VIOLATION, ERR_FILE_EXISTS, | |
@@ -72,12 +72,14 @@ class DispatchErrors(unittest.TestCase): | |
tftp.transport = self.transport | |
wrq_datagram = WRQDatagram('foobar', 'netascii', {}) | |
tftp.datagramReceived(wrq_datagram.to_wire(), ('127.0.0.1', 1111)) | |
+ self.clock.advance(1) | |
error_datagram = TFTPDatagramFactory(*split_opcode(self.transport.value())) | |
self.assertEqual(error_datagram.errorcode, ERR_ILLEGAL_OP) | |
self.transport.clear() | |
rrq_datagram = RRQDatagram('foobar', 'octet', {}) | |
tftp.datagramReceived(rrq_datagram.to_wire(), ('127.0.0.1', 1111)) | |
+ self.clock.advance(1) | |
error_datagram = TFTPDatagramFactory(*split_opcode(self.transport.value())) | |
self.assertEqual(error_datagram.errorcode, ERR_ILLEGAL_OP) | |
@@ -86,12 +88,14 @@ class DispatchErrors(unittest.TestCase): | |
tftp.transport = self.transport | |
wrq_datagram = WRQDatagram('foobar', 'netascii', {}) | |
tftp.datagramReceived(wrq_datagram.to_wire(), ('127.0.0.1', 1111)) | |
+ self.clock.advance(1) | |
error_datagram = TFTPDatagramFactory(*split_opcode(self.transport.value())) | |
self.assertEqual(error_datagram.errorcode, ERR_ACCESS_VIOLATION) | |
self.transport.clear() | |
rrq_datagram = RRQDatagram('foobar', 'octet', {}) | |
tftp.datagramReceived(rrq_datagram.to_wire(), ('127.0.0.1', 1111)) | |
+ self.clock.advance(1) | |
error_datagram = TFTPDatagramFactory(*split_opcode(self.transport.value())) | |
self.assertEqual(error_datagram.errorcode, ERR_ACCESS_VIOLATION) | |
@@ -100,6 +104,7 @@ class DispatchErrors(unittest.TestCase): | |
tftp.transport = self.transport | |
wrq_datagram = WRQDatagram('foobar', 'netascii', {}) | |
tftp.datagramReceived(wrq_datagram.to_wire(), ('127.0.0.1', 1111)) | |
+ self.clock.advance(1) | |
error_datagram = TFTPDatagramFactory(*split_opcode(self.transport.value())) | |
self.assertEqual(error_datagram.errorcode, ERR_FILE_EXISTS) | |
@@ -108,6 +113,7 @@ class DispatchErrors(unittest.TestCase): | |
tftp.transport = self.transport | |
rrq_datagram = RRQDatagram('foobar', 'netascii', {}) | |
tftp.datagramReceived(rrq_datagram.to_wire(), ('127.0.0.1', 1111)) | |
+ self.clock.advance(1) | |
error_datagram = TFTPDatagramFactory(*split_opcode(self.transport.value())) | |
self.assertEqual(error_datagram.errorcode, ERR_FILE_NOT_FOUND) | |
@@ -116,12 +122,14 @@ class DispatchErrors(unittest.TestCase): | |
tftp.transport = self.transport | |
rrq_datagram = RRQDatagram('foobar', 'netascii', {}) | |
tftp.datagramReceived(rrq_datagram.to_wire(), ('127.0.0.1', 1111)) | |
+ self.clock.advance(1) | |
error_datagram = TFTPDatagramFactory(*split_opcode(self.transport.value())) | |
self.assertEqual(error_datagram.errorcode, ERR_NOT_DEFINED) | |
self.transport.clear() | |
rrq_datagram = RRQDatagram('foobar', 'octet', {}) | |
tftp.datagramReceived(rrq_datagram.to_wire(), ('127.0.0.1', 1111)) | |
+ self.clock.advance(1) | |
error_datagram = TFTPDatagramFactory(*split_opcode(self.transport.value())) | |
self.assertEqual(error_datagram.errorcode, ERR_NOT_DEFINED) | |
@@ -135,8 +143,15 @@ class DummyClient(DatagramProtocol): | |
class TFTPWrapper(TFTP): | |
- def datagramReceived(self, *args, **kwargs): | |
- self.session = TFTP.datagramReceived(self, *args, **kwargs) | |
+ def _startSession(self, *args, **kwargs): | |
+ d = TFTP._startSession(self, *args, **kwargs) | |
+ | |
+ def save_session(session): | |
+ self.session = session | |
+ return session | |
+ | |
+ d.addCallback(save_session) | |
+ return d | |
class SuccessfulDispatch(unittest.TestCase): | |
@@ -156,8 +171,8 @@ class SuccessfulDispatch(unittest.TestCase): | |
self.client.transport.write(WRQDatagram('foobar', 'NetASCiI', {}).to_wire(), ('127.0.0.1', 1069)) | |
d = Deferred() | |
def cb(ign): | |
- self.failUnless(isinstance(self.tftp.session, RemoteOriginWriteSession)) | |
- self.failUnless(isinstance(self.tftp.session.backend, NetasciiReceiverProxy)) | |
+ self.assertIsInstance(self.tftp.session, RemoteOriginWriteSession) | |
+ self.assertIsInstance(self.tftp.session.backend, NetasciiReceiverProxy) | |
self.tftp.session.cancel() | |
d.addCallback(cb) | |
reactor.callLater(0.5, d.callback, None) | |
@@ -167,8 +182,8 @@ class SuccessfulDispatch(unittest.TestCase): | |
self.client.transport.write(RRQDatagram('nonempty', 'NetASCiI', {}).to_wire(), ('127.0.0.1', 1069)) | |
d = Deferred() | |
def cb(ign): | |
- self.failUnless(isinstance(self.tftp.session, RemoteOriginReadSession)) | |
- self.failUnless(isinstance(self.tftp.session.backend, NetasciiSenderProxy)) | |
+ self.assertIsInstance(self.tftp.session, RemoteOriginReadSession) | |
+ self.assertIsInstance(self.tftp.session.backend, NetasciiSenderProxy) | |
self.tftp.session.cancel() | |
d.addCallback(cb) | |
reactor.callLater(0.5, d.callback, None) | |
@@ -177,3 +192,54 @@ class SuccessfulDispatch(unittest.TestCase): | |
def tearDown(self): | |
self.tftp.transport.stopListening() | |
self.client.transport.stopListening() | |
+ | |
+ | |
+class FilesystemAsyncBackend(FilesystemSynchronousBackend): | |
+ | |
+ def __init__(self, base_path, clock): | |
+ super(FilesystemAsyncBackend, self).__init__( | |
+ base_path, can_read=True, can_write=True) | |
+ self.clock = clock | |
+ | |
+ def get_reader(self, file_name): | |
+ reader = super(FilesystemAsyncBackend, self).get_reader(file_name) | |
+ d = Deferred() | |
+ self.clock.callLater(0, d.callback, reader) | |
+ return d | |
+ | |
+ def get_writer(self, file_name): | |
+ writer = super(FilesystemAsyncBackend, self).get_writer(file_name) | |
+ d = Deferred() | |
+ self.clock.callLater(0, d.callback, writer) | |
+ return d | |
+ | |
+ | |
+class SuccessfulAsyncDispatch(unittest.TestCase): | |
+ | |
+ def setUp(self): | |
+ self.clock = Clock() | |
+ self.tmp_dir_path = tempfile.mkdtemp() | |
+ with FilePath(self.tmp_dir_path).child('nonempty').open('w') as fd: | |
+ fd.write('Something uninteresting') | |
+ self.backend = FilesystemAsyncBackend(self.tmp_dir_path, self.clock) | |
+ self.tftp = TFTP(self.backend, self.clock) | |
+ | |
+ def test_get_reader_can_defer(self): | |
+ rrq_datagram = RRQDatagram('nonempty', 'NetASCiI', {}) | |
+ rrq_addr = ('127.0.0.1', 1069) | |
+ rrq_mode = "octet" | |
+ d = self.tftp._startSession(rrq_datagram, rrq_addr, rrq_mode) | |
+ self.assertFalse(d.called) | |
+ self.clock.advance(1) | |
+ self.assertTrue(d.called) | |
+ self.assertTrue(IReader.providedBy(d.result.backend)) | |
+ | |
+ def test_get_writer_can_defer(self): | |
+ wrq_datagram = WRQDatagram('foobar', 'NetASCiI', {}) | |
+ wrq_addr = ('127.0.0.1', 1069) | |
+ wrq_mode = "octet" | |
+ d = self.tftp._startSession(wrq_datagram, wrq_addr, wrq_mode) | |
+ self.assertFalse(d.called) | |
+ self.clock.advance(1) | |
+ self.assertTrue(d.called) | |
+ self.assertTrue(IWriter.providedBy(d.result.backend)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment