Kaynağa Gözat

Merge pull request #1211 from enkore/issue/1138

Fix incorrect propagation of OSErrors in create code
enkore 9 yıl önce
ebeveyn
işleme
67c69998d6
4 değiştirilmiş dosya ile 89 ekleme ve 24 silme
  1. 40 5
      borg/archive.py
  2. 3 3
      borg/archiver.py
  3. 21 16
      borg/remote.py
  4. 25 0
      borg/testsuite/archive.py

+ 40 - 5
borg/archive.py

@@ -1,4 +1,5 @@
 from binascii import hexlify
+from contextlib import contextmanager
 from datetime import datetime, timezone
 from getpass import getuser
 from itertools import groupby
@@ -45,6 +46,37 @@ flags_normal = os.O_RDONLY | getattr(os, 'O_BINARY', 0)
 flags_noatime = flags_normal | getattr(os, 'O_NOATIME', 0)
 
 
+class InputOSError(Exception):
+    """Wrapper for OSError raised while accessing input files."""
+    def __init__(self, os_error):
+        self.os_error = os_error
+        self.errno = os_error.errno
+        self.strerror = os_error.strerror
+        self.filename = os_error.filename
+
+    def __str__(self):
+        return str(self.os_error)
+
+
+@contextmanager
+def input_io():
+    """Context manager changing OSError to InputOSError."""
+    try:
+        yield
+    except OSError as os_error:
+        raise InputOSError(os_error) from os_error
+
+
+def input_io_iter(iterator):
+    while True:
+        try:
+            with input_io():
+                item = next(iterator)
+        except StopIteration:
+            return
+        yield item
+
+
 class DownloadPipeline:
 
     def __init__(self, repository, key):
@@ -464,12 +496,14 @@ Number of files: {0.stats.nfiles}'''.format(
         }
         if self.numeric_owner:
             item[b'user'] = item[b'group'] = None
-        xattrs = xattr.get_all(path, follow_symlinks=False)
+        with input_io():
+            xattrs = xattr.get_all(path, follow_symlinks=False)
         if xattrs:
             item[b'xattrs'] = StableDict(xattrs)
         if has_lchflags and st.st_flags:
             item[b'bsdflags'] = st.st_flags
-        acl_get(path, item, st, self.numeric_owner)
+        with input_io():
+            acl_get(path, item, st, self.numeric_owner)
         return item
 
     def process_dir(self, path, st):
@@ -504,7 +538,7 @@ Number of files: {0.stats.nfiles}'''.format(
         uid, gid = 0, 0
         fd = sys.stdin.buffer  # binary
         chunks = []
-        for chunk in self.chunker.chunkify(fd):
+        for chunk in input_io_iter(self.chunker.chunkify(fd)):
             chunks.append(cache.add_chunk(self.key.id_hash(chunk), chunk, self.stats))
         self.stats.nfiles += 1
         t = int_to_bigint(int(time.time()) * 1000000000)
@@ -552,10 +586,11 @@ Number of files: {0.stats.nfiles}'''.format(
         item = {b'path': safe_path}
         # Only chunkify the file if needed
         if chunks is None:
-            fh = Archive._open_rb(path)
+            with input_io():
+                fh = Archive._open_rb(path)
             with os.fdopen(fh, 'rb') as fd:
                 chunks = []
-                for chunk in self.chunker.chunkify(fd, fh):
+                for chunk in input_io_iter(self.chunker.chunkify(fd, fh)):
                     chunks.append(cache.add_chunk(self.key.id_hash(chunk), chunk, self.stats))
                     if self.show_progress:
                         self.stats.show_progress(item=item, dt=0.2)

+ 3 - 3
borg/archiver.py

@@ -29,7 +29,7 @@ from .upgrader import AtticRepositoryUpgrader, BorgRepositoryUpgrader
 from .repository import Repository
 from .cache import Cache
 from .key import key_creator, RepoKey, PassphraseKey
-from .archive import Archive, ArchiveChecker, CHUNKER_PARAMS
+from .archive import input_io, InputOSError, Archive, ArchiveChecker, CHUNKER_PARAMS
 from .remote import RepositoryServer, RemoteRepository, cache_if_remote
 
 has_lchflags = hasattr(os, 'lchflags')
@@ -198,7 +198,7 @@ class Archiver:
                     if not dry_run:
                         try:
                             status = archive.process_stdin(path, cache)
-                        except OSError as e:
+                        except InputOSError as e:
                             status = 'E'
                             self.print_warning('%s: %s', path, e)
                     else:
@@ -273,7 +273,7 @@ class Archiver:
             if not dry_run:
                 try:
                     status = archive.process_file(path, st, cache, self.ignore_inode)
-                except OSError as e:
+                except InputOSError as e:
                     status = 'E'
                     self.print_warning('%s: %s', path, e)
         elif stat.S_ISDIR(st.st_mode):

+ 21 - 16
borg/remote.py

@@ -241,6 +241,24 @@ class RemoteRepository:
                 del self.cache[args]
             return msgid
 
+        def handle_error(error, res):
+            if error == b'DoesNotExist':
+                raise Repository.DoesNotExist(self.location.orig)
+            elif error == b'AlreadyExists':
+                raise Repository.AlreadyExists(self.location.orig)
+            elif error == b'CheckNeeded':
+                raise Repository.CheckNeeded(self.location.orig)
+            elif error == b'IntegrityError':
+                raise IntegrityError(res)
+            elif error == b'PathNotAllowed':
+                raise PathNotAllowed(*res)
+            elif error == b'ObjectNotFound':
+                raise Repository.ObjectNotFound(res[0], self.location.orig)
+            elif error == b'InvalidRPCMethod':
+                raise InvalidRPCMethod(*res)
+            else:
+                raise self.RPCError(res.decode('utf-8'))
+
         calls = list(calls)
         waiting_for = []
         w_fds = [self.stdin_fd]
@@ -250,22 +268,7 @@ class RemoteRepository:
                     error, res = self.responses.pop(waiting_for[0])
                     waiting_for.pop(0)
                     if error:
-                        if error == b'DoesNotExist':
-                            raise Repository.DoesNotExist(self.location.orig)
-                        elif error == b'AlreadyExists':
-                            raise Repository.AlreadyExists(self.location.orig)
-                        elif error == b'CheckNeeded':
-                            raise Repository.CheckNeeded(self.location.orig)
-                        elif error == b'IntegrityError':
-                            raise IntegrityError(res)
-                        elif error == b'PathNotAllowed':
-                            raise PathNotAllowed(*res)
-                        elif error == b'ObjectNotFound':
-                            raise Repository.ObjectNotFound(res[0], self.location.orig)
-                        elif error == b'InvalidRPCMethod':
-                            raise InvalidRPCMethod(*res)
-                        else:
-                            raise self.RPCError(res.decode('utf-8'))
+                        handle_error(error, res)
                     else:
                         yield res
                         if not waiting_for and not calls:
@@ -287,6 +290,8 @@ class RemoteRepository:
                         type, msgid, error, res = unpacked
                         if msgid in self.ignore_responses:
                             self.ignore_responses.remove(msgid)
+                            if error:
+                                handle_error(error, res)
                         else:
                             self.responses[msgid] = error, res
                 elif fd is self.stderr_fd:

+ 25 - 0
borg/testsuite/archive.py

@@ -5,6 +5,7 @@ import msgpack
 import pytest
 
 from ..archive import Archive, CacheChunkBuffer, RobustUnpacker, valid_msgpacked_dict, ITEM_KEYS
+from ..archive import InputOSError, input_io, input_io_iter
 from ..key import PlaintextKey
 from ..helpers import Manifest
 from . import BaseTestCase
@@ -145,3 +146,27 @@ def test_key_length_msgpacked_items():
     data = {key: b''}
     item_keys_serialized = [msgpack.packb(key), ]
     assert valid_msgpacked_dict(msgpack.packb(data), item_keys_serialized)
+
+
+def test_input_io():
+    with pytest.raises(InputOSError):
+        with input_io():
+            raise OSError(123)
+
+
+def test_input_io_iter():
+    class Iterator:
+        def __init__(self, exc):
+            self.exc = exc
+
+        def __next__(self):
+            raise self.exc()
+
+    oserror_iterator = Iterator(OSError)
+    with pytest.raises(InputOSError):
+        for _ in input_io_iter(oserror_iterator):
+            pass
+
+    normal_iterator = Iterator(StopIteration)
+    for _ in input_io_iter(normal_iterator):
+        assert False, 'StopIteration handled incorrectly'