Skip to content

Instantly share code, notes, and snippets.

@zaitcev
Created June 15, 2017 04:31
Show Gist options
  • Save zaitcev/53b2c2b2b16f59387ca43d3a09872cfe to your computer and use it in GitHub Desktop.
Save zaitcev/53b2c2b2b16f59387ca43d3a09872cfe to your computer and use it in GitHub Desktop.
commit 5b4e5617096d984e2633f4697002a9592f564d12
Author: Pete Zaitcev <zaitcev@kotori.zaitcev.us>
Date: Wed Jun 14 22:28:13 2017 -0600
Buffer reads from disk (alt 1)
Otherwise, Python defaults to 8k reads which seems kinda terrible.
As a side-effect, get more picky about the types of input we wrap. While
you can still send unicode that can be encoded to ASCII, any other
unicode will raise a TypeError in LengthWrapper, rather than making it
all the way down to httplib.
Change-Id: I3160626e947083af487fd1c3cb0aa6a62646527b
Closes-Bug: #1671621
diff --git a/swiftclient/utils.py b/swiftclient/utils.py
index 8afcde9..f38c7f5 100644
--- a/swiftclient/utils.py
+++ b/swiftclient/utils.py
@@ -20,6 +20,7 @@ import hashlib
import hmac
import json
import logging
+import os
import six
import time
import traceback
@@ -301,7 +302,7 @@ class LengthWrapper(object):
Fix for https://github.com/kennethreitz/requests/issues/1648.
It is recommended to use this class only on files opened in binary mode.
"""
- def __init__(self, readable, length, md5=False):
+ def __init__(self, readable, length, md5=False, read_size=65536):
"""
:param readable: The filelike object to read from.
:param length: The maximum amount of content that can be read from
@@ -309,11 +310,16 @@ class LengthWrapper(object):
empty.
:param md5: Flag to enable calculating the MD5 of the content
as it is read.
+ :param read_size: The number of bytes that should be read from
+ ``readable`` at a time.
"""
self._md5 = md5
self._reset_md5()
self._length = self._remaining = length
+ self._buffer = six.BytesIO()
+ self._buffer_size = 0
self._readable = readable
+ self.read_size = read_size
self._can_reset = all(hasattr(readable, attr)
for attr in ('seek', 'tell'))
if self._can_reset:
@@ -328,12 +334,36 @@ class LengthWrapper(object):
def get_md5sum(self):
return self.md5sum.hexdigest()
+ def _look_ahead(self, remaining):
+ if self._buffer.tell() < self._buffer_size:
+ return
+
+ data = self._readable.read(min(remaining, self.read_size))
+ if not data:
+ return
+ if not isinstance(data, bytes):
+ try:
+ data = data.encode('ascii')
+ except UnicodeEncodeError:
+ raise TypeError('Object data must be bytes, not %s' %
+ type(data).__name__)
+
+ # truncate the buffer to free some memory
+ self._buffer.truncate()
+ # must reset even after truncate, else write ends appending
+ self._buffer.seek(0, os.SEEK_SET)
+
+ self._buffer.write(data)
+ self._buffer.seek(0, os.SEEK_SET)
+ self._buffer_size = len(data)
+
def read(self, size=-1):
if self._remaining <= 0:
- return ''
+ return b''
to_read = self._remaining if size < 0 else min(size, self._remaining)
- chunk = self._readable.read(to_read)
+ self._look_ahead(self._remaining)
+ chunk = self._buffer.read(to_read)
self._remaining -= len(chunk)
try:
@@ -354,6 +384,9 @@ class LengthWrapper(object):
if not self._can_reset:
raise TypeError('%r object cannot be reset; needs both seek and '
'tell methods' % type(self._readable).__name__)
+ self._buffer.seek(0)
+ self._buffer.truncate()
+ self._buffer_size = 0
self._readable.seek(self._start)
self._reset_md5()
self._remaining = self._length
diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py
index adead00..e5814dc 100644
--- a/tests/unit/test_utils.py
+++ b/tests/unit/test_utils.py
@@ -17,6 +17,7 @@ import gzip
import unittest
import mock
import six
+import string
import tempfile
from time import gmtime, localtime, mktime, strftime, strptime
from hashlib import md5, sha1
@@ -441,33 +442,98 @@ class TestLengthWrapper(unittest.TestCase):
contents.seek(22)
data = u.LengthWrapper(contents, 42, True)
s = u'a' * 28 + u'b' * 14
- read_data = u''.join(iter(data.read, ''))
+ read_data = b''.join(iter(data.read, b''))
self.assertEqual(42, len(data))
self.assertEqual(42, len(read_data))
- self.assertEqual(s, read_data)
+ self.assertEqual(s.encode(), read_data)
self.assertEqual(md5(s.encode()).hexdigest(), data.get_md5sum())
data.reset()
self.assertEqual(md5().hexdigest(), data.get_md5sum())
- read_data = u''.join(iter(data.read, ''))
+ read_data = b''.join(iter(data.read, b''))
self.assertEqual(42, len(read_data))
- self.assertEqual(s, read_data)
+ self.assertEqual(s.encode(), read_data)
self.assertEqual(md5(s.encode()).hexdigest(), data.get_md5sum())
+ def test_nonascii_stringio(self):
+ contents = six.StringIO(u'\xaa' * 50)
+ contents.seek(22)
+ data = u.LengthWrapper(contents, 42, True)
+ with self.assertRaises(TypeError):
+ data.read()
+
def test_bytesio(self):
contents = six.BytesIO(b'a' * 50 + b'b' * 50)
contents.seek(22)
data = u.LengthWrapper(contents, 42, True)
s = b'a' * 28 + b'b' * 14
- read_data = b''.join(iter(data.read, ''))
+ read_data = b''.join(iter(data.read, b''))
self.assertEqual(42, len(data))
self.assertEqual(42, len(read_data))
self.assertEqual(s, read_data)
self.assertEqual(md5(s).hexdigest(), data.get_md5sum())
+ def test_size_mismatch(self):
+ contents = six.BytesIO(b'a' * 50 + b'b' * 50)
+ contents.seek(22)
+ data = u.LengthWrapper(contents, 400, True)
+ s = b'a' * 28 + b'b' * 50
+ read_data = b''.join(iter(data.read, b''))
+
+ self.assertEqual(400, len(data))
+ self.assertEqual(78, len(read_data))
+ self.assertEqual(s, read_data)
+ self.assertEqual(md5(s).hexdigest(), data.get_md5sum())
+
+ def test_buffering(self):
+ class LocalContents(object):
+ def __init__(self, tell_value=0):
+ self.data = six.BytesIO(string.ascii_letters.encode() * 10)
+ self.data.seek(tell_value)
+ self.reads = []
+ self.seeks = []
+ self.tells = []
+
+ def tell(self):
+ self.tells.append(self.data.tell())
+ return self.tells[-1]
+
+ def seek(self, position, mode=0):
+ self.seeks.append((position, mode))
+ self.data.seek(position, mode)
+
+ def read(self, size=-1):
+ read_data = self.data.read(size)
+ self.reads.append((size, read_data))
+ return read_data
+
+ contents = LocalContents()
+ data = u.LengthWrapper(contents, 42, False, 10)
+
+ self.assertEqual(b'abcd', data.read(4))
+ self.assertEqual([(10, b'abcdefghij')], contents.reads)
+ self.assertEqual(b'efgh', data.read(4))
+ self.assertEqual([(10, b'abcdefghij')], contents.reads)
+ # We get less than was requested at the border of a buffer
+ self.assertEqual(b'ij', data.read(4))
+ self.assertEqual(b'kl', data.read(2))
+ self.assertEqual([(10, b'abcdefghij'), (10, b'klmnopqrst')],
+ contents.reads)
+ self.assertEqual(b'mnopqrst', data.read(20))
+ self.assertEqual(b'uvwxyzABCD', data.read(20))
+ self.assertEqual(b'EFGHIJKLMN', data.read(20))
+ self.assertEqual([(10, b'abcdefghij'), (10, b'klmnopqrst'),
+ (10, b'uvwxyzABCD'), (10, b'EFGHIJKLMN')],
+ contents.reads)
+ self.assertEqual(b'OP', data.read(20))
+ self.assertEqual([(10, b'abcdefghij'), (10, b'klmnopqrst'),
+ (10, b'uvwxyzABCD'), (10, b'EFGHIJKLMN'),
+ (2, b'OP')],
+ contents.reads)
+
def test_tempfile(self):
with tempfile.NamedTemporaryFile(mode='wb') as f:
f.write(b'a' * 100)
@@ -475,7 +541,7 @@ class TestLengthWrapper(unittest.TestCase):
contents = open(f.name, 'rb')
data = u.LengthWrapper(contents, 42, True)
s = b'a' * 42
- read_data = b''.join(iter(data.read, ''))
+ read_data = b''.join(iter(data.read, b''))
self.assertEqual(42, len(data))
self.assertEqual(42, len(read_data))
@@ -493,7 +559,7 @@ class TestLengthWrapper(unittest.TestCase):
contents = open(f.name, 'rb')
contents.seek(i * segment_length)
data = u.LengthWrapper(contents, segment_length, True)
- read_data = b''.join(iter(data.read, ''))
+ read_data = b''.join(iter(data.read, b''))
s = (c * segment_length).encode()
self.assertEqual(segment_length, len(data))
@@ -503,7 +569,7 @@ class TestLengthWrapper(unittest.TestCase):
data.reset()
self.assertEqual(md5().hexdigest(), data.get_md5sum())
- read_data = b''.join(iter(data.read, ''))
+ read_data = b''.join(iter(data.read, b''))
self.assertEqual(segment_length, len(data))
self.assertEqual(segment_length, len(read_data))
self.assertEqual(s, read_data)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment