Browse Source

xattr: dynamically grow result buffer until it fits, fixes #1462

this also fixes the race condition seen in #1462 because there is only 1 call now.
either it succeeds, then we get the correct length as result and truncate the result value to that length.
or it fails with ERANGE, then we grow the buffer to double size and repeat.
or it fails with some other error, then we throw OSError.
Thomas Waldmann 9 years ago
parent
commit
17c77a5dc5
2 changed files with 72 additions and 18 deletions
  1. 21 1
      borg/testsuite/xattr.py
  2. 51 17
      borg/xattr.py

+ 21 - 1
borg/testsuite/xattr.py

@@ -2,7 +2,7 @@ import os
 import tempfile
 import unittest
 
-from ..xattr import is_enabled, getxattr, setxattr, listxattr
+from ..xattr import is_enabled, getxattr, setxattr, listxattr, get_buffer
 from . import BaseTestCase
 
 
@@ -38,3 +38,23 @@ class XattrTestCase(BaseTestCase):
         self.assert_equal(getxattr(self.tmpfile.fileno(), 'user.foo'), b'bar')
         self.assert_equal(getxattr(self.symlink, 'user.foo'), b'bar')
         self.assert_equal(getxattr(self.tmpfile.name, 'user.empty'), None)
+
+    def test_listxattr_buffer_growth(self):
+        # make it work even with ext4, which imposes rather low limits
+        get_buffer(size=64, init=True)
+        # xattr raw key list will be size 9 * (10 + 1), which is > 64
+        keys = ['user.attr%d' % i for i in range(9)]
+        for key in keys:
+            setxattr(self.tmpfile.name, key, b'x')
+        got_keys = listxattr(self.tmpfile.name)
+        self.assert_equal_se(got_keys, keys)
+        self.assert_equal(len(get_buffer()), 128)
+
+    def test_getxattr_buffer_growth(self):
+        # make it work even with ext4, which imposes rather low limits
+        get_buffer(size=64, init=True)
+        value = b'x' * 126
+        setxattr(self.tmpfile.name, 'user.big', value)
+        got_value = getxattr(self.tmpfile.name, 'user.big')
+        self.assert_equal(value, got_value)
+        self.assert_equal(len(get_buffer()), 128)

+ 51 - 17
borg/xattr.py

@@ -6,6 +6,7 @@ import re
 import subprocess
 import sys
 import tempfile
+import threading
 from ctypes import CDLL, create_string_buffer, c_ssize_t, c_size_t, c_char_p, c_int, c_uint32, get_errno
 from ctypes.util import find_library
 from distutils.version import LooseVersion
@@ -14,6 +15,18 @@ from .logger import create_logger
 logger = create_logger()
 
 
+def get_buffer(size=None, init=False):
+    if size is not None:
+        size = int(size)
+        if init or len(thread_local.buffer) < size:
+            thread_local.buffer = create_string_buffer(size)
+    return thread_local.buffer
+
+
+thread_local = threading.local()
+get_buffer(size=4096, init=True)
+
+
 def is_enabled(path=None):
     """Determine if xattr is enabled on the filesystem
     """
@@ -78,9 +91,17 @@ except OSError as e:
     raise Exception(msg)
 
 
+class BufferTooSmallError(Exception):
+    """the buffer given to an xattr function was too small for the result"""
+
+
 def _check(rv, path=None):
     if rv < 0:
-        raise OSError(get_errno(), path)
+        e = get_errno()
+        if e == errno.ERANGE:
+            raise BufferTooSmallError
+        else:
+            raise OSError(e, path)
     return rv
 
 if sys.platform.startswith('linux'):  # pragma: linux only
@@ -106,14 +127,20 @@ if sys.platform.startswith('linux'):  # pragma: linux only
             func = libc.listxattr
         else:
             func = libc.llistxattr
-        n = _check(func(path, None, 0), path)
-        if n == 0:
-            return []
-        namebuf = create_string_buffer(n)
-        n2 = _check(func(path, namebuf, n), path)
-        if n2 != n:
-            raise Exception('listxattr failed')
-        return [os.fsdecode(name) for name in namebuf.raw.split(b'\0')[:-1] if not name.startswith(b'system.posix_acl_')]
+        size = len(get_buffer())
+        while True:
+            buf = get_buffer(size)
+            try:
+                n = _check(func(path, buf, size), path)
+            except BufferTooSmallError:
+                size *= 2
+            else:
+                if n == 0:
+                    return []
+                names = buf.raw[:n]
+                return [os.fsdecode(name)
+                        for name in names.split(b'\0')[:-1]
+                        if not name.startswith(b'system.posix_acl_')]
 
     def getxattr(path, name, *, follow_symlinks=True):
         name = os.fsencode(name)
@@ -125,14 +152,17 @@ if sys.platform.startswith('linux'):  # pragma: linux only
             func = libc.getxattr
         else:
             func = libc.lgetxattr
-        n = _check(func(path, name, None, 0))
-        if n == 0:
-            return
-        valuebuf = create_string_buffer(n)
-        n2 = _check(func(path, name, valuebuf, n), path)
-        if n2 != n:
-            raise Exception('getxattr failed')
-        return valuebuf.raw
+        size = len(get_buffer())
+        while True:
+            buf = get_buffer(size)
+            try:
+                n = _check(func(path, name, buf, size), path)
+            except BufferTooSmallError:
+                size *= 2
+            else:
+                if n == 0:
+                    return
+                return buf.raw[:n]
 
     def setxattr(path, name, value, *, follow_symlinks=True):
         name = os.fsencode(name)
@@ -172,6 +202,7 @@ elif sys.platform == 'darwin':  # pragma: darwin only
             func = libc.flistxattr
         elif not follow_symlinks:
             flags = XATTR_NOFOLLOW
+        # TODO: fix, see linux
         n = _check(func(path, None, 0, flags), path)
         if n == 0:
             return []
@@ -191,6 +222,7 @@ elif sys.platform == 'darwin':  # pragma: darwin only
             func = libc.fgetxattr
         elif not follow_symlinks:
             flags = XATTR_NOFOLLOW
+        # TODO: fix, see linux
         n = _check(func(path, name, None, 0, 0, flags))
         if n == 0:
             return
@@ -244,6 +276,7 @@ elif sys.platform.startswith('freebsd'):  # pragma: freebsd only
             func = libc.extattr_list_file
         else:
             func = libc.extattr_list_link
+        # TODO: fix, see linux
         n = _check(func(path, ns, None, 0), path)
         if n == 0:
             return []
@@ -269,6 +302,7 @@ elif sys.platform.startswith('freebsd'):  # pragma: freebsd only
             func = libc.extattr_get_file
         else:
             func = libc.extattr_get_link
+        # TODO: fix, see linux
         n = _check(func(path, EXTATTR_NAMESPACE_USER, name, None, 0))
         if n == 0:
             return