Преглед на файлове

Merge pull request #777 from enkore/feature-ctxmng

Feature ctxmng: Repository context manager
TW преди 9 години
родител
ревизия
db171e998e
променени са 8 файла, в които са добавени 312 реда и са изтрити 256 реда
  1. 147 141
      borg/archiver.py
  2. 15 1
      borg/remote.py
  3. 19 4
      borg/repository.py
  4. 4 0
      borg/testsuite/__init__.py
  5. 47 36
      borg/testsuite/archiver.py
  6. 32 25
      borg/testsuite/repository.py
  7. 24 26
      borg/testsuite/upgrader.py
  8. 24 23
      borg/upgrader.py

+ 147 - 141
borg/archiver.py

@@ -40,18 +40,56 @@ UMASK_DEFAULT = 0o077
 DASHES = '-' * 78
 
 
-class ToggleAction(argparse.Action):
-    """argparse action to handle "toggle" flags easily
+def argument(args, str_or_bool):
+    """If bool is passed, return it. If str is passed, retrieve named attribute from args."""
+    if isinstance(str_or_bool, str):
+        return getattr(args, str_or_bool)
+    return str_or_bool
 
-    toggle flags are in the form of ``--foo``, ``--no-foo``.
 
-    the ``--no-foo`` argument still needs to be passed to the
-    ``add_argument()`` call, but it simplifies the ``--no``
-    detection.
+def with_repository(fake=False, create=False, lock=True, exclusive=False, manifest=True, cache=False):
     """
-    def __call__(self, parser, ns, values, option):
-        """set the given flag to true unless ``--no`` is passed"""
-        setattr(ns, self.dest, not option.startswith('--no-'))
+    Method decorator for subcommand-handling methods: do_XYZ(self, args, repository, …)
+
+    If a parameter (where allowed) is a str the attribute named of args is used instead.
+    :param fake: (str or bool) use None instead of repository, don't do anything else
+    :param create: create repository
+    :param lock: lock repository
+    :param exclusive: (str or bool) lock repository exclusively (for writing)
+    :param manifest: load manifest and key, pass them as keyword arguments
+    :param cache: open cache, pass it as keyword argument (implies manifest)
+    """
+    def decorator(method):
+        @functools.wraps(method)
+        def wrapper(self, args, **kwargs):
+            location = args.location  # note: 'location' must be always present in args
+            if argument(args, fake):
+                return method(self, args, repository=None, **kwargs)
+            elif location.proto == 'ssh':
+                repository = RemoteRepository(location, create=create, lock_wait=self.lock_wait, lock=lock, args=args)
+            else:
+                repository = Repository(location.path, create=create, exclusive=argument(args, exclusive),
+                                        lock_wait=self.lock_wait, lock=lock)
+            with repository:
+                if manifest or cache:
+                    kwargs['manifest'], kwargs['key'] = Manifest.load(repository)
+                if cache:
+                    with Cache(repository, kwargs['key'], kwargs['manifest'],
+                               do_files=getattr(args, 'cache_files', False), lock_wait=self.lock_wait) as cache_:
+                        return method(self, args, repository=repository, cache=cache_, **kwargs)
+                else:
+                    return method(self, args, repository=repository, **kwargs)
+        return wrapper
+    return decorator
+
+
+def with_archive(method):
+    @functools.wraps(method)
+    def wrapper(self, args, repository, key, manifest, **kwargs):
+        archive = Archive(repository, key, manifest, args.location.archive,
+                          numeric_owner=getattr(args, 'numeric_owner', False), cache=kwargs.get('cache'))
+        return method(self, args, repository=repository, manifest=manifest, key=key, archive=archive, **kwargs)
+    return wrapper
 
 
 class Archiver:
@@ -60,14 +98,6 @@ class Archiver:
         self.exit_code = EXIT_SUCCESS
         self.lock_wait = lock_wait
 
-    def open_repository(self, args, create=False, exclusive=False, lock=True):
-        location = args.location  # note: 'location' must be always present in args
-        if location.proto == 'ssh':
-            repository = RemoteRepository(location, create=create, lock_wait=self.lock_wait, lock=lock, args=args)
-        else:
-            repository = Repository(location.path, create=create, exclusive=exclusive, lock_wait=self.lock_wait, lock=lock)
-        return repository
-
     def print_error(self, msg, *args):
         msg = args and msg % args or msg
         self.exit_code = EXIT_ERROR
@@ -126,10 +156,10 @@ class Archiver:
         """
         return RepositoryServer(restrict_to_paths=args.restrict_to_paths).serve()
 
-    def do_init(self, args):
+    @with_repository(create=True, exclusive=True, manifest=False)
+    def do_init(self, args, repository):
         """Initialize an empty repository"""
         logger.info('Initializing repository at "%s"' % args.location.canonical_path())
-        repository = self.open_repository(args, create=True, exclusive=True)
         key = key_creator(repository, args)
         manifest = Manifest(key, repository)
         manifest.key = key
@@ -139,9 +169,9 @@ class Archiver:
             pass
         return self.exit_code
 
-    def do_check(self, args):
+    @with_repository(exclusive='repair', manifest=False)
+    def do_check(self, args, repository):
         """Check repository consistency"""
-        repository = self.open_repository(args, exclusive=args.repair)
         if args.repair:
             msg = ("'check --repair' is an experimental feature that might result in data loss." +
                    "\n" +
@@ -158,16 +188,15 @@ class Archiver:
             return EXIT_WARNING
         return EXIT_SUCCESS
 
-    def do_change_passphrase(self, args):
+    @with_repository()
+    def do_change_passphrase(self, args, repository, manifest, key):
         """Change repository key file passphrase"""
-        repository = self.open_repository(args)
-        manifest, key = Manifest.load(repository)
         key.change_passphrase()
         return EXIT_SUCCESS
 
-    def do_migrate_to_repokey(self, args):
+    @with_repository(manifest=False)
+    def do_migrate_to_repokey(self, args, repository):
         """Migrate passphrase -> repokey"""
-        repository = self.open_repository(args)
         manifest_data = repository.get(Manifest.MANIFEST_ID)
         key_old = PassphraseKey.detect(repository, manifest_data)
         key_new = RepoKey(repository)
@@ -180,7 +209,8 @@ class Archiver:
         key_new.change_passphrase()  # option to change key protection passphrase, save
         return EXIT_SUCCESS
 
-    def do_create(self, args):
+    @with_repository(fake='dry_run')
+    def do_create(self, args, repository, manifest=None, key=None):
         """Create new archive"""
         matcher = PatternMatcher(fallback=True)
         if args.excludes:
@@ -245,8 +275,6 @@ class Archiver:
         dry_run = args.dry_run
         t0 = datetime.utcnow()
         if not dry_run:
-            repository = self.open_repository(args, exclusive=True)
-            manifest, key = Manifest.load(repository)
             compr_args = dict(buffer=COMPR_BUFFER)
             compr_args.update(args.compression)
             key.compressor = Compressor(**compr_args)
@@ -339,17 +367,15 @@ class Archiver:
                 status = '-'  # dry run, item was not backed up
         self.print_file_status(status, path)
 
-    def do_extract(self, args):
+    @with_repository()
+    @with_archive
+    def do_extract(self, args, repository, manifest, key, archive):
         """Extract archive contents"""
         # be restrictive when restoring files, restore permissions later
         if sys.getfilesystemencoding() == 'ascii':
             logger.warning('Warning: File system encoding is "ascii", extracting non-ascii filenames will not be supported.')
             if sys.platform.startswith(('linux', 'freebsd', 'netbsd', 'openbsd', 'darwin', )):
                 logger.warning('Hint: You likely need to fix your locale setup. E.g. install locales and use: LANG=en_US.UTF-8')
-        repository = self.open_repository(args)
-        manifest, key = Manifest.load(repository)
-        archive = Archive(repository, key, manifest, args.location.archive,
-                          numeric_owner=args.numeric_owner)
 
         matcher, include_patterns = self.build_matcher(args.excludes, args.paths)
 
@@ -403,7 +429,9 @@ class Archiver:
                 self.print_warning("Include pattern '%s' never matched.", pattern)
         return self.exit_code
 
-    def do_diff(self, args):
+    @with_repository()
+    @with_archive
+    def do_diff(self, args, repository, manifest, key, archive):
         """Diff contents of two archives"""
         def format_bytes(count):
             if count is None:
@@ -499,9 +527,7 @@ class Archiver:
                     b'chunks': [],
                 }, deleted=True)
 
-        repository = self.open_repository(args)
-        manifest, key = Manifest.load(repository)
-        archive1 = Archive(repository, key, manifest, args.location.archive)
+        archive1 = archive
         archive2 = Archive(repository, key, manifest, args.archive2)
 
         can_compare_chunk_ids = archive1.metadata.get(b'chunker_params', False) == archive2.metadata.get(
@@ -520,55 +546,52 @@ class Archiver:
                 self.print_warning("Include pattern '%s' never matched.", pattern)
         return self.exit_code
 
-    def do_rename(self, args):
+    @with_repository(exclusive=True, cache=True)
+    @with_archive
+    def do_rename(self, args, repository, manifest, key, cache, archive):
         """Rename an existing archive"""
-        repository = self.open_repository(args, exclusive=True)
-        manifest, key = Manifest.load(repository)
-        with Cache(repository, key, manifest, lock_wait=self.lock_wait) as cache:
-            archive = Archive(repository, key, manifest, args.location.archive, cache=cache)
-            archive.rename(args.name)
-            manifest.write()
-            repository.commit()
-            cache.commit()
+        archive.rename(args.name)
+        manifest.write()
+        repository.commit()
+        cache.commit()
         return self.exit_code
 
-    def do_delete(self, args):
+    @with_repository(exclusive=True, cache=True)
+    def do_delete(self, args, repository, manifest, key, cache):
         """Delete an existing repository or archive"""
-        repository = self.open_repository(args, exclusive=True)
-        manifest, key = Manifest.load(repository)
-        with Cache(repository, key, manifest, do_files=args.cache_files, lock_wait=self.lock_wait) as cache:
-            if args.location.archive:
-                archive = Archive(repository, key, manifest, args.location.archive, cache=cache)
-                stats = Statistics()
-                archive.delete(stats, progress=args.progress)
-                manifest.write()
-                repository.commit(save_space=args.save_space)
-                cache.commit()
-                logger.info("Archive deleted.")
-                if args.stats:
-                    log_multi(DASHES,
-                              stats.summary.format(label='Deleted data:', stats=stats),
-                              str(cache),
-                              DASHES)
-            else:
-                if not args.cache_only:
-                    msg = []
-                    msg.append("You requested to completely DELETE the repository *including* all archives it contains:")
-                    for archive_info in manifest.list_archive_infos(sort_by='ts'):
-                        msg.append(format_archive(archive_info))
-                    msg.append("Type 'YES' if you understand this and want to continue: ")
-                    msg = '\n'.join(msg)
-                    if not yes(msg, false_msg="Aborting.", truish=('YES', ),
-                               env_var_override='BORG_DELETE_I_KNOW_WHAT_I_AM_DOING'):
-                        self.exit_code = EXIT_ERROR
-                        return self.exit_code
-                    repository.destroy()
-                    logger.info("Repository deleted.")
-                cache.destroy()
-                logger.info("Cache deleted.")
+        if args.location.archive:
+            archive = Archive(repository, key, manifest, args.location.archive, cache=cache)
+            stats = Statistics()
+            archive.delete(stats, progress=args.progress)
+            manifest.write()
+            repository.commit(save_space=args.save_space)
+            cache.commit()
+            logger.info("Archive deleted.")
+            if args.stats:
+                log_multi(DASHES,
+                          stats.summary.format(label='Deleted data:', stats=stats),
+                          str(cache),
+                          DASHES)
+        else:
+            if not args.cache_only:
+                msg = []
+                msg.append("You requested to completely DELETE the repository *including* all archives it contains:")
+                for archive_info in manifest.list_archive_infos(sort_by='ts'):
+                    msg.append(format_archive(archive_info))
+                msg.append("Type 'YES' if you understand this and want to continue: ")
+                msg = '\n'.join(msg)
+                if not yes(msg, false_msg="Aborting.", truish=('YES', ),
+                           env_var_override='BORG_DELETE_I_KNOW_WHAT_I_AM_DOING'):
+                    self.exit_code = EXIT_ERROR
+                    return self.exit_code
+                repository.destroy()
+                logger.info("Repository deleted.")
+            cache.destroy()
+            logger.info("Cache deleted.")
         return self.exit_code
 
-    def do_mount(self, args):
+    @with_repository()
+    def do_mount(self, args, repository, manifest, key):
         """Mount archive or an entire repository as a FUSE fileystem"""
         try:
             from .fuse import FuseOperations
@@ -580,29 +603,23 @@ class Archiver:
             self.print_error('%s: Mountpoint must be a writable directory' % args.mountpoint)
             return self.exit_code
 
-        repository = self.open_repository(args)
-        try:
-            with cache_if_remote(repository) as cached_repo:
-                manifest, key = Manifest.load(repository)
-                if args.location.archive:
-                    archive = Archive(repository, key, manifest, args.location.archive)
-                else:
-                    archive = None
-                operations = FuseOperations(key, repository, manifest, archive, cached_repo)
-                logger.info("Mounting filesystem")
-                try:
-                    operations.mount(args.mountpoint, args.options, args.foreground)
-                except RuntimeError:
-                    # Relevant error message already printed to stderr by fuse
-                    self.exit_code = EXIT_ERROR
-        finally:
-            repository.close()
+        with cache_if_remote(repository) as cached_repo:
+            if args.location.archive:
+                archive = Archive(repository, key, manifest, args.location.archive)
+            else:
+                archive = None
+            operations = FuseOperations(key, repository, manifest, archive, cached_repo)
+            logger.info("Mounting filesystem")
+            try:
+                operations.mount(args.mountpoint, args.options, args.foreground)
+            except RuntimeError:
+                # Relevant error message already printed to stderr by fuse
+                self.exit_code = EXIT_ERROR
         return self.exit_code
 
-    def do_list(self, args):
+    @with_repository()
+    def do_list(self, args, repository, manifest, key):
         """List archive or repository contents"""
-        repository = self.open_repository(args)
-        manifest, key = Manifest.load(repository)
         if args.location.archive:
             matcher, _ = self.build_matcher(args.excludes, args.paths)
 
@@ -626,7 +643,6 @@ class Archiver:
                     write = sys.stdout.buffer.write
                 for item in archive.iter_items(lambda item: matcher.match(item[b'path'])):
                     write(formatter.format_item(item).encode('utf-8', errors='surrogateescape'))
-            repository.close()
         else:
             for archive_info in manifest.list_archive_infos(sort_by='ts'):
                 if args.prefix and not archive_info.name.startswith(args.prefix):
@@ -637,30 +653,27 @@ class Archiver:
                     print(format_archive(archive_info))
         return self.exit_code
 
-    def do_info(self, args):
+    @with_repository(cache=True)
+    @with_archive
+    def do_info(self, args, repository, manifest, key, archive, cache):
         """Show archive details such as disk space used"""
-        repository = self.open_repository(args)
-        manifest, key = Manifest.load(repository)
-        with Cache(repository, key, manifest, do_files=args.cache_files, lock_wait=self.lock_wait) as cache:
-            archive = Archive(repository, key, manifest, args.location.archive, cache=cache)
-            stats = archive.calc_stats(cache)
-            print('Name:', archive.name)
-            print('Fingerprint: %s' % hexlify(archive.id).decode('ascii'))
-            print('Hostname:', archive.metadata[b'hostname'])
-            print('Username:', archive.metadata[b'username'])
-            print('Time (start): %s' % format_time(to_localtime(archive.ts)))
-            print('Time (end):   %s' % format_time(to_localtime(archive.ts_end)))
-            print('Command line:', remove_surrogates(' '.join(archive.metadata[b'cmdline'])))
-            print('Number of files: %d' % stats.nfiles)
-            print()
-            print(str(stats))
-            print(str(cache))
+        stats = archive.calc_stats(cache)
+        print('Name:', archive.name)
+        print('Fingerprint: %s' % hexlify(archive.id).decode('ascii'))
+        print('Hostname:', archive.metadata[b'hostname'])
+        print('Username:', archive.metadata[b'username'])
+        print('Time (start): %s' % format_time(to_localtime(archive.ts)))
+        print('Time (end):   %s' % format_time(to_localtime(archive.ts_end)))
+        print('Command line:', remove_surrogates(' '.join(archive.metadata[b'cmdline'])))
+        print('Number of files: %d' % stats.nfiles)
+        print()
+        print(str(stats))
+        print(str(cache))
         return self.exit_code
 
-    def do_prune(self, args):
+    @with_repository()
+    def do_prune(self, args, repository, manifest, key):
         """Prune repository archives according to specified rules"""
-        repository = self.open_repository(args, exclusive=True)
-        manifest, key = Manifest.load(repository)
         archives = manifest.list_archive_infos(sort_by='ts', reverse=True)  # just a ArchiveInfo list
         if args.hourly + args.daily + args.weekly + args.monthly + args.yearly == 0 and args.within is None:
             self.print_error('At least one of the "keep-within", "keep-hourly", "keep-daily", "keep-weekly", '
@@ -725,10 +738,9 @@ class Archiver:
             print("warning: %s" % e)
         return self.exit_code
 
-    def do_debug_dump_archive_items(self, args):
+    @with_repository()
+    def do_debug_dump_archive_items(self, args, repository, manifest, key):
         """dump (decrypted, decompressed) archive items metadata (not: data)"""
-        repository = self.open_repository(args)
-        manifest, key = Manifest.load(repository)
         archive = Archive(repository, key, manifest, args.location.archive)
         for i, item_id in enumerate(archive.metadata[b'items']):
             data = key.decrypt(item_id, repository.get(item_id))
@@ -739,10 +751,9 @@ class Archiver:
         print('Done.')
         return EXIT_SUCCESS
 
-    def do_debug_get_obj(self, args):
+    @with_repository(manifest=False)
+    def do_debug_get_obj(self, args, repository):
         """get object contents from the repository and write it into file"""
-        repository = self.open_repository(args)
-        manifest, key = Manifest.load(repository)
         hex_id = args.id
         try:
             id = unhexlify(hex_id)
@@ -759,10 +770,9 @@ class Archiver:
                 print("object %s fetched." % hex_id)
         return EXIT_SUCCESS
 
-    def do_debug_put_obj(self, args):
+    @with_repository(manifest=False)
+    def do_debug_put_obj(self, args, repository):
         """put file(s) contents into the repository"""
-        repository = self.open_repository(args)
-        manifest, key = Manifest.load(repository)
         for path in args.paths:
             with open(path, "rb") as f:
                 data = f.read()
@@ -772,10 +782,9 @@ class Archiver:
         repository.commit()
         return EXIT_SUCCESS
 
-    def do_debug_delete_obj(self, args):
+    @with_repository(manifest=False)
+    def do_debug_delete_obj(self, args, repository):
         """delete the objects with the given IDs from the repo"""
-        repository = self.open_repository(args)
-        manifest, key = Manifest.load(repository)
         modified = False
         for hex_id in args.ids:
             try:
@@ -794,14 +803,11 @@ class Archiver:
         print('Done.')
         return EXIT_SUCCESS
 
-    def do_break_lock(self, args):
+    @with_repository(lock=False, manifest=False)
+    def do_break_lock(self, args, repository):
         """Break the repository lock (e.g. in case it was left by a dead borg."""
-        repository = self.open_repository(args, lock=False)
-        try:
-            repository.break_lock()
-            Cache.break_lock(repository)
-        finally:
-            repository.close()
+        repository.break_lock()
+        Cache.break_lock(repository)
         return self.exit_code
 
     helptext = {}

+ 15 - 1
borg/remote.py

@@ -77,6 +77,7 @@ class RepositoryServer:  # pragma: no cover
             if r:
                 data = os.read(stdin_fd, BUFSIZE)
                 if not data:
+                    self.repository.close()
                     return
                 unpacker.feed(data)
                 for unpacked in unpacker:
@@ -100,6 +101,7 @@ class RepositoryServer:  # pragma: no cover
                     else:
                         os.write(stdout_fd, msgpack.packb((1, msgid, None, res)))
             if es:
+                self.repository.close()
                 return
 
     def negotiate(self, versions):
@@ -117,6 +119,7 @@ class RepositoryServer:  # pragma: no cover
             else:
                 raise PathNotAllowed(path)
         self.repository = Repository(path, create, lock_wait=lock_wait, lock=lock)
+        self.repository.__enter__()  # clean exit handled by serve() method
         return self.repository.id
 
 
@@ -164,11 +167,21 @@ class RemoteRepository:
         self.id = self.call('open', location.path, create, lock_wait, lock)
 
     def __del__(self):
-        self.close()
+        if self.p:
+            self.close()
+            assert False, "cleanup happened in Repository.__del__"
 
     def __repr__(self):
         return '<%s %s>' % (self.__class__.__name__, self.location.canonical_path())
 
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        if exc_type is not None:
+            self.rollback()
+        self.close()
+
     def borg_cmd(self, args, testing):
         """return a borg serve command line"""
         # give some args/options to "borg serve" process as they were given to us
@@ -392,6 +405,7 @@ class RepositoryCache(RepositoryNoCache):
         super().__init__(repository)
         tmppath = tempfile.mkdtemp(prefix='borg-tmp')
         self.caching_repo = Repository(tmppath, create=True, exclusive=True)
+        self.caching_repo.__enter__()  # handled by context manager in base class
 
     def close(self):
         if self.caching_repo is not None:

+ 19 - 4
borg/repository.py

@@ -59,16 +59,31 @@ class Repository:
         self.lock = None
         self.index = None
         self._active_txn = False
-        if create:
-            self.create(self.path)
-        self.open(self.path, exclusive, lock_wait=lock_wait, lock=lock)
+        self.lock_wait = lock_wait
+        self.do_lock = lock
+        self.do_create = create
+        self.exclusive = exclusive
 
     def __del__(self):
-        self.close()
+        if self.lock:
+            self.close()
+            assert False, "cleanup happened in Repository.__del__"
 
     def __repr__(self):
         return '<%s %s>' % (self.__class__.__name__, self.path)
 
+    def __enter__(self):
+        if self.do_create:
+            self.do_create = False
+            self.create(self.path)
+        self.open(self.path, self.exclusive, lock_wait=self.lock_wait, lock=self.do_lock)
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        if exc_type is not None:
+            self.rollback()
+        self.close()
+
     def create(self, path):
         """Create a new empty repository at `path`
         """

+ 4 - 0
borg/testsuite/__init__.py

@@ -8,6 +8,7 @@ import sysconfig
 import time
 import unittest
 from ..xattr import get_all
+from ..logger import setup_logging
 
 try:
     import llfuse
@@ -30,6 +31,9 @@ else:
 if sys.platform.startswith('netbsd'):
     st_mtime_ns_round = -4  # only >1 microsecond resolution here?
 
+# Ensure that the loggers exist for all tests
+setup_logging()
+
 
 class BaseTestCase(unittest.TestCase):
     """

+ 47 - 36
borg/testsuite/archiver.py

@@ -367,7 +367,8 @@ class ArchiverTestCase(ArchiverTestCaseBase):
             assert sto.st_atime_ns == atime * 1e9
 
     def _extract_repository_id(self, path):
-        return Repository(self.repository_path).id
+        with Repository(self.repository_path) as repository:
+            return repository.id
 
     def _set_repository_id(self, path, id):
         config = ConfigParser(interpolation=None)
@@ -375,7 +376,8 @@ class ArchiverTestCase(ArchiverTestCaseBase):
         config.set('repository', 'id', hexlify(id).decode('ascii'))
         with open(os.path.join(path, 'config'), 'w') as fd:
             config.write(fd)
-        return Repository(self.repository_path).id
+        with Repository(self.repository_path) as repository:
+            return repository.id
 
     def test_sparse_file(self):
         # no sparse file support on Mac OS X
@@ -745,8 +747,8 @@ class ArchiverTestCase(ArchiverTestCaseBase):
         self.cmd('extract', '--dry-run', self.repository_location + '::test.3')
         self.cmd('extract', '--dry-run', self.repository_location + '::test.4')
         # Make sure both archives have been renamed
-        repository = Repository(self.repository_path)
-        manifest, key = Manifest.load(repository)
+        with Repository(self.repository_path) as repository:
+            manifest, key = Manifest.load(repository)
         self.assert_equal(len(manifest.archives), 2)
         self.assert_in('test.3', manifest.archives)
         self.assert_in('test.4', manifest.archives)
@@ -763,8 +765,8 @@ class ArchiverTestCase(ArchiverTestCaseBase):
         self.cmd('extract', '--dry-run', self.repository_location + '::test.2')
         self.cmd('delete', '--stats', self.repository_location + '::test.2')
         # Make sure all data except the manifest has been deleted
-        repository = Repository(self.repository_path)
-        self.assert_equal(len(repository), 1)
+        with Repository(self.repository_path) as repository:
+            self.assert_equal(len(repository), 1)
 
     def test_delete_repo(self):
         self.create_regular_file('file1', size=1024 * 80)
@@ -772,6 +774,11 @@ class ArchiverTestCase(ArchiverTestCaseBase):
         self.cmd('init', self.repository_location)
         self.cmd('create', self.repository_location + '::test', 'input')
         self.cmd('create', self.repository_location + '::test.2', 'input')
+        os.environ['BORG_DELETE_I_KNOW_WHAT_I_AM_DOING'] = 'no'
+        self.cmd('delete', self.repository_location, exit_code=2)
+        self.archiver.exit_code = 0
+        assert os.path.exists(self.repository_path)
+        os.environ['BORG_DELETE_I_KNOW_WHAT_I_AM_DOING'] = 'YES'
         self.cmd('delete', self.repository_location)
         # Make sure the repo is gone
         self.assertFalse(os.path.exists(self.repository_path))
@@ -810,8 +817,8 @@ class ArchiverTestCase(ArchiverTestCaseBase):
         self.cmd('init', self.repository_location)
         self.cmd('create', '--dry-run', self.repository_location + '::test', 'input')
         # Make sure no archive has been created
-        repository = Repository(self.repository_path)
-        manifest, key = Manifest.load(repository)
+        with Repository(self.repository_path) as repository:
+            manifest, key = Manifest.load(repository)
         self.assert_equal(len(manifest.archives), 0)
 
     def test_progress(self):
@@ -1045,17 +1052,17 @@ class ArchiverTestCase(ArchiverTestCaseBase):
         used = set()  # counter values already used
 
         def verify_uniqueness():
-            repository = Repository(self.repository_path)
-            for key, _ in repository.open_index(repository.get_transaction_id()).iteritems():
-                data = repository.get(key)
-                hash = sha256(data).digest()
-                if hash not in seen:
-                    seen.add(hash)
-                    num_blocks = num_aes_blocks(len(data) - 41)
-                    nonce = bytes_to_long(data[33:41])
-                    for counter in range(nonce, nonce + num_blocks):
-                        self.assert_not_in(counter, used)
-                        used.add(counter)
+            with Repository(self.repository_path) as repository:
+                for key, _ in repository.open_index(repository.get_transaction_id()).iteritems():
+                    data = repository.get(key)
+                    hash = sha256(data).digest()
+                    if hash not in seen:
+                        seen.add(hash)
+                        num_blocks = num_aes_blocks(len(data) - 41)
+                        nonce = bytes_to_long(data[33:41])
+                        for counter in range(nonce, nonce + num_blocks):
+                            self.assert_not_in(counter, used)
+                            used.add(counter)
 
         self.create_test_files()
         os.environ['BORG_PASSPHRASE'] = 'passphrase'
@@ -1122,8 +1129,9 @@ class ArchiverCheckTestCase(ArchiverTestCaseBase):
 
     def open_archive(self, name):
         repository = Repository(self.repository_path)
-        manifest, key = Manifest.load(repository)
-        archive = Archive(repository, key, manifest, name)
+        with repository:
+            manifest, key = Manifest.load(repository)
+            archive = Archive(repository, key, manifest, name)
         return archive, repository
 
     def test_check_usage(self):
@@ -1141,35 +1149,39 @@ class ArchiverCheckTestCase(ArchiverTestCaseBase):
 
     def test_missing_file_chunk(self):
         archive, repository = self.open_archive('archive1')
-        for item in archive.iter_items():
-            if item[b'path'].endswith('testsuite/archiver.py'):
-                repository.delete(item[b'chunks'][-1][0])
-                break
-        repository.commit()
+        with repository:
+            for item in archive.iter_items():
+                if item[b'path'].endswith('testsuite/archiver.py'):
+                    repository.delete(item[b'chunks'][-1][0])
+                    break
+            repository.commit()
         self.cmd('check', self.repository_location, exit_code=1)
         self.cmd('check', '--repair', self.repository_location, exit_code=0)
         self.cmd('check', self.repository_location, exit_code=0)
 
     def test_missing_archive_item_chunk(self):
         archive, repository = self.open_archive('archive1')
-        repository.delete(archive.metadata[b'items'][-5])
-        repository.commit()
+        with repository:
+            repository.delete(archive.metadata[b'items'][-5])
+            repository.commit()
         self.cmd('check', self.repository_location, exit_code=1)
         self.cmd('check', '--repair', self.repository_location, exit_code=0)
         self.cmd('check', self.repository_location, exit_code=0)
 
     def test_missing_archive_metadata(self):
         archive, repository = self.open_archive('archive1')
-        repository.delete(archive.id)
-        repository.commit()
+        with repository:
+            repository.delete(archive.id)
+            repository.commit()
         self.cmd('check', self.repository_location, exit_code=1)
         self.cmd('check', '--repair', self.repository_location, exit_code=0)
         self.cmd('check', self.repository_location, exit_code=0)
 
     def test_missing_manifest(self):
         archive, repository = self.open_archive('archive1')
-        repository.delete(Manifest.MANIFEST_ID)
-        repository.commit()
+        with repository:
+            repository.delete(Manifest.MANIFEST_ID)
+            repository.commit()
         self.cmd('check', self.repository_location, exit_code=1)
         output = self.cmd('check', '-v', '--repair', self.repository_location, exit_code=0)
         self.assert_in('archive1', output)
@@ -1178,10 +1190,9 @@ class ArchiverCheckTestCase(ArchiverTestCaseBase):
 
     def test_extra_chunks(self):
         self.cmd('check', self.repository_location, exit_code=0)
-        repository = Repository(self.repository_location)
-        repository.put(b'01234567890123456789012345678901', b'xxxx')
-        repository.commit()
-        repository.close()
+        with Repository(self.repository_location) as repository:
+            repository.put(b'01234567890123456789012345678901', b'xxxx')
+            repository.commit()
         self.cmd('check', self.repository_location, exit_code=1)
         self.cmd('check', self.repository_location, exit_code=1)
         self.cmd('check', '--repair', self.repository_location, exit_code=0)

+ 32 - 25
borg/testsuite/repository.py

@@ -21,6 +21,7 @@ class RepositoryTestCaseBase(BaseTestCase):
     def setUp(self):
         self.tmppath = tempfile.mkdtemp()
         self.repository = self.open(create=True)
+        self.repository.__enter__()
 
     def tearDown(self):
         self.repository.close()
@@ -43,13 +44,12 @@ class RepositoryTestCase(RepositoryTestCaseBase):
         self.assert_raises(Repository.ObjectNotFound, lambda: self.repository.get(key50))
         self.repository.commit()
         self.repository.close()
-        repository2 = self.open()
-        self.assert_raises(Repository.ObjectNotFound, lambda: repository2.get(key50))
-        for x in range(100):
-            if x == 50:
-                continue
-            self.assert_equal(repository2.get(('%-32d' % x).encode('ascii')), b'SOMEDATA')
-        repository2.close()
+        with self.open() as repository2:
+            self.assert_raises(Repository.ObjectNotFound, lambda: repository2.get(key50))
+            for x in range(100):
+                if x == 50:
+                    continue
+                self.assert_equal(repository2.get(('%-32d' % x).encode('ascii')), b'SOMEDATA')
 
     def test2(self):
         """Test multiple sequential transactions
@@ -100,13 +100,14 @@ class RepositoryTestCase(RepositoryTestCaseBase):
         self.repository.close()
         # replace
         self.repository = self.open()
-        self.repository.put(b'00000000000000000000000000000000', b'bar')
-        self.repository.commit()
-        self.repository.close()
+        with self.repository:
+            self.repository.put(b'00000000000000000000000000000000', b'bar')
+            self.repository.commit()
         # delete
         self.repository = self.open()
-        self.repository.delete(b'00000000000000000000000000000000')
-        self.repository.commit()
+        with self.repository:
+            self.repository.delete(b'00000000000000000000000000000000')
+            self.repository.commit()
 
     def test_list(self):
         for x in range(100):
@@ -139,8 +140,9 @@ class RepositoryCommitTestCase(RepositoryTestCaseBase):
             if name.startswith('index.'):
                 os.unlink(os.path.join(self.repository.path, name))
         self.reopen()
-        self.assert_equal(len(self.repository), 3)
-        self.assert_equal(self.repository.check(), True)
+        with self.repository:
+            self.assert_equal(len(self.repository), 3)
+            self.assert_equal(self.repository.check(), True)
 
     def test_crash_before_compact_segments(self):
         self.add_keys()
@@ -150,8 +152,9 @@ class RepositoryCommitTestCase(RepositoryTestCaseBase):
         except TypeError:
             pass
         self.reopen()
-        self.assert_equal(len(self.repository), 3)
-        self.assert_equal(self.repository.check(), True)
+        with self.repository:
+            self.assert_equal(len(self.repository), 3)
+            self.assert_equal(self.repository.check(), True)
 
     def test_replay_of_readonly_repository(self):
         self.add_keys()
@@ -160,8 +163,9 @@ class RepositoryCommitTestCase(RepositoryTestCaseBase):
                 os.unlink(os.path.join(self.repository.path, name))
         with patch.object(UpgradableLock, 'upgrade', side_effect=LockFailed) as upgrade:
             self.reopen()
-            self.assert_raises(LockFailed, lambda: len(self.repository))
-            upgrade.assert_called_once_with()
+            with self.repository:
+                self.assert_raises(LockFailed, lambda: len(self.repository))
+                upgrade.assert_called_once_with()
 
     def test_crash_before_write_index(self):
         self.add_keys()
@@ -171,8 +175,9 @@ class RepositoryCommitTestCase(RepositoryTestCaseBase):
         except TypeError:
             pass
         self.reopen()
-        self.assert_equal(len(self.repository), 3)
-        self.assert_equal(self.repository.check(), True)
+        with self.repository:
+            self.assert_equal(len(self.repository), 3)
+            self.assert_equal(self.repository.check(), True)
 
     def test_crash_before_deleting_compacted_segments(self):
         self.add_keys()
@@ -182,9 +187,10 @@ class RepositoryCommitTestCase(RepositoryTestCaseBase):
         except TypeError:
             pass
         self.reopen()
-        self.assert_equal(len(self.repository), 3)
-        self.assert_equal(self.repository.check(), True)
-        self.assert_equal(len(self.repository), 3)
+        with self.repository:
+            self.assert_equal(len(self.repository), 3)
+            self.assert_equal(self.repository.check(), True)
+            self.assert_equal(len(self.repository), 3)
 
 
 class RepositoryCheckTestCase(RepositoryTestCaseBase):
@@ -313,8 +319,9 @@ class RepositoryCheckTestCase(RepositoryTestCaseBase):
             self.repository.commit()
             compact.assert_called_once_with(save_space=False)
         self.reopen()
-        self.check(repair=True)
-        self.assert_equal(self.repository.get(bytes(32)), b'data2')
+        with self.repository:
+            self.check(repair=True)
+            self.assert_equal(self.repository.get(bytes(32)), b'data2')
 
 
 class RemoteRepositoryTestCase(RepositoryTestCase):

+ 24 - 26
borg/testsuite/upgrader.py

@@ -23,11 +23,9 @@ def repo_valid(path):
     :param path: the path to the repository
     :returns: if borg can check the repository
     """
-    repository = Repository(str(path), create=False)
-    # can't check raises() because check() handles the error
-    state = repository.check()
-    repository.close()
-    return state
+    with Repository(str(path), create=False) as repository:
+        # can't check raises() because check() handles the error
+        return repository.check()
 
 
 def key_valid(path):
@@ -79,11 +77,11 @@ def test_convert_segments(tmpdir, attic_repo, inplace):
     """
     # check should fail because of magic number
     assert not repo_valid(tmpdir)
-    repo = AtticRepositoryUpgrader(str(tmpdir), create=False)
-    segments = [filename for i, filename in repo.io.segment_iterator()]
-    repo.close()
-    repo.convert_segments(segments, dryrun=False, inplace=inplace)
-    repo.convert_cache(dryrun=False)
+    repository = AtticRepositoryUpgrader(str(tmpdir), create=False)
+    with repository:
+        segments = [filename for i, filename in repository.io.segment_iterator()]
+    repository.convert_segments(segments, dryrun=False, inplace=inplace)
+    repository.convert_cache(dryrun=False)
     assert repo_valid(tmpdir)
 
 
@@ -138,9 +136,9 @@ def test_keys(tmpdir, attic_repo, attic_key_file):
     define above)
     :param attic_key_file: an attic.key.KeyfileKey (fixture created above)
     """
-    repository = AtticRepositoryUpgrader(str(tmpdir), create=False)
-    keyfile = AtticKeyfileKey.find_key_file(repository)
-    AtticRepositoryUpgrader.convert_keyfiles(keyfile, dryrun=False)
+    with AtticRepositoryUpgrader(str(tmpdir), create=False) as repository:
+        keyfile = AtticKeyfileKey.find_key_file(repository)
+        AtticRepositoryUpgrader.convert_keyfiles(keyfile, dryrun=False)
     assert key_valid(attic_key_file.path)
 
 
@@ -167,19 +165,19 @@ def test_convert_all(tmpdir, attic_repo, attic_key_file, inplace):
         return stat_segment(path).st_ino
 
     orig_inode = first_inode(attic_repo.path)
-    repo = AtticRepositoryUpgrader(str(tmpdir), create=False)
-    # replicate command dispatch, partly
-    os.umask(UMASK_DEFAULT)
-    backup = repo.upgrade(dryrun=False, inplace=inplace)
-    if inplace:
-        assert backup is None
-        assert first_inode(repo.path) == orig_inode
-    else:
-        assert backup
-        assert first_inode(repo.path) != first_inode(backup)
-        # i have seen cases where the copied tree has world-readable
-        # permissions, which is wrong
-        assert stat_segment(backup).st_mode & UMASK_DEFAULT == 0
+    with AtticRepositoryUpgrader(str(tmpdir), create=False) as repository:
+        # replicate command dispatch, partly
+        os.umask(UMASK_DEFAULT)
+        backup = repository.upgrade(dryrun=False, inplace=inplace)
+        if inplace:
+            assert backup is None
+            assert first_inode(repository.path) == orig_inode
+        else:
+            assert backup
+            assert first_inode(repository.path) != first_inode(backup)
+            # i have seen cases where the copied tree has world-readable
+            # permissions, which is wrong
+            assert stat_segment(backup).st_mode & UMASK_DEFAULT == 0
 
     assert key_valid(attic_key_file.path)
     assert repo_valid(tmpdir)

+ 24 - 23
borg/upgrader.py

@@ -30,23 +30,23 @@ class AtticRepositoryUpgrader(Repository):
         we nevertheless do the order in reverse, as we prefer to do
         the fast stuff first, to improve interactivity.
         """
-        backup = None
-        if not inplace:
-            backup = '{}.upgrade-{:%Y-%m-%d-%H:%M:%S}'.format(self.path, datetime.datetime.now())
-            logger.info('making a hardlink copy in %s', backup)
-            if not dryrun:
-                shutil.copytree(self.path, backup, copy_function=os.link)
-        logger.info("opening attic repository with borg and converting")
-        # now lock the repo, after we have made the copy
-        self.lock = UpgradableLock(os.path.join(self.path, 'lock'), exclusive=True, timeout=1.0).acquire()
-        segments = [filename for i, filename in self.io.segment_iterator()]
-        try:
-            keyfile = self.find_attic_keyfile()
-        except KeyfileNotFoundError:
-            logger.warning("no key file found for repository")
-        else:
-            self.convert_keyfiles(keyfile, dryrun)
-        self.close()
+        with self:
+            backup = None
+            if not inplace:
+                backup = '{}.upgrade-{:%Y-%m-%d-%H:%M:%S}'.format(self.path, datetime.datetime.now())
+                logger.info('making a hardlink copy in %s', backup)
+                if not dryrun:
+                    shutil.copytree(self.path, backup, copy_function=os.link)
+            logger.info("opening attic repository with borg and converting")
+            # now lock the repo, after we have made the copy
+            self.lock = UpgradableLock(os.path.join(self.path, 'lock'), exclusive=True, timeout=1.0).acquire()
+            segments = [filename for i, filename in self.io.segment_iterator()]
+            try:
+                keyfile = self.find_attic_keyfile()
+            except KeyfileNotFoundError:
+                logger.warning("no key file found for repository")
+            else:
+                self.convert_keyfiles(keyfile, dryrun)
         # partial open: just hold on to the lock
         self.lock = UpgradableLock(os.path.join(self.path, 'lock'),
                                    exclusive=True).acquire()
@@ -282,12 +282,13 @@ class BorgRepositoryUpgrader(Repository):
         """convert an old borg repository to a current borg repository
         """
         logger.info("converting borg 0.xx to borg current")
-        try:
-            keyfile = self.find_borg0xx_keyfile()
-        except KeyfileNotFoundError:
-            logger.warning("no key file found for repository")
-        else:
-            self.move_keyfiles(keyfile, dryrun)
+        with self:
+            try:
+                keyfile = self.find_borg0xx_keyfile()
+            except KeyfileNotFoundError:
+                logger.warning("no key file found for repository")
+            else:
+                self.move_keyfiles(keyfile, dryrun)
 
     def find_borg0xx_keyfile(self):
         return Borg0xxKeyfileKey.find_key_file(self)