Browse Source

xattr: refactor code, deduplicate

this code would be otherwise duplicated 3 times for linux, freebsd, darwin.
Thomas Waldmann 9 years ago
parent
commit
67c6c1898c
1 changed files with 145 additions and 137 deletions
  1. 145 137
      borg/xattr.py

+ 145 - 137
borg/xattr.py

@@ -46,6 +46,7 @@ def get_all(path, follow_symlinks=True):
         if e.errno in (errno.ENOTSUP, errno.EPERM):
             return {}
 
+
 libc_name = find_library('c')
 if libc_name is None:
     # find_library didn't work, maybe we are on some minimal system that misses essential
@@ -104,6 +105,46 @@ def _check(rv, path=None):
             raise OSError(e, path)
     return rv
 
+
+def _listxattr_inner(func, path):
+    if isinstance(path, str):
+        path = os.fsencode(path)
+    size = len(get_buffer())
+    while True:
+        buf = get_buffer(size)
+        try:
+            n = _check(func(path, buf, size), path)
+        except BufferTooSmallError:
+            size *= 2
+            assert size < 2 ** 20
+        else:
+            return n, buf.raw
+
+
+def _getxattr_inner(func, path, name):
+    if isinstance(path, str):
+        path = os.fsencode(path)
+    name = os.fsencode(name)
+    size = len(get_buffer())
+    while True:
+        buf = get_buffer(size)
+        try:
+            n = _check(func(path, name, buf, size), path)
+        except BufferTooSmallError:
+            size *= 2
+        else:
+            return n, buf.raw
+
+
+def _setxattr_inner(func, path, name, value):
+    if isinstance(path, str):
+        path = os.fsencode(path)
+    name = os.fsencode(name)
+    value = value and os.fsencode(value)
+    size = len(value) if value else 0
+    _check(func(path, name, value, size), path)
+
+
 if sys.platform.startswith('linux'):  # pragma: linux only
     libc.llistxattr.argtypes = (c_char_p, c_char_p, c_size_t)
     libc.llistxattr.restype = c_ssize_t
@@ -119,63 +160,49 @@ if sys.platform.startswith('linux'):  # pragma: linux only
     libc.fgetxattr.restype = c_ssize_t
 
     def listxattr(path, *, follow_symlinks=True):
-        if isinstance(path, str):
-            path = os.fsencode(path)
-        if isinstance(path, int):
-            func = libc.flistxattr
-        elif follow_symlinks:
-            func = libc.listxattr
-        else:
-            func = libc.llistxattr
-        size = len(get_buffer())
-        while True:
-            buf = get_buffer(size)
-            try:
-                n = _check(func(path, buf, size), path)
-            except BufferTooSmallError:
-                size *= 2
+        def func(path, buf, size):
+            if isinstance(path, int):
+                return libc.flistxattr(path, buf, size)
             else:
-                if n == 0:
-                    return []
-                names = buf.raw[:n]
-                return [os.fsdecode(name)
-                        for name in names.split(b'\0')[:-1]
-                        if not name.startswith(b'system.posix_acl_')]
+                if follow_symlinks:
+                    return libc.listxattr(path, buf, size)
+                else:
+                    return libc.llistxattr(path, buf, size)
+
+        n, buf = _listxattr_inner(func, path)
+        if n == 0:
+            return []
+        names = buf[:n].split(b'\0')[:-1]
+        return [os.fsdecode(name) for name in names
+                if not name.startswith(b'system.posix_acl_')]
 
     def getxattr(path, name, *, follow_symlinks=True):
-        name = os.fsencode(name)
-        if isinstance(path, str):
-            path = os.fsencode(path)
-        if isinstance(path, int):
-            func = libc.fgetxattr
-        elif follow_symlinks:
-            func = libc.getxattr
-        else:
-            func = libc.lgetxattr
-        size = len(get_buffer())
-        while True:
-            buf = get_buffer(size)
-            try:
-                n = _check(func(path, name, buf, size), path)
-            except BufferTooSmallError:
-                size *= 2
+        def func(path, name, buf, size):
+            if isinstance(path, int):
+                return libc.fgetxattr(path, name, buf, size)
             else:
-                if n == 0:
-                    return
-                return buf.raw[:n]
+                if follow_symlinks:
+                    return libc.getxattr(path, name, buf, size)
+                else:
+                    return libc.lgetxattr(path, name, buf, size)
+
+        n, buf = _getxattr_inner(func, path, name)
+        if n == 0:
+            return
+        return buf[:n]
 
     def setxattr(path, name, value, *, follow_symlinks=True):
-        name = os.fsencode(name)
-        value = value and os.fsencode(value)
-        if isinstance(path, str):
-            path = os.fsencode(path)
-        if isinstance(path, int):
-            func = libc.fsetxattr
-        elif follow_symlinks:
-            func = libc.setxattr
-        else:
-            func = libc.lsetxattr
-        _check(func(path, name, value, len(value) if value else 0, 0), path)
+        def func(path, name, value, size):
+            flags = 0
+            if isinstance(path, int):
+                return libc.fsetxattr(path, name, value, size, flags)
+            else:
+                if follow_symlinks:
+                    return libc.setxattr(path, name, value, size, flags)
+                else:
+                    return libc.lsetxattr(path, name, value, size, flags)
+
+        _setxattr_inner(func, path, name, value)
 
 elif sys.platform == 'darwin':  # pragma: darwin only
     libc.listxattr.argtypes = (c_char_p, c_char_p, c_size_t, c_int)
@@ -191,62 +218,53 @@ elif sys.platform == 'darwin':  # pragma: darwin only
     libc.fgetxattr.argtypes = (c_int, c_char_p, c_char_p, c_size_t, c_uint32, c_int)
     libc.fgetxattr.restype = c_ssize_t
 
+    XATTR_NOFLAGS = 0x0000
     XATTR_NOFOLLOW = 0x0001
 
     def listxattr(path, *, follow_symlinks=True):
-        func = libc.listxattr
-        flags = 0
-        if isinstance(path, str):
-            path = os.fsencode(path)
-        if isinstance(path, int):
-            func = libc.flistxattr
-        elif not follow_symlinks:
-            flags = XATTR_NOFOLLOW
-        # TODO: fix, see linux
-        n = _check(func(path, None, 0, flags), path)
+        def func(path, buf, size):
+            if isinstance(path, int):
+                return libc.flistxattr(path, buf, size, XATTR_NOFLAGS)
+            else:
+                if follow_symlinks:
+                    return libc.listxattr(path, buf, size, XATTR_NOFLAGS)
+                else:
+                    return libc.listxattr(path, buf, size, XATTR_NOFOLLOW)
+
+        n, buf = _listxattr_inner(func, path)
         if n == 0:
             return []
-        namebuf = create_string_buffer(n)
-        n2 = _check(func(path, namebuf, n, flags), path)
-        if n2 != n:
-            raise Exception('listxattr failed')
-        return [os.fsdecode(name) for name in namebuf.raw.split(b'\0')[:-1]]
+        names = buf[:n].split(b'\0')[:-1]
+        return [os.fsdecode(name) for name in names]
 
     def getxattr(path, name, *, follow_symlinks=True):
-        name = os.fsencode(name)
-        func = libc.getxattr
-        flags = 0
-        if isinstance(path, str):
-            path = os.fsencode(path)
-        if isinstance(path, int):
-            func = libc.fgetxattr
-        elif not follow_symlinks:
-            flags = XATTR_NOFOLLOW
-        # TODO: fix, see linux
-        n = _check(func(path, name, None, 0, 0, flags))
+        def func(path, name, buf, size):
+            if isinstance(path, int):
+                return libc.fgetxattr(path, name, buf, size, 0, XATTR_NOFLAGS)
+            else:
+                if follow_symlinks:
+                    return libc.getxattr(path, name, buf, size, 0, XATTR_NOFLAGS)
+                else:
+                    return libc.getxattr(path, name, buf, size, 0, XATTR_NOFOLLOW)
+
+        n, buf = _getxattr_inner(func, path, name)
         if n == 0:
             return
-        valuebuf = create_string_buffer(n)
-        n2 = _check(func(path, name, valuebuf, n, 0, flags), path)
-        if n2 != n:
-            raise Exception('getxattr failed')
-        return valuebuf.raw
+        return buf[:n]
 
     def setxattr(path, name, value, *, follow_symlinks=True):
-        name = os.fsencode(name)
-        value = value and os.fsencode(value)
-        func = libc.setxattr
-        flags = 0
-        if isinstance(path, str):
-            path = os.fsencode(path)
-        if isinstance(path, int):
-            func = libc.fsetxattr
-        elif not follow_symlinks:
-            flags = XATTR_NOFOLLOW
-        _check(func(path, name, value, len(value) if value else 0, 0, flags), path)
+        def func(path, name, value, size):
+            if isinstance(path, int):
+                return libc.fsetxattr(path, name, value, size, 0, XATTR_NOFLAGS)
+            else:
+                if follow_symlinks:
+                    return libc.setxattr(path, name, value, size, 0, XATTR_NOFLAGS)
+                else:
+                    return libc.setxattr(path, name, value, size, 0, XATTR_NOFOLLOW)
+
+        _setxattr_inner(func, path, name, value)
 
 elif sys.platform.startswith('freebsd'):  # pragma: freebsd only
-    EXTATTR_NAMESPACE_USER = 0x0001
     libc.extattr_list_fd.argtypes = (c_int, c_int, c_char_p, c_size_t)
     libc.extattr_list_fd.restype = c_ssize_t
     libc.extattr_list_link.argtypes = (c_char_p, c_int, c_char_p, c_size_t)
@@ -265,27 +283,23 @@ elif sys.platform.startswith('freebsd'):  # pragma: freebsd only
     libc.extattr_set_link.restype = c_int
     libc.extattr_set_file.argtypes = (c_char_p, c_int, c_char_p, c_char_p, c_size_t)
     libc.extattr_set_file.restype = c_int
+    ns = EXTATTR_NAMESPACE_USER = 0x0001
 
     def listxattr(path, *, follow_symlinks=True):
-        ns = EXTATTR_NAMESPACE_USER
-        if isinstance(path, str):
-            path = os.fsencode(path)
-        if isinstance(path, int):
-            func = libc.extattr_list_fd
-        elif follow_symlinks:
-            func = libc.extattr_list_file
-        else:
-            func = libc.extattr_list_link
-        # TODO: fix, see linux
-        n = _check(func(path, ns, None, 0), path)
+        def func(path, buf, size):
+            if isinstance(path, int):
+                return libc.extattr_list_fd(path, ns, buf, size)
+            else:
+                if follow_symlinks:
+                    return libc.extattr_list_file(path, ns, buf, size)
+                else:
+                    return libc.extattr_list_link(path, ns, buf, size)
+
+        n, buf = _listxattr_inner(func, path)
         if n == 0:
             return []
-        namebuf = create_string_buffer(n)
-        n2 = _check(func(path, ns, namebuf, n), path)
-        if n2 != n:
-            raise Exception('listxattr failed')
         names = []
-        mv = memoryview(namebuf.raw)
+        mv = memoryview(buf)
         while mv:
             length = mv[0]
             names.append(os.fsdecode(bytes(mv[1:1 + length])))
@@ -293,37 +307,31 @@ elif sys.platform.startswith('freebsd'):  # pragma: freebsd only
         return names
 
     def getxattr(path, name, *, follow_symlinks=True):
-        name = os.fsencode(name)
-        if isinstance(path, str):
-            path = os.fsencode(path)
-        if isinstance(path, int):
-            func = libc.extattr_get_fd
-        elif follow_symlinks:
-            func = libc.extattr_get_file
-        else:
-            func = libc.extattr_get_link
-        # TODO: fix, see linux
-        n = _check(func(path, EXTATTR_NAMESPACE_USER, name, None, 0))
+        def func(path, name, buf, size):
+            if isinstance(path, int):
+                return libc.extattr_get_fd(path, ns, name, buf, size)
+            else:
+                if follow_symlinks:
+                    return libc.extattr_get_file(path, ns, name, buf, size)
+                else:
+                    return libc.extattr_get_link(path, ns, name, buf, size)
+
+        n, buf = _getxattr_inner(func, path, name)
         if n == 0:
             return
-        valuebuf = create_string_buffer(n)
-        n2 = _check(func(path, EXTATTR_NAMESPACE_USER, name, valuebuf, n), path)
-        if n2 != n:
-            raise Exception('getxattr failed')
-        return valuebuf.raw
+        return buf[:n]
 
     def setxattr(path, name, value, *, follow_symlinks=True):
-        name = os.fsencode(name)
-        value = value and os.fsencode(value)
-        if isinstance(path, str):
-            path = os.fsencode(path)
-        if isinstance(path, int):
-            func = libc.extattr_set_fd
-        elif follow_symlinks:
-            func = libc.extattr_set_file
-        else:
-            func = libc.extattr_set_link
-        _check(func(path, EXTATTR_NAMESPACE_USER, name, value, len(value) if value else 0), path)
+        def func(path, name, value, size):
+            if isinstance(path, int):
+                return libc.extattr_set_fd(path, ns, name, value, size)
+            else:
+                if follow_symlinks:
+                    return libc.extattr_set_file(path, ns, name, value, size)
+                else:
+                    return libc.extattr_set_link(path, ns, name, value, size)
+
+        _setxattr_inner(func, path, name, value)
 
 else:  # pragma: unknown platform only
     def listxattr(path, *, follow_symlinks=True):