Created
June 15, 2017 04:31
-
-
Save zaitcev/53b2c2b2b16f59387ca43d3a09872cfe to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
commit 5b4e5617096d984e2633f4697002a9592f564d12 | |
Author: Pete Zaitcev <zaitcev@kotori.zaitcev.us> | |
Date: Wed Jun 14 22:28:13 2017 -0600 | |
Buffer reads from disk (alt 1) | |
Otherwise, Python defaults to 8k reads which seems kinda terrible. | |
As a side-effect, get more picky about the types of input we wrap. While | |
you can still send unicode that can be encoded to ASCII, any other | |
unicode will raise a TypeError in LengthWrapper, rather than making it | |
all the way down to httplib. | |
Change-Id: I3160626e947083af487fd1c3cb0aa6a62646527b | |
Closes-Bug: #1671621 | |
diff --git a/swiftclient/utils.py b/swiftclient/utils.py | |
index 8afcde9..f38c7f5 100644 | |
--- a/swiftclient/utils.py | |
+++ b/swiftclient/utils.py | |
@@ -20,6 +20,7 @@ import hashlib | |
import hmac | |
import json | |
import logging | |
+import os | |
import six | |
import time | |
import traceback | |
@@ -301,7 +302,7 @@ class LengthWrapper(object): | |
Fix for https://github.com/kennethreitz/requests/issues/1648. | |
It is recommended to use this class only on files opened in binary mode. | |
""" | |
- def __init__(self, readable, length, md5=False): | |
+ def __init__(self, readable, length, md5=False, read_size=65536): | |
""" | |
:param readable: The filelike object to read from. | |
:param length: The maximum amount of content that can be read from | |
@@ -309,11 +310,16 @@ class LengthWrapper(object): | |
empty. | |
:param md5: Flag to enable calculating the MD5 of the content | |
as it is read. | |
+ :param read_size: The number of bytes that should be read from | |
+ ``readable`` at a time. | |
""" | |
self._md5 = md5 | |
self._reset_md5() | |
self._length = self._remaining = length | |
+ self._buffer = six.BytesIO() | |
+ self._buffer_size = 0 | |
self._readable = readable | |
+ self.read_size = read_size | |
self._can_reset = all(hasattr(readable, attr) | |
for attr in ('seek', 'tell')) | |
if self._can_reset: | |
@@ -328,12 +334,36 @@ class LengthWrapper(object): | |
def get_md5sum(self): | |
return self.md5sum.hexdigest() | |
+ def _look_ahead(self, remaining): | |
+ if self._buffer.tell() < self._buffer_size: | |
+ return | |
+ | |
+ data = self._readable.read(min(remaining, self.read_size)) | |
+ if not data: | |
+ return | |
+ if not isinstance(data, bytes): | |
+ try: | |
+ data = data.encode('ascii') | |
+ except UnicodeEncodeError: | |
+ raise TypeError('Object data must be bytes, not %s' % | |
+ type(data).__name__) | |
+ | |
+ # truncate the buffer to free some memory | |
+ self._buffer.truncate() | |
+ # must reset even after truncate, else write ends appending | |
+ self._buffer.seek(0, os.SEEK_SET) | |
+ | |
+ self._buffer.write(data) | |
+ self._buffer.seek(0, os.SEEK_SET) | |
+ self._buffer_size = len(data) | |
+ | |
def read(self, size=-1): | |
if self._remaining <= 0: | |
- return '' | |
+ return b'' | |
to_read = self._remaining if size < 0 else min(size, self._remaining) | |
- chunk = self._readable.read(to_read) | |
+ self._look_ahead(self._remaining) | |
+ chunk = self._buffer.read(to_read) | |
self._remaining -= len(chunk) | |
try: | |
@@ -354,6 +384,9 @@ class LengthWrapper(object): | |
if not self._can_reset: | |
raise TypeError('%r object cannot be reset; needs both seek and ' | |
'tell methods' % type(self._readable).__name__) | |
+ self._buffer.seek(0) | |
+ self._buffer.truncate() | |
+ self._buffer_size = 0 | |
self._readable.seek(self._start) | |
self._reset_md5() | |
self._remaining = self._length | |
diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py | |
index adead00..e5814dc 100644 | |
--- a/tests/unit/test_utils.py | |
+++ b/tests/unit/test_utils.py | |
@@ -17,6 +17,7 @@ import gzip | |
import unittest | |
import mock | |
import six | |
+import string | |
import tempfile | |
from time import gmtime, localtime, mktime, strftime, strptime | |
from hashlib import md5, sha1 | |
@@ -441,33 +442,98 @@ class TestLengthWrapper(unittest.TestCase): | |
contents.seek(22) | |
data = u.LengthWrapper(contents, 42, True) | |
s = u'a' * 28 + u'b' * 14 | |
- read_data = u''.join(iter(data.read, '')) | |
+ read_data = b''.join(iter(data.read, b'')) | |
self.assertEqual(42, len(data)) | |
self.assertEqual(42, len(read_data)) | |
- self.assertEqual(s, read_data) | |
+ self.assertEqual(s.encode(), read_data) | |
self.assertEqual(md5(s.encode()).hexdigest(), data.get_md5sum()) | |
data.reset() | |
self.assertEqual(md5().hexdigest(), data.get_md5sum()) | |
- read_data = u''.join(iter(data.read, '')) | |
+ read_data = b''.join(iter(data.read, b'')) | |
self.assertEqual(42, len(read_data)) | |
- self.assertEqual(s, read_data) | |
+ self.assertEqual(s.encode(), read_data) | |
self.assertEqual(md5(s.encode()).hexdigest(), data.get_md5sum()) | |
+ def test_nonascii_stringio(self): | |
+ contents = six.StringIO(u'\xaa' * 50) | |
+ contents.seek(22) | |
+ data = u.LengthWrapper(contents, 42, True) | |
+ with self.assertRaises(TypeError): | |
+ data.read() | |
+ | |
def test_bytesio(self): | |
contents = six.BytesIO(b'a' * 50 + b'b' * 50) | |
contents.seek(22) | |
data = u.LengthWrapper(contents, 42, True) | |
s = b'a' * 28 + b'b' * 14 | |
- read_data = b''.join(iter(data.read, '')) | |
+ read_data = b''.join(iter(data.read, b'')) | |
self.assertEqual(42, len(data)) | |
self.assertEqual(42, len(read_data)) | |
self.assertEqual(s, read_data) | |
self.assertEqual(md5(s).hexdigest(), data.get_md5sum()) | |
+ def test_size_mismatch(self): | |
+ contents = six.BytesIO(b'a' * 50 + b'b' * 50) | |
+ contents.seek(22) | |
+ data = u.LengthWrapper(contents, 400, True) | |
+ s = b'a' * 28 + b'b' * 50 | |
+ read_data = b''.join(iter(data.read, b'')) | |
+ | |
+ self.assertEqual(400, len(data)) | |
+ self.assertEqual(78, len(read_data)) | |
+ self.assertEqual(s, read_data) | |
+ self.assertEqual(md5(s).hexdigest(), data.get_md5sum()) | |
+ | |
+ def test_buffering(self): | |
+ class LocalContents(object): | |
+ def __init__(self, tell_value=0): | |
+ self.data = six.BytesIO(string.ascii_letters.encode() * 10) | |
+ self.data.seek(tell_value) | |
+ self.reads = [] | |
+ self.seeks = [] | |
+ self.tells = [] | |
+ | |
+ def tell(self): | |
+ self.tells.append(self.data.tell()) | |
+ return self.tells[-1] | |
+ | |
+ def seek(self, position, mode=0): | |
+ self.seeks.append((position, mode)) | |
+ self.data.seek(position, mode) | |
+ | |
+ def read(self, size=-1): | |
+ read_data = self.data.read(size) | |
+ self.reads.append((size, read_data)) | |
+ return read_data | |
+ | |
+ contents = LocalContents() | |
+ data = u.LengthWrapper(contents, 42, False, 10) | |
+ | |
+ self.assertEqual(b'abcd', data.read(4)) | |
+ self.assertEqual([(10, b'abcdefghij')], contents.reads) | |
+ self.assertEqual(b'efgh', data.read(4)) | |
+ self.assertEqual([(10, b'abcdefghij')], contents.reads) | |
+ # We get less than was requested at the border of a buffer | |
+ self.assertEqual(b'ij', data.read(4)) | |
+ self.assertEqual(b'kl', data.read(2)) | |
+ self.assertEqual([(10, b'abcdefghij'), (10, b'klmnopqrst')], | |
+ contents.reads) | |
+ self.assertEqual(b'mnopqrst', data.read(20)) | |
+ self.assertEqual(b'uvwxyzABCD', data.read(20)) | |
+ self.assertEqual(b'EFGHIJKLMN', data.read(20)) | |
+ self.assertEqual([(10, b'abcdefghij'), (10, b'klmnopqrst'), | |
+ (10, b'uvwxyzABCD'), (10, b'EFGHIJKLMN')], | |
+ contents.reads) | |
+ self.assertEqual(b'OP', data.read(20)) | |
+ self.assertEqual([(10, b'abcdefghij'), (10, b'klmnopqrst'), | |
+ (10, b'uvwxyzABCD'), (10, b'EFGHIJKLMN'), | |
+ (2, b'OP')], | |
+ contents.reads) | |
+ | |
def test_tempfile(self): | |
with tempfile.NamedTemporaryFile(mode='wb') as f: | |
f.write(b'a' * 100) | |
@@ -475,7 +541,7 @@ class TestLengthWrapper(unittest.TestCase): | |
contents = open(f.name, 'rb') | |
data = u.LengthWrapper(contents, 42, True) | |
s = b'a' * 42 | |
- read_data = b''.join(iter(data.read, '')) | |
+ read_data = b''.join(iter(data.read, b'')) | |
self.assertEqual(42, len(data)) | |
self.assertEqual(42, len(read_data)) | |
@@ -493,7 +559,7 @@ class TestLengthWrapper(unittest.TestCase): | |
contents = open(f.name, 'rb') | |
contents.seek(i * segment_length) | |
data = u.LengthWrapper(contents, segment_length, True) | |
- read_data = b''.join(iter(data.read, '')) | |
+ read_data = b''.join(iter(data.read, b'')) | |
s = (c * segment_length).encode() | |
self.assertEqual(segment_length, len(data)) | |
@@ -503,7 +569,7 @@ class TestLengthWrapper(unittest.TestCase): | |
data.reset() | |
self.assertEqual(md5().hexdigest(), data.get_md5sum()) | |
- read_data = b''.join(iter(data.read, '')) | |
+ read_data = b''.join(iter(data.read, b'')) | |
self.assertEqual(segment_length, len(data)) | |
self.assertEqual(segment_length, len(read_data)) | |
self.assertEqual(s, read_data) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment