Просмотр исходного кода

refactor buffer code into helpers.Buffer class, add tests

Thomas Waldmann 8 лет назад
Родитель
Сommit
ef9e8a584b
4 измененных файлов с 112 добавлено и 23 удалено
  1. 42 0
      borg/helpers.py
  2. 57 1
      borg/testsuite/helpers.py
  3. 5 5
      borg/testsuite/xattr.py
  4. 8 17
      borg/xattr.py

+ 42 - 0
borg/helpers.py

@@ -10,6 +10,7 @@ import re
 from shutil import get_terminal_size
 import sys
 import platform
+import threading
 import time
 import unicodedata
 import io
@@ -655,6 +656,47 @@ def memoize(function):
     return decorated_function
 
 
+class Buffer:
+    """
+    provide a thread-local buffer
+    """
+    def __init__(self, allocator, size=4096, limit=None):
+        """
+        Initialize the buffer: use allocator(size) call to allocate a buffer.
+        Optionally, set the upper <limit> for the buffer size.
+        """
+        assert callable(allocator), 'must give alloc(size) function as first param'
+        assert limit is None or size <= limit, 'initial size must be <= limit'
+        self._thread_local = threading.local()
+        self.allocator = allocator
+        self.limit = limit
+        self.resize(size, init=True)
+
+    def __len__(self):
+        return len(self._thread_local.buffer)
+
+    def resize(self, size, init=False):
+        """
+        resize the buffer - to avoid frequent reallocation, we usually always grow (if needed).
+        giving init=True it is possible to first-time initialize or shrink the buffer.
+        if a buffer size beyond the limit is requested, raise ValueError.
+        """
+        size = int(size)
+        if self.limit is not None and size > self.limit:
+            raise ValueError('Requested buffer size %d is above the limit of %d.' % (size, self.limit))
+        if init or len(self) < size:
+            self._thread_local.buffer = self.allocator(size)
+
+    def get(self, size=None, init=False):
+        """
+        return a buffer of at least the requested size (None: any current size).
+        init=True can be given to trigger shrinking of the buffer to the given size.
+        """
+        if size is not None:
+            self.resize(size, init)
+        return self._thread_local.buffer
+
+
 @memoize
 def uid2user(uid, default=None):
     try:

+ 57 - 1
borg/testsuite/helpers.py

@@ -15,7 +15,8 @@ from ..helpers import Location, format_file_size, format_timedelta, format_line,
     yes, TRUISH, FALSISH, DEFAULTISH, \
     StableDict, int_to_bigint, bigint_to_int, parse_timestamp, CompressionSpec, ChunkerParams, \
     ProgressIndicatorPercent, ProgressIndicatorEndless, load_excludes, parse_pattern, \
-    PatternMatcher, RegexPattern, PathPrefixPattern, FnmatchPattern, ShellPattern
+    PatternMatcher, RegexPattern, PathPrefixPattern, FnmatchPattern, ShellPattern, \
+    Buffer
 from . import BaseTestCase, environment_variable, FakeInputs
 
 
@@ -714,6 +715,61 @@ def test_is_slow_msgpack():
     assert not is_slow_msgpack()
 
 
+class TestBuffer:
+    def test_type(self):
+        buffer = Buffer(bytearray)
+        assert isinstance(buffer.get(), bytearray)
+        buffer = Buffer(bytes)  # don't do that in practice
+        assert isinstance(buffer.get(), bytes)
+
+    def test_len(self):
+        buffer = Buffer(bytearray, size=0)
+        b = buffer.get()
+        assert len(buffer) == len(b) == 0
+        buffer = Buffer(bytearray, size=1234)
+        b = buffer.get()
+        assert len(buffer) == len(b) == 1234
+
+    def test_resize(self):
+        buffer = Buffer(bytearray, size=100)
+        assert len(buffer) == 100
+        b1 = buffer.get()
+        buffer.resize(200)
+        assert len(buffer) == 200
+        b2 = buffer.get()
+        assert b2 is not b1  # new, bigger buffer
+        buffer.resize(100)
+        assert len(buffer) >= 100
+        b3 = buffer.get()
+        assert b3 is b2  # still same buffer (200)
+        buffer.resize(100, init=True)
+        assert len(buffer) == 100  # except on init
+        b4 = buffer.get()
+        assert b4 is not b3  # new, smaller buffer
+
+    def test_limit(self):
+        buffer = Buffer(bytearray, size=100, limit=200)
+        buffer.resize(200)
+        assert len(buffer) == 200
+        with pytest.raises(ValueError):
+            buffer.resize(201)
+        assert len(buffer) == 200
+
+    def test_get(self):
+        buffer = Buffer(bytearray, size=100, limit=200)
+        b1 = buffer.get(50)
+        assert len(b1) >= 50  # == 100
+        b2 = buffer.get(100)
+        assert len(b2) >= 100  # == 100
+        assert b2 is b1  # did not need resizing yet
+        b3 = buffer.get(200)
+        assert len(b3) == 200
+        assert b3 is not b2  # new, resized buffer
+        with pytest.raises(ValueError):
+            buffer.get(201)  # beyond limit
+        assert len(buffer) == 200
+
+
 def test_yes_input():
     inputs = list(TRUISH)
     input = FakeInputs(inputs)

+ 5 - 5
borg/testsuite/xattr.py

@@ -2,7 +2,7 @@ import os
 import tempfile
 import unittest
 
-from ..xattr import is_enabled, getxattr, setxattr, listxattr, get_buffer
+from ..xattr import is_enabled, getxattr, setxattr, listxattr, buffer
 from . import BaseTestCase
 
 
@@ -41,20 +41,20 @@ class XattrTestCase(BaseTestCase):
 
     def test_listxattr_buffer_growth(self):
         # make it work even with ext4, which imposes rather low limits
-        get_buffer(size=64, init=True)
+        buffer.resize(size=64, init=True)
         # xattr raw key list will be size 9 * (10 + 1), which is > 64
         keys = ['user.attr%d' % i for i in range(9)]
         for key in keys:
             setxattr(self.tmpfile.name, key, b'x')
         got_keys = listxattr(self.tmpfile.name)
         self.assert_equal_se(got_keys, keys)
-        self.assert_equal(len(get_buffer()), 128)
+        self.assert_equal(len(buffer), 128)
 
     def test_getxattr_buffer_growth(self):
         # make it work even with ext4, which imposes rather low limits
-        get_buffer(size=64, init=True)
+        buffer.resize(size=64, init=True)
         value = b'x' * 126
         setxattr(self.tmpfile.name, 'user.big', value)
         got_value = getxattr(self.tmpfile.name, 'user.big')
         self.assert_equal(value, got_value)
-        self.assert_equal(len(get_buffer()), 128)
+        self.assert_equal(len(buffer), 128)

+ 8 - 17
borg/xattr.py

@@ -6,11 +6,12 @@ import re
 import subprocess
 import sys
 import tempfile
-import threading
 from ctypes import CDLL, create_string_buffer, c_ssize_t, c_size_t, c_char_p, c_int, c_uint32, get_errno
 from ctypes.util import find_library
 from distutils.version import LooseVersion
 
+from .helpers import Buffer
+
 from .logger import create_logger
 logger = create_logger()
 
@@ -22,17 +23,7 @@ except AttributeError:
     ENOATTR = errno.ENODATA
 
 
-def get_buffer(size=None, init=False):
-    if size is not None:
-        size = int(size)
-        assert size < 2 ** 24
-        if init or len(thread_local.buffer) < size:
-            thread_local.buffer = create_string_buffer(size)
-    return thread_local.buffer
-
-
-thread_local = threading.local()
-get_buffer(size=4096, init=True)
+buffer = Buffer(create_string_buffer, limit=2**24)
 
 
 def is_enabled(path=None):
@@ -144,7 +135,7 @@ def _check(rv, path=None, detect_buffer_too_small=False):
             if isinstance(path, int):
                 path = '<FD %d>' % path
             raise OSError(e, msg, path)
-    if detect_buffer_too_small and rv >= len(get_buffer()):
+    if detect_buffer_too_small and rv >= len(buffer):
         # freebsd does not error with ERANGE if the buffer is too small,
         # it just fills the buffer, truncates and returns.
         # so, we play sure and just assume that result is truncated if
@@ -156,9 +147,9 @@ def _check(rv, path=None, detect_buffer_too_small=False):
 def _listxattr_inner(func, path):
     if isinstance(path, str):
         path = os.fsencode(path)
-    size = len(get_buffer())
+    size = len(buffer)
     while True:
-        buf = get_buffer(size)
+        buf = buffer.get(size)
         try:
             n = _check(func(path, buf, size), path, detect_buffer_too_small=True)
         except BufferTooSmallError:
@@ -171,9 +162,9 @@ def _getxattr_inner(func, path, name):
     if isinstance(path, str):
         path = os.fsencode(path)
     name = os.fsencode(name)
-    size = len(get_buffer())
+    size = len(buffer)
     while True:
-        buf = get_buffer(size)
+        buf = buffer.get(size)
         try:
             n = _check(func(path, name, buf, size), path, detect_buffer_too_small=True)
         except BufferTooSmallError: