Forráskód Böngészése

Merge pull request #777 from enkore/feature-ctxmng

Feature ctxmng: Repository context manager
TW 9 éve
szülő
commit
db171e998e

+ 147 - 141
borg/archiver.py

@@ -40,18 +40,56 @@ UMASK_DEFAULT = 0o077
 DASHES = '-' * 78
 DASHES = '-' * 78
 
 
 
 
-class ToggleAction(argparse.Action):
-    """argparse action to handle "toggle" flags easily
+def argument(args, str_or_bool):
+    """If bool is passed, return it. If str is passed, retrieve named attribute from args."""
+    if isinstance(str_or_bool, str):
+        return getattr(args, str_or_bool)
+    return str_or_bool
 
 
-    toggle flags are in the form of ``--foo``, ``--no-foo``.
 
 
-    the ``--no-foo`` argument still needs to be passed to the
-    ``add_argument()`` call, but it simplifies the ``--no``
-    detection.
+def with_repository(fake=False, create=False, lock=True, exclusive=False, manifest=True, cache=False):
     """
     """
-    def __call__(self, parser, ns, values, option):
-        """set the given flag to true unless ``--no`` is passed"""
-        setattr(ns, self.dest, not option.startswith('--no-'))
+    Method decorator for subcommand-handling methods: do_XYZ(self, args, repository, …)
+
+    If a parameter (where allowed) is a str the attribute named of args is used instead.
+    :param fake: (str or bool) use None instead of repository, don't do anything else
+    :param create: create repository
+    :param lock: lock repository
+    :param exclusive: (str or bool) lock repository exclusively (for writing)
+    :param manifest: load manifest and key, pass them as keyword arguments
+    :param cache: open cache, pass it as keyword argument (implies manifest)
+    """
+    def decorator(method):
+        @functools.wraps(method)
+        def wrapper(self, args, **kwargs):
+            location = args.location  # note: 'location' must be always present in args
+            if argument(args, fake):
+                return method(self, args, repository=None, **kwargs)
+            elif location.proto == 'ssh':
+                repository = RemoteRepository(location, create=create, lock_wait=self.lock_wait, lock=lock, args=args)
+            else:
+                repository = Repository(location.path, create=create, exclusive=argument(args, exclusive),
+                                        lock_wait=self.lock_wait, lock=lock)
+            with repository:
+                if manifest or cache:
+                    kwargs['manifest'], kwargs['key'] = Manifest.load(repository)
+                if cache:
+                    with Cache(repository, kwargs['key'], kwargs['manifest'],
+                               do_files=getattr(args, 'cache_files', False), lock_wait=self.lock_wait) as cache_:
+                        return method(self, args, repository=repository, cache=cache_, **kwargs)
+                else:
+                    return method(self, args, repository=repository, **kwargs)
+        return wrapper
+    return decorator
+
+
+def with_archive(method):
+    @functools.wraps(method)
+    def wrapper(self, args, repository, key, manifest, **kwargs):
+        archive = Archive(repository, key, manifest, args.location.archive,
+                          numeric_owner=getattr(args, 'numeric_owner', False), cache=kwargs.get('cache'))
+        return method(self, args, repository=repository, manifest=manifest, key=key, archive=archive, **kwargs)
+    return wrapper
 
 
 
 
 class Archiver:
 class Archiver:
@@ -60,14 +98,6 @@ class Archiver:
         self.exit_code = EXIT_SUCCESS
         self.exit_code = EXIT_SUCCESS
         self.lock_wait = lock_wait
         self.lock_wait = lock_wait
 
 
-    def open_repository(self, args, create=False, exclusive=False, lock=True):
-        location = args.location  # note: 'location' must be always present in args
-        if location.proto == 'ssh':
-            repository = RemoteRepository(location, create=create, lock_wait=self.lock_wait, lock=lock, args=args)
-        else:
-            repository = Repository(location.path, create=create, exclusive=exclusive, lock_wait=self.lock_wait, lock=lock)
-        return repository
-
     def print_error(self, msg, *args):
     def print_error(self, msg, *args):
         msg = args and msg % args or msg
         msg = args and msg % args or msg
         self.exit_code = EXIT_ERROR
         self.exit_code = EXIT_ERROR
@@ -126,10 +156,10 @@ class Archiver:
         """
         """
         return RepositoryServer(restrict_to_paths=args.restrict_to_paths).serve()
         return RepositoryServer(restrict_to_paths=args.restrict_to_paths).serve()
 
 
-    def do_init(self, args):
+    @with_repository(create=True, exclusive=True, manifest=False)
+    def do_init(self, args, repository):
         """Initialize an empty repository"""
         """Initialize an empty repository"""
         logger.info('Initializing repository at "%s"' % args.location.canonical_path())
         logger.info('Initializing repository at "%s"' % args.location.canonical_path())
-        repository = self.open_repository(args, create=True, exclusive=True)
         key = key_creator(repository, args)
         key = key_creator(repository, args)
         manifest = Manifest(key, repository)
         manifest = Manifest(key, repository)
         manifest.key = key
         manifest.key = key
@@ -139,9 +169,9 @@ class Archiver:
             pass
             pass
         return self.exit_code
         return self.exit_code
 
 
-    def do_check(self, args):
+    @with_repository(exclusive='repair', manifest=False)
+    def do_check(self, args, repository):
         """Check repository consistency"""
         """Check repository consistency"""
-        repository = self.open_repository(args, exclusive=args.repair)
         if args.repair:
         if args.repair:
             msg = ("'check --repair' is an experimental feature that might result in data loss." +
             msg = ("'check --repair' is an experimental feature that might result in data loss." +
                    "\n" +
                    "\n" +
@@ -158,16 +188,15 @@ class Archiver:
             return EXIT_WARNING
             return EXIT_WARNING
         return EXIT_SUCCESS
         return EXIT_SUCCESS
 
 
-    def do_change_passphrase(self, args):
+    @with_repository()
+    def do_change_passphrase(self, args, repository, manifest, key):
         """Change repository key file passphrase"""
         """Change repository key file passphrase"""
-        repository = self.open_repository(args)
-        manifest, key = Manifest.load(repository)
         key.change_passphrase()
         key.change_passphrase()
         return EXIT_SUCCESS
         return EXIT_SUCCESS
 
 
-    def do_migrate_to_repokey(self, args):
+    @with_repository(manifest=False)
+    def do_migrate_to_repokey(self, args, repository):
         """Migrate passphrase -> repokey"""
         """Migrate passphrase -> repokey"""
-        repository = self.open_repository(args)
         manifest_data = repository.get(Manifest.MANIFEST_ID)
         manifest_data = repository.get(Manifest.MANIFEST_ID)
         key_old = PassphraseKey.detect(repository, manifest_data)
         key_old = PassphraseKey.detect(repository, manifest_data)
         key_new = RepoKey(repository)
         key_new = RepoKey(repository)
@@ -180,7 +209,8 @@ class Archiver:
         key_new.change_passphrase()  # option to change key protection passphrase, save
         key_new.change_passphrase()  # option to change key protection passphrase, save
         return EXIT_SUCCESS
         return EXIT_SUCCESS
 
 
-    def do_create(self, args):
+    @with_repository(fake='dry_run')
+    def do_create(self, args, repository, manifest=None, key=None):
         """Create new archive"""
         """Create new archive"""
         matcher = PatternMatcher(fallback=True)
         matcher = PatternMatcher(fallback=True)
         if args.excludes:
         if args.excludes:
@@ -245,8 +275,6 @@ class Archiver:
         dry_run = args.dry_run
         dry_run = args.dry_run
         t0 = datetime.utcnow()
         t0 = datetime.utcnow()
         if not dry_run:
         if not dry_run:
-            repository = self.open_repository(args, exclusive=True)
-            manifest, key = Manifest.load(repository)
             compr_args = dict(buffer=COMPR_BUFFER)
             compr_args = dict(buffer=COMPR_BUFFER)
             compr_args.update(args.compression)
             compr_args.update(args.compression)
             key.compressor = Compressor(**compr_args)
             key.compressor = Compressor(**compr_args)
@@ -339,17 +367,15 @@ class Archiver:
                 status = '-'  # dry run, item was not backed up
                 status = '-'  # dry run, item was not backed up
         self.print_file_status(status, path)
         self.print_file_status(status, path)
 
 
-    def do_extract(self, args):
+    @with_repository()
+    @with_archive
+    def do_extract(self, args, repository, manifest, key, archive):
         """Extract archive contents"""
         """Extract archive contents"""
         # be restrictive when restoring files, restore permissions later
         # be restrictive when restoring files, restore permissions later
         if sys.getfilesystemencoding() == 'ascii':
         if sys.getfilesystemencoding() == 'ascii':
             logger.warning('Warning: File system encoding is "ascii", extracting non-ascii filenames will not be supported.')
             logger.warning('Warning: File system encoding is "ascii", extracting non-ascii filenames will not be supported.')
             if sys.platform.startswith(('linux', 'freebsd', 'netbsd', 'openbsd', 'darwin', )):
             if sys.platform.startswith(('linux', 'freebsd', 'netbsd', 'openbsd', 'darwin', )):
                 logger.warning('Hint: You likely need to fix your locale setup. E.g. install locales and use: LANG=en_US.UTF-8')
                 logger.warning('Hint: You likely need to fix your locale setup. E.g. install locales and use: LANG=en_US.UTF-8')
-        repository = self.open_repository(args)
-        manifest, key = Manifest.load(repository)
-        archive = Archive(repository, key, manifest, args.location.archive,
-                          numeric_owner=args.numeric_owner)
 
 
         matcher, include_patterns = self.build_matcher(args.excludes, args.paths)
         matcher, include_patterns = self.build_matcher(args.excludes, args.paths)
 
 
@@ -403,7 +429,9 @@ class Archiver:
                 self.print_warning("Include pattern '%s' never matched.", pattern)
                 self.print_warning("Include pattern '%s' never matched.", pattern)
         return self.exit_code
         return self.exit_code
 
 
-    def do_diff(self, args):
+    @with_repository()
+    @with_archive
+    def do_diff(self, args, repository, manifest, key, archive):
         """Diff contents of two archives"""
         """Diff contents of two archives"""
         def format_bytes(count):
         def format_bytes(count):
             if count is None:
             if count is None:
@@ -499,9 +527,7 @@ class Archiver:
                     b'chunks': [],
                     b'chunks': [],
                 }, deleted=True)
                 }, deleted=True)
 
 
-        repository = self.open_repository(args)
-        manifest, key = Manifest.load(repository)
-        archive1 = Archive(repository, key, manifest, args.location.archive)
+        archive1 = archive
         archive2 = Archive(repository, key, manifest, args.archive2)
         archive2 = Archive(repository, key, manifest, args.archive2)
 
 
         can_compare_chunk_ids = archive1.metadata.get(b'chunker_params', False) == archive2.metadata.get(
         can_compare_chunk_ids = archive1.metadata.get(b'chunker_params', False) == archive2.metadata.get(
@@ -520,55 +546,52 @@ class Archiver:
                 self.print_warning("Include pattern '%s' never matched.", pattern)
                 self.print_warning("Include pattern '%s' never matched.", pattern)
         return self.exit_code
         return self.exit_code
 
 
-    def do_rename(self, args):
+    @with_repository(exclusive=True, cache=True)
+    @with_archive
+    def do_rename(self, args, repository, manifest, key, cache, archive):
         """Rename an existing archive"""
         """Rename an existing archive"""
-        repository = self.open_repository(args, exclusive=True)
-        manifest, key = Manifest.load(repository)
-        with Cache(repository, key, manifest, lock_wait=self.lock_wait) as cache:
-            archive = Archive(repository, key, manifest, args.location.archive, cache=cache)
-            archive.rename(args.name)
-            manifest.write()
-            repository.commit()
-            cache.commit()
+        archive.rename(args.name)
+        manifest.write()
+        repository.commit()
+        cache.commit()
         return self.exit_code
         return self.exit_code
 
 
-    def do_delete(self, args):
+    @with_repository(exclusive=True, cache=True)
+    def do_delete(self, args, repository, manifest, key, cache):
         """Delete an existing repository or archive"""
         """Delete an existing repository or archive"""
-        repository = self.open_repository(args, exclusive=True)
-        manifest, key = Manifest.load(repository)
-        with Cache(repository, key, manifest, do_files=args.cache_files, lock_wait=self.lock_wait) as cache:
-            if args.location.archive:
-                archive = Archive(repository, key, manifest, args.location.archive, cache=cache)
-                stats = Statistics()
-                archive.delete(stats, progress=args.progress)
-                manifest.write()
-                repository.commit(save_space=args.save_space)
-                cache.commit()
-                logger.info("Archive deleted.")
-                if args.stats:
-                    log_multi(DASHES,
-                              stats.summary.format(label='Deleted data:', stats=stats),
-                              str(cache),
-                              DASHES)
-            else:
-                if not args.cache_only:
-                    msg = []
-                    msg.append("You requested to completely DELETE the repository *including* all archives it contains:")
-                    for archive_info in manifest.list_archive_infos(sort_by='ts'):
-                        msg.append(format_archive(archive_info))
-                    msg.append("Type 'YES' if you understand this and want to continue: ")
-                    msg = '\n'.join(msg)
-                    if not yes(msg, false_msg="Aborting.", truish=('YES', ),
-                               env_var_override='BORG_DELETE_I_KNOW_WHAT_I_AM_DOING'):
-                        self.exit_code = EXIT_ERROR
-                        return self.exit_code
-                    repository.destroy()
-                    logger.info("Repository deleted.")
-                cache.destroy()
-                logger.info("Cache deleted.")
+        if args.location.archive:
+            archive = Archive(repository, key, manifest, args.location.archive, cache=cache)
+            stats = Statistics()
+            archive.delete(stats, progress=args.progress)
+            manifest.write()
+            repository.commit(save_space=args.save_space)
+            cache.commit()
+            logger.info("Archive deleted.")
+            if args.stats:
+                log_multi(DASHES,
+                          stats.summary.format(label='Deleted data:', stats=stats),
+                          str(cache),
+                          DASHES)
+        else:
+            if not args.cache_only:
+                msg = []
+                msg.append("You requested to completely DELETE the repository *including* all archives it contains:")
+                for archive_info in manifest.list_archive_infos(sort_by='ts'):
+                    msg.append(format_archive(archive_info))
+                msg.append("Type 'YES' if you understand this and want to continue: ")
+                msg = '\n'.join(msg)
+                if not yes(msg, false_msg="Aborting.", truish=('YES', ),
+                           env_var_override='BORG_DELETE_I_KNOW_WHAT_I_AM_DOING'):
+                    self.exit_code = EXIT_ERROR
+                    return self.exit_code
+                repository.destroy()
+                logger.info("Repository deleted.")
+            cache.destroy()
+            logger.info("Cache deleted.")
         return self.exit_code
         return self.exit_code
 
 
-    def do_mount(self, args):
+    @with_repository()
+    def do_mount(self, args, repository, manifest, key):
         """Mount archive or an entire repository as a FUSE fileystem"""
         """Mount archive or an entire repository as a FUSE fileystem"""
         try:
         try:
             from .fuse import FuseOperations
             from .fuse import FuseOperations
@@ -580,29 +603,23 @@ class Archiver:
             self.print_error('%s: Mountpoint must be a writable directory' % args.mountpoint)
             self.print_error('%s: Mountpoint must be a writable directory' % args.mountpoint)
             return self.exit_code
             return self.exit_code
 
 
-        repository = self.open_repository(args)
-        try:
-            with cache_if_remote(repository) as cached_repo:
-                manifest, key = Manifest.load(repository)
-                if args.location.archive:
-                    archive = Archive(repository, key, manifest, args.location.archive)
-                else:
-                    archive = None
-                operations = FuseOperations(key, repository, manifest, archive, cached_repo)
-                logger.info("Mounting filesystem")
-                try:
-                    operations.mount(args.mountpoint, args.options, args.foreground)
-                except RuntimeError:
-                    # Relevant error message already printed to stderr by fuse
-                    self.exit_code = EXIT_ERROR
-        finally:
-            repository.close()
+        with cache_if_remote(repository) as cached_repo:
+            if args.location.archive:
+                archive = Archive(repository, key, manifest, args.location.archive)
+            else:
+                archive = None
+            operations = FuseOperations(key, repository, manifest, archive, cached_repo)
+            logger.info("Mounting filesystem")
+            try:
+                operations.mount(args.mountpoint, args.options, args.foreground)
+            except RuntimeError:
+                # Relevant error message already printed to stderr by fuse
+                self.exit_code = EXIT_ERROR
         return self.exit_code
         return self.exit_code
 
 
-    def do_list(self, args):
+    @with_repository()
+    def do_list(self, args, repository, manifest, key):
         """List archive or repository contents"""
         """List archive or repository contents"""
-        repository = self.open_repository(args)
-        manifest, key = Manifest.load(repository)
         if args.location.archive:
         if args.location.archive:
             matcher, _ = self.build_matcher(args.excludes, args.paths)
             matcher, _ = self.build_matcher(args.excludes, args.paths)
 
 
@@ -626,7 +643,6 @@ class Archiver:
                     write = sys.stdout.buffer.write
                     write = sys.stdout.buffer.write
                 for item in archive.iter_items(lambda item: matcher.match(item[b'path'])):
                 for item in archive.iter_items(lambda item: matcher.match(item[b'path'])):
                     write(formatter.format_item(item).encode('utf-8', errors='surrogateescape'))
                     write(formatter.format_item(item).encode('utf-8', errors='surrogateescape'))
-            repository.close()
         else:
         else:
             for archive_info in manifest.list_archive_infos(sort_by='ts'):
             for archive_info in manifest.list_archive_infos(sort_by='ts'):
                 if args.prefix and not archive_info.name.startswith(args.prefix):
                 if args.prefix and not archive_info.name.startswith(args.prefix):
@@ -637,30 +653,27 @@ class Archiver:
                     print(format_archive(archive_info))
                     print(format_archive(archive_info))
         return self.exit_code
         return self.exit_code
 
 
-    def do_info(self, args):
+    @with_repository(cache=True)
+    @with_archive
+    def do_info(self, args, repository, manifest, key, archive, cache):
         """Show archive details such as disk space used"""
         """Show archive details such as disk space used"""
-        repository = self.open_repository(args)
-        manifest, key = Manifest.load(repository)
-        with Cache(repository, key, manifest, do_files=args.cache_files, lock_wait=self.lock_wait) as cache:
-            archive = Archive(repository, key, manifest, args.location.archive, cache=cache)
-            stats = archive.calc_stats(cache)
-            print('Name:', archive.name)
-            print('Fingerprint: %s' % hexlify(archive.id).decode('ascii'))
-            print('Hostname:', archive.metadata[b'hostname'])
-            print('Username:', archive.metadata[b'username'])
-            print('Time (start): %s' % format_time(to_localtime(archive.ts)))
-            print('Time (end):   %s' % format_time(to_localtime(archive.ts_end)))
-            print('Command line:', remove_surrogates(' '.join(archive.metadata[b'cmdline'])))
-            print('Number of files: %d' % stats.nfiles)
-            print()
-            print(str(stats))
-            print(str(cache))
+        stats = archive.calc_stats(cache)
+        print('Name:', archive.name)
+        print('Fingerprint: %s' % hexlify(archive.id).decode('ascii'))
+        print('Hostname:', archive.metadata[b'hostname'])
+        print('Username:', archive.metadata[b'username'])
+        print('Time (start): %s' % format_time(to_localtime(archive.ts)))
+        print('Time (end):   %s' % format_time(to_localtime(archive.ts_end)))
+        print('Command line:', remove_surrogates(' '.join(archive.metadata[b'cmdline'])))
+        print('Number of files: %d' % stats.nfiles)
+        print()
+        print(str(stats))
+        print(str(cache))
         return self.exit_code
         return self.exit_code
 
 
-    def do_prune(self, args):
+    @with_repository()
+    def do_prune(self, args, repository, manifest, key):
         """Prune repository archives according to specified rules"""
         """Prune repository archives according to specified rules"""
-        repository = self.open_repository(args, exclusive=True)
-        manifest, key = Manifest.load(repository)
         archives = manifest.list_archive_infos(sort_by='ts', reverse=True)  # just a ArchiveInfo list
         archives = manifest.list_archive_infos(sort_by='ts', reverse=True)  # just a ArchiveInfo list
         if args.hourly + args.daily + args.weekly + args.monthly + args.yearly == 0 and args.within is None:
         if args.hourly + args.daily + args.weekly + args.monthly + args.yearly == 0 and args.within is None:
             self.print_error('At least one of the "keep-within", "keep-hourly", "keep-daily", "keep-weekly", '
             self.print_error('At least one of the "keep-within", "keep-hourly", "keep-daily", "keep-weekly", '
@@ -725,10 +738,9 @@ class Archiver:
             print("warning: %s" % e)
             print("warning: %s" % e)
         return self.exit_code
         return self.exit_code
 
 
-    def do_debug_dump_archive_items(self, args):
+    @with_repository()
+    def do_debug_dump_archive_items(self, args, repository, manifest, key):
         """dump (decrypted, decompressed) archive items metadata (not: data)"""
         """dump (decrypted, decompressed) archive items metadata (not: data)"""
-        repository = self.open_repository(args)
-        manifest, key = Manifest.load(repository)
         archive = Archive(repository, key, manifest, args.location.archive)
         archive = Archive(repository, key, manifest, args.location.archive)
         for i, item_id in enumerate(archive.metadata[b'items']):
         for i, item_id in enumerate(archive.metadata[b'items']):
             data = key.decrypt(item_id, repository.get(item_id))
             data = key.decrypt(item_id, repository.get(item_id))
@@ -739,10 +751,9 @@ class Archiver:
         print('Done.')
         print('Done.')
         return EXIT_SUCCESS
         return EXIT_SUCCESS
 
 
-    def do_debug_get_obj(self, args):
+    @with_repository(manifest=False)
+    def do_debug_get_obj(self, args, repository):
         """get object contents from the repository and write it into file"""
         """get object contents from the repository and write it into file"""
-        repository = self.open_repository(args)
-        manifest, key = Manifest.load(repository)
         hex_id = args.id
         hex_id = args.id
         try:
         try:
             id = unhexlify(hex_id)
             id = unhexlify(hex_id)
@@ -759,10 +770,9 @@ class Archiver:
                 print("object %s fetched." % hex_id)
                 print("object %s fetched." % hex_id)
         return EXIT_SUCCESS
         return EXIT_SUCCESS
 
 
-    def do_debug_put_obj(self, args):
+    @with_repository(manifest=False)
+    def do_debug_put_obj(self, args, repository):
         """put file(s) contents into the repository"""
         """put file(s) contents into the repository"""
-        repository = self.open_repository(args)
-        manifest, key = Manifest.load(repository)
         for path in args.paths:
         for path in args.paths:
             with open(path, "rb") as f:
             with open(path, "rb") as f:
                 data = f.read()
                 data = f.read()
@@ -772,10 +782,9 @@ class Archiver:
         repository.commit()
         repository.commit()
         return EXIT_SUCCESS
         return EXIT_SUCCESS
 
 
-    def do_debug_delete_obj(self, args):
+    @with_repository(manifest=False)
+    def do_debug_delete_obj(self, args, repository):
         """delete the objects with the given IDs from the repo"""
         """delete the objects with the given IDs from the repo"""
-        repository = self.open_repository(args)
-        manifest, key = Manifest.load(repository)
         modified = False
         modified = False
         for hex_id in args.ids:
         for hex_id in args.ids:
             try:
             try:
@@ -794,14 +803,11 @@ class Archiver:
         print('Done.')
         print('Done.')
         return EXIT_SUCCESS
         return EXIT_SUCCESS
 
 
-    def do_break_lock(self, args):
+    @with_repository(lock=False, manifest=False)
+    def do_break_lock(self, args, repository):
         """Break the repository lock (e.g. in case it was left by a dead borg."""
         """Break the repository lock (e.g. in case it was left by a dead borg."""
-        repository = self.open_repository(args, lock=False)
-        try:
-            repository.break_lock()
-            Cache.break_lock(repository)
-        finally:
-            repository.close()
+        repository.break_lock()
+        Cache.break_lock(repository)
         return self.exit_code
         return self.exit_code
 
 
     helptext = {}
     helptext = {}

+ 15 - 1
borg/remote.py

@@ -77,6 +77,7 @@ class RepositoryServer:  # pragma: no cover
             if r:
             if r:
                 data = os.read(stdin_fd, BUFSIZE)
                 data = os.read(stdin_fd, BUFSIZE)
                 if not data:
                 if not data:
+                    self.repository.close()
                     return
                     return
                 unpacker.feed(data)
                 unpacker.feed(data)
                 for unpacked in unpacker:
                 for unpacked in unpacker:
@@ -100,6 +101,7 @@ class RepositoryServer:  # pragma: no cover
                     else:
                     else:
                         os.write(stdout_fd, msgpack.packb((1, msgid, None, res)))
                         os.write(stdout_fd, msgpack.packb((1, msgid, None, res)))
             if es:
             if es:
+                self.repository.close()
                 return
                 return
 
 
     def negotiate(self, versions):
     def negotiate(self, versions):
@@ -117,6 +119,7 @@ class RepositoryServer:  # pragma: no cover
             else:
             else:
                 raise PathNotAllowed(path)
                 raise PathNotAllowed(path)
         self.repository = Repository(path, create, lock_wait=lock_wait, lock=lock)
         self.repository = Repository(path, create, lock_wait=lock_wait, lock=lock)
+        self.repository.__enter__()  # clean exit handled by serve() method
         return self.repository.id
         return self.repository.id
 
 
 
 
@@ -164,11 +167,21 @@ class RemoteRepository:
         self.id = self.call('open', location.path, create, lock_wait, lock)
         self.id = self.call('open', location.path, create, lock_wait, lock)
 
 
     def __del__(self):
     def __del__(self):
-        self.close()
+        if self.p:
+            self.close()
+            assert False, "cleanup happened in Repository.__del__"
 
 
     def __repr__(self):
     def __repr__(self):
         return '<%s %s>' % (self.__class__.__name__, self.location.canonical_path())
         return '<%s %s>' % (self.__class__.__name__, self.location.canonical_path())
 
 
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        if exc_type is not None:
+            self.rollback()
+        self.close()
+
     def borg_cmd(self, args, testing):
     def borg_cmd(self, args, testing):
         """return a borg serve command line"""
         """return a borg serve command line"""
         # give some args/options to "borg serve" process as they were given to us
         # give some args/options to "borg serve" process as they were given to us
@@ -392,6 +405,7 @@ class RepositoryCache(RepositoryNoCache):
         super().__init__(repository)
         super().__init__(repository)
         tmppath = tempfile.mkdtemp(prefix='borg-tmp')
         tmppath = tempfile.mkdtemp(prefix='borg-tmp')
         self.caching_repo = Repository(tmppath, create=True, exclusive=True)
         self.caching_repo = Repository(tmppath, create=True, exclusive=True)
+        self.caching_repo.__enter__()  # handled by context manager in base class
 
 
     def close(self):
     def close(self):
         if self.caching_repo is not None:
         if self.caching_repo is not None:

+ 19 - 4
borg/repository.py

@@ -59,16 +59,31 @@ class Repository:
         self.lock = None
         self.lock = None
         self.index = None
         self.index = None
         self._active_txn = False
         self._active_txn = False
-        if create:
-            self.create(self.path)
-        self.open(self.path, exclusive, lock_wait=lock_wait, lock=lock)
+        self.lock_wait = lock_wait
+        self.do_lock = lock
+        self.do_create = create
+        self.exclusive = exclusive
 
 
     def __del__(self):
     def __del__(self):
-        self.close()
+        if self.lock:
+            self.close()
+            assert False, "cleanup happened in Repository.__del__"
 
 
     def __repr__(self):
     def __repr__(self):
         return '<%s %s>' % (self.__class__.__name__, self.path)
         return '<%s %s>' % (self.__class__.__name__, self.path)
 
 
+    def __enter__(self):
+        if self.do_create:
+            self.do_create = False
+            self.create(self.path)
+        self.open(self.path, self.exclusive, lock_wait=self.lock_wait, lock=self.do_lock)
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        if exc_type is not None:
+            self.rollback()
+        self.close()
+
     def create(self, path):
     def create(self, path):
         """Create a new empty repository at `path`
         """Create a new empty repository at `path`
         """
         """

+ 4 - 0
borg/testsuite/__init__.py

@@ -8,6 +8,7 @@ import sysconfig
 import time
 import time
 import unittest
 import unittest
 from ..xattr import get_all
 from ..xattr import get_all
+from ..logger import setup_logging
 
 
 try:
 try:
     import llfuse
     import llfuse
@@ -30,6 +31,9 @@ else:
 if sys.platform.startswith('netbsd'):
 if sys.platform.startswith('netbsd'):
     st_mtime_ns_round = -4  # only >1 microsecond resolution here?
     st_mtime_ns_round = -4  # only >1 microsecond resolution here?
 
 
+# Ensure that the loggers exist for all tests
+setup_logging()
+
 
 
 class BaseTestCase(unittest.TestCase):
 class BaseTestCase(unittest.TestCase):
     """
     """

+ 47 - 36
borg/testsuite/archiver.py

@@ -367,7 +367,8 @@ class ArchiverTestCase(ArchiverTestCaseBase):
             assert sto.st_atime_ns == atime * 1e9
             assert sto.st_atime_ns == atime * 1e9
 
 
     def _extract_repository_id(self, path):
     def _extract_repository_id(self, path):
-        return Repository(self.repository_path).id
+        with Repository(self.repository_path) as repository:
+            return repository.id
 
 
     def _set_repository_id(self, path, id):
     def _set_repository_id(self, path, id):
         config = ConfigParser(interpolation=None)
         config = ConfigParser(interpolation=None)
@@ -375,7 +376,8 @@ class ArchiverTestCase(ArchiverTestCaseBase):
         config.set('repository', 'id', hexlify(id).decode('ascii'))
         config.set('repository', 'id', hexlify(id).decode('ascii'))
         with open(os.path.join(path, 'config'), 'w') as fd:
         with open(os.path.join(path, 'config'), 'w') as fd:
             config.write(fd)
             config.write(fd)
-        return Repository(self.repository_path).id
+        with Repository(self.repository_path) as repository:
+            return repository.id
 
 
     def test_sparse_file(self):
     def test_sparse_file(self):
         # no sparse file support on Mac OS X
         # no sparse file support on Mac OS X
@@ -745,8 +747,8 @@ class ArchiverTestCase(ArchiverTestCaseBase):
         self.cmd('extract', '--dry-run', self.repository_location + '::test.3')
         self.cmd('extract', '--dry-run', self.repository_location + '::test.3')
         self.cmd('extract', '--dry-run', self.repository_location + '::test.4')
         self.cmd('extract', '--dry-run', self.repository_location + '::test.4')
         # Make sure both archives have been renamed
         # Make sure both archives have been renamed
-        repository = Repository(self.repository_path)
-        manifest, key = Manifest.load(repository)
+        with Repository(self.repository_path) as repository:
+            manifest, key = Manifest.load(repository)
         self.assert_equal(len(manifest.archives), 2)
         self.assert_equal(len(manifest.archives), 2)
         self.assert_in('test.3', manifest.archives)
         self.assert_in('test.3', manifest.archives)
         self.assert_in('test.4', manifest.archives)
         self.assert_in('test.4', manifest.archives)
@@ -763,8 +765,8 @@ class ArchiverTestCase(ArchiverTestCaseBase):
         self.cmd('extract', '--dry-run', self.repository_location + '::test.2')
         self.cmd('extract', '--dry-run', self.repository_location + '::test.2')
         self.cmd('delete', '--stats', self.repository_location + '::test.2')
         self.cmd('delete', '--stats', self.repository_location + '::test.2')
         # Make sure all data except the manifest has been deleted
         # Make sure all data except the manifest has been deleted
-        repository = Repository(self.repository_path)
-        self.assert_equal(len(repository), 1)
+        with Repository(self.repository_path) as repository:
+            self.assert_equal(len(repository), 1)
 
 
     def test_delete_repo(self):
     def test_delete_repo(self):
         self.create_regular_file('file1', size=1024 * 80)
         self.create_regular_file('file1', size=1024 * 80)
@@ -772,6 +774,11 @@ class ArchiverTestCase(ArchiverTestCaseBase):
         self.cmd('init', self.repository_location)
         self.cmd('init', self.repository_location)
         self.cmd('create', self.repository_location + '::test', 'input')
         self.cmd('create', self.repository_location + '::test', 'input')
         self.cmd('create', self.repository_location + '::test.2', 'input')
         self.cmd('create', self.repository_location + '::test.2', 'input')
+        os.environ['BORG_DELETE_I_KNOW_WHAT_I_AM_DOING'] = 'no'
+        self.cmd('delete', self.repository_location, exit_code=2)
+        self.archiver.exit_code = 0
+        assert os.path.exists(self.repository_path)
+        os.environ['BORG_DELETE_I_KNOW_WHAT_I_AM_DOING'] = 'YES'
         self.cmd('delete', self.repository_location)
         self.cmd('delete', self.repository_location)
         # Make sure the repo is gone
         # Make sure the repo is gone
         self.assertFalse(os.path.exists(self.repository_path))
         self.assertFalse(os.path.exists(self.repository_path))
@@ -810,8 +817,8 @@ class ArchiverTestCase(ArchiverTestCaseBase):
         self.cmd('init', self.repository_location)
         self.cmd('init', self.repository_location)
         self.cmd('create', '--dry-run', self.repository_location + '::test', 'input')
         self.cmd('create', '--dry-run', self.repository_location + '::test', 'input')
         # Make sure no archive has been created
         # Make sure no archive has been created
-        repository = Repository(self.repository_path)
-        manifest, key = Manifest.load(repository)
+        with Repository(self.repository_path) as repository:
+            manifest, key = Manifest.load(repository)
         self.assert_equal(len(manifest.archives), 0)
         self.assert_equal(len(manifest.archives), 0)
 
 
     def test_progress(self):
     def test_progress(self):
@@ -1045,17 +1052,17 @@ class ArchiverTestCase(ArchiverTestCaseBase):
         used = set()  # counter values already used
         used = set()  # counter values already used
 
 
         def verify_uniqueness():
         def verify_uniqueness():
-            repository = Repository(self.repository_path)
-            for key, _ in repository.open_index(repository.get_transaction_id()).iteritems():
-                data = repository.get(key)
-                hash = sha256(data).digest()
-                if hash not in seen:
-                    seen.add(hash)
-                    num_blocks = num_aes_blocks(len(data) - 41)
-                    nonce = bytes_to_long(data[33:41])
-                    for counter in range(nonce, nonce + num_blocks):
-                        self.assert_not_in(counter, used)
-                        used.add(counter)
+            with Repository(self.repository_path) as repository:
+                for key, _ in repository.open_index(repository.get_transaction_id()).iteritems():
+                    data = repository.get(key)
+                    hash = sha256(data).digest()
+                    if hash not in seen:
+                        seen.add(hash)
+                        num_blocks = num_aes_blocks(len(data) - 41)
+                        nonce = bytes_to_long(data[33:41])
+                        for counter in range(nonce, nonce + num_blocks):
+                            self.assert_not_in(counter, used)
+                            used.add(counter)
 
 
         self.create_test_files()
         self.create_test_files()
         os.environ['BORG_PASSPHRASE'] = 'passphrase'
         os.environ['BORG_PASSPHRASE'] = 'passphrase'
@@ -1122,8 +1129,9 @@ class ArchiverCheckTestCase(ArchiverTestCaseBase):
 
 
     def open_archive(self, name):
     def open_archive(self, name):
         repository = Repository(self.repository_path)
         repository = Repository(self.repository_path)
-        manifest, key = Manifest.load(repository)
-        archive = Archive(repository, key, manifest, name)
+        with repository:
+            manifest, key = Manifest.load(repository)
+            archive = Archive(repository, key, manifest, name)
         return archive, repository
         return archive, repository
 
 
     def test_check_usage(self):
     def test_check_usage(self):
@@ -1141,35 +1149,39 @@ class ArchiverCheckTestCase(ArchiverTestCaseBase):
 
 
     def test_missing_file_chunk(self):
     def test_missing_file_chunk(self):
         archive, repository = self.open_archive('archive1')
         archive, repository = self.open_archive('archive1')
-        for item in archive.iter_items():
-            if item[b'path'].endswith('testsuite/archiver.py'):
-                repository.delete(item[b'chunks'][-1][0])
-                break
-        repository.commit()
+        with repository:
+            for item in archive.iter_items():
+                if item[b'path'].endswith('testsuite/archiver.py'):
+                    repository.delete(item[b'chunks'][-1][0])
+                    break
+            repository.commit()
         self.cmd('check', self.repository_location, exit_code=1)
         self.cmd('check', self.repository_location, exit_code=1)
         self.cmd('check', '--repair', self.repository_location, exit_code=0)
         self.cmd('check', '--repair', self.repository_location, exit_code=0)
         self.cmd('check', self.repository_location, exit_code=0)
         self.cmd('check', self.repository_location, exit_code=0)
 
 
     def test_missing_archive_item_chunk(self):
     def test_missing_archive_item_chunk(self):
         archive, repository = self.open_archive('archive1')
         archive, repository = self.open_archive('archive1')
-        repository.delete(archive.metadata[b'items'][-5])
-        repository.commit()
+        with repository:
+            repository.delete(archive.metadata[b'items'][-5])
+            repository.commit()
         self.cmd('check', self.repository_location, exit_code=1)
         self.cmd('check', self.repository_location, exit_code=1)
         self.cmd('check', '--repair', self.repository_location, exit_code=0)
         self.cmd('check', '--repair', self.repository_location, exit_code=0)
         self.cmd('check', self.repository_location, exit_code=0)
         self.cmd('check', self.repository_location, exit_code=0)
 
 
     def test_missing_archive_metadata(self):
     def test_missing_archive_metadata(self):
         archive, repository = self.open_archive('archive1')
         archive, repository = self.open_archive('archive1')
-        repository.delete(archive.id)
-        repository.commit()
+        with repository:
+            repository.delete(archive.id)
+            repository.commit()
         self.cmd('check', self.repository_location, exit_code=1)
         self.cmd('check', self.repository_location, exit_code=1)
         self.cmd('check', '--repair', self.repository_location, exit_code=0)
         self.cmd('check', '--repair', self.repository_location, exit_code=0)
         self.cmd('check', self.repository_location, exit_code=0)
         self.cmd('check', self.repository_location, exit_code=0)
 
 
     def test_missing_manifest(self):
     def test_missing_manifest(self):
         archive, repository = self.open_archive('archive1')
         archive, repository = self.open_archive('archive1')
-        repository.delete(Manifest.MANIFEST_ID)
-        repository.commit()
+        with repository:
+            repository.delete(Manifest.MANIFEST_ID)
+            repository.commit()
         self.cmd('check', self.repository_location, exit_code=1)
         self.cmd('check', self.repository_location, exit_code=1)
         output = self.cmd('check', '-v', '--repair', self.repository_location, exit_code=0)
         output = self.cmd('check', '-v', '--repair', self.repository_location, exit_code=0)
         self.assert_in('archive1', output)
         self.assert_in('archive1', output)
@@ -1178,10 +1190,9 @@ class ArchiverCheckTestCase(ArchiverTestCaseBase):
 
 
     def test_extra_chunks(self):
     def test_extra_chunks(self):
         self.cmd('check', self.repository_location, exit_code=0)
         self.cmd('check', self.repository_location, exit_code=0)
-        repository = Repository(self.repository_location)
-        repository.put(b'01234567890123456789012345678901', b'xxxx')
-        repository.commit()
-        repository.close()
+        with Repository(self.repository_location) as repository:
+            repository.put(b'01234567890123456789012345678901', b'xxxx')
+            repository.commit()
         self.cmd('check', self.repository_location, exit_code=1)
         self.cmd('check', self.repository_location, exit_code=1)
         self.cmd('check', self.repository_location, exit_code=1)
         self.cmd('check', self.repository_location, exit_code=1)
         self.cmd('check', '--repair', self.repository_location, exit_code=0)
         self.cmd('check', '--repair', self.repository_location, exit_code=0)

+ 32 - 25
borg/testsuite/repository.py

@@ -21,6 +21,7 @@ class RepositoryTestCaseBase(BaseTestCase):
     def setUp(self):
     def setUp(self):
         self.tmppath = tempfile.mkdtemp()
         self.tmppath = tempfile.mkdtemp()
         self.repository = self.open(create=True)
         self.repository = self.open(create=True)
+        self.repository.__enter__()
 
 
     def tearDown(self):
     def tearDown(self):
         self.repository.close()
         self.repository.close()
@@ -43,13 +44,12 @@ class RepositoryTestCase(RepositoryTestCaseBase):
         self.assert_raises(Repository.ObjectNotFound, lambda: self.repository.get(key50))
         self.assert_raises(Repository.ObjectNotFound, lambda: self.repository.get(key50))
         self.repository.commit()
         self.repository.commit()
         self.repository.close()
         self.repository.close()
-        repository2 = self.open()
-        self.assert_raises(Repository.ObjectNotFound, lambda: repository2.get(key50))
-        for x in range(100):
-            if x == 50:
-                continue
-            self.assert_equal(repository2.get(('%-32d' % x).encode('ascii')), b'SOMEDATA')
-        repository2.close()
+        with self.open() as repository2:
+            self.assert_raises(Repository.ObjectNotFound, lambda: repository2.get(key50))
+            for x in range(100):
+                if x == 50:
+                    continue
+                self.assert_equal(repository2.get(('%-32d' % x).encode('ascii')), b'SOMEDATA')
 
 
     def test2(self):
     def test2(self):
         """Test multiple sequential transactions
         """Test multiple sequential transactions
@@ -100,13 +100,14 @@ class RepositoryTestCase(RepositoryTestCaseBase):
         self.repository.close()
         self.repository.close()
         # replace
         # replace
         self.repository = self.open()
         self.repository = self.open()
-        self.repository.put(b'00000000000000000000000000000000', b'bar')
-        self.repository.commit()
-        self.repository.close()
+        with self.repository:
+            self.repository.put(b'00000000000000000000000000000000', b'bar')
+            self.repository.commit()
         # delete
         # delete
         self.repository = self.open()
         self.repository = self.open()
-        self.repository.delete(b'00000000000000000000000000000000')
-        self.repository.commit()
+        with self.repository:
+            self.repository.delete(b'00000000000000000000000000000000')
+            self.repository.commit()
 
 
     def test_list(self):
     def test_list(self):
         for x in range(100):
         for x in range(100):
@@ -139,8 +140,9 @@ class RepositoryCommitTestCase(RepositoryTestCaseBase):
             if name.startswith('index.'):
             if name.startswith('index.'):
                 os.unlink(os.path.join(self.repository.path, name))
                 os.unlink(os.path.join(self.repository.path, name))
         self.reopen()
         self.reopen()
-        self.assert_equal(len(self.repository), 3)
-        self.assert_equal(self.repository.check(), True)
+        with self.repository:
+            self.assert_equal(len(self.repository), 3)
+            self.assert_equal(self.repository.check(), True)
 
 
     def test_crash_before_compact_segments(self):
     def test_crash_before_compact_segments(self):
         self.add_keys()
         self.add_keys()
@@ -150,8 +152,9 @@ class RepositoryCommitTestCase(RepositoryTestCaseBase):
         except TypeError:
         except TypeError:
             pass
             pass
         self.reopen()
         self.reopen()
-        self.assert_equal(len(self.repository), 3)
-        self.assert_equal(self.repository.check(), True)
+        with self.repository:
+            self.assert_equal(len(self.repository), 3)
+            self.assert_equal(self.repository.check(), True)
 
 
     def test_replay_of_readonly_repository(self):
     def test_replay_of_readonly_repository(self):
         self.add_keys()
         self.add_keys()
@@ -160,8 +163,9 @@ class RepositoryCommitTestCase(RepositoryTestCaseBase):
                 os.unlink(os.path.join(self.repository.path, name))
                 os.unlink(os.path.join(self.repository.path, name))
         with patch.object(UpgradableLock, 'upgrade', side_effect=LockFailed) as upgrade:
         with patch.object(UpgradableLock, 'upgrade', side_effect=LockFailed) as upgrade:
             self.reopen()
             self.reopen()
-            self.assert_raises(LockFailed, lambda: len(self.repository))
-            upgrade.assert_called_once_with()
+            with self.repository:
+                self.assert_raises(LockFailed, lambda: len(self.repository))
+                upgrade.assert_called_once_with()
 
 
     def test_crash_before_write_index(self):
     def test_crash_before_write_index(self):
         self.add_keys()
         self.add_keys()
@@ -171,8 +175,9 @@ class RepositoryCommitTestCase(RepositoryTestCaseBase):
         except TypeError:
         except TypeError:
             pass
             pass
         self.reopen()
         self.reopen()
-        self.assert_equal(len(self.repository), 3)
-        self.assert_equal(self.repository.check(), True)
+        with self.repository:
+            self.assert_equal(len(self.repository), 3)
+            self.assert_equal(self.repository.check(), True)
 
 
     def test_crash_before_deleting_compacted_segments(self):
     def test_crash_before_deleting_compacted_segments(self):
         self.add_keys()
         self.add_keys()
@@ -182,9 +187,10 @@ class RepositoryCommitTestCase(RepositoryTestCaseBase):
         except TypeError:
         except TypeError:
             pass
             pass
         self.reopen()
         self.reopen()
-        self.assert_equal(len(self.repository), 3)
-        self.assert_equal(self.repository.check(), True)
-        self.assert_equal(len(self.repository), 3)
+        with self.repository:
+            self.assert_equal(len(self.repository), 3)
+            self.assert_equal(self.repository.check(), True)
+            self.assert_equal(len(self.repository), 3)
 
 
 
 
 class RepositoryCheckTestCase(RepositoryTestCaseBase):
 class RepositoryCheckTestCase(RepositoryTestCaseBase):
@@ -313,8 +319,9 @@ class RepositoryCheckTestCase(RepositoryTestCaseBase):
             self.repository.commit()
             self.repository.commit()
             compact.assert_called_once_with(save_space=False)
             compact.assert_called_once_with(save_space=False)
         self.reopen()
         self.reopen()
-        self.check(repair=True)
-        self.assert_equal(self.repository.get(bytes(32)), b'data2')
+        with self.repository:
+            self.check(repair=True)
+            self.assert_equal(self.repository.get(bytes(32)), b'data2')
 
 
 
 
 class RemoteRepositoryTestCase(RepositoryTestCase):
 class RemoteRepositoryTestCase(RepositoryTestCase):

+ 24 - 26
borg/testsuite/upgrader.py

@@ -23,11 +23,9 @@ def repo_valid(path):
     :param path: the path to the repository
     :param path: the path to the repository
     :returns: if borg can check the repository
     :returns: if borg can check the repository
     """
     """
-    repository = Repository(str(path), create=False)
-    # can't check raises() because check() handles the error
-    state = repository.check()
-    repository.close()
-    return state
+    with Repository(str(path), create=False) as repository:
+        # can't check raises() because check() handles the error
+        return repository.check()
 
 
 
 
 def key_valid(path):
 def key_valid(path):
@@ -79,11 +77,11 @@ def test_convert_segments(tmpdir, attic_repo, inplace):
     """
     """
     # check should fail because of magic number
     # check should fail because of magic number
     assert not repo_valid(tmpdir)
     assert not repo_valid(tmpdir)
-    repo = AtticRepositoryUpgrader(str(tmpdir), create=False)
-    segments = [filename for i, filename in repo.io.segment_iterator()]
-    repo.close()
-    repo.convert_segments(segments, dryrun=False, inplace=inplace)
-    repo.convert_cache(dryrun=False)
+    repository = AtticRepositoryUpgrader(str(tmpdir), create=False)
+    with repository:
+        segments = [filename for i, filename in repository.io.segment_iterator()]
+    repository.convert_segments(segments, dryrun=False, inplace=inplace)
+    repository.convert_cache(dryrun=False)
     assert repo_valid(tmpdir)
     assert repo_valid(tmpdir)
 
 
 
 
@@ -138,9 +136,9 @@ def test_keys(tmpdir, attic_repo, attic_key_file):
     define above)
     define above)
     :param attic_key_file: an attic.key.KeyfileKey (fixture created above)
     :param attic_key_file: an attic.key.KeyfileKey (fixture created above)
     """
     """
-    repository = AtticRepositoryUpgrader(str(tmpdir), create=False)
-    keyfile = AtticKeyfileKey.find_key_file(repository)
-    AtticRepositoryUpgrader.convert_keyfiles(keyfile, dryrun=False)
+    with AtticRepositoryUpgrader(str(tmpdir), create=False) as repository:
+        keyfile = AtticKeyfileKey.find_key_file(repository)
+        AtticRepositoryUpgrader.convert_keyfiles(keyfile, dryrun=False)
     assert key_valid(attic_key_file.path)
     assert key_valid(attic_key_file.path)
 
 
 
 
@@ -167,19 +165,19 @@ def test_convert_all(tmpdir, attic_repo, attic_key_file, inplace):
         return stat_segment(path).st_ino
         return stat_segment(path).st_ino
 
 
     orig_inode = first_inode(attic_repo.path)
     orig_inode = first_inode(attic_repo.path)
-    repo = AtticRepositoryUpgrader(str(tmpdir), create=False)
-    # replicate command dispatch, partly
-    os.umask(UMASK_DEFAULT)
-    backup = repo.upgrade(dryrun=False, inplace=inplace)
-    if inplace:
-        assert backup is None
-        assert first_inode(repo.path) == orig_inode
-    else:
-        assert backup
-        assert first_inode(repo.path) != first_inode(backup)
-        # i have seen cases where the copied tree has world-readable
-        # permissions, which is wrong
-        assert stat_segment(backup).st_mode & UMASK_DEFAULT == 0
+    with AtticRepositoryUpgrader(str(tmpdir), create=False) as repository:
+        # replicate command dispatch, partly
+        os.umask(UMASK_DEFAULT)
+        backup = repository.upgrade(dryrun=False, inplace=inplace)
+        if inplace:
+            assert backup is None
+            assert first_inode(repository.path) == orig_inode
+        else:
+            assert backup
+            assert first_inode(repository.path) != first_inode(backup)
+            # i have seen cases where the copied tree has world-readable
+            # permissions, which is wrong
+            assert stat_segment(backup).st_mode & UMASK_DEFAULT == 0
 
 
     assert key_valid(attic_key_file.path)
     assert key_valid(attic_key_file.path)
     assert repo_valid(tmpdir)
     assert repo_valid(tmpdir)

+ 24 - 23
borg/upgrader.py

@@ -30,23 +30,23 @@ class AtticRepositoryUpgrader(Repository):
         we nevertheless do the order in reverse, as we prefer to do
         we nevertheless do the order in reverse, as we prefer to do
         the fast stuff first, to improve interactivity.
         the fast stuff first, to improve interactivity.
         """
         """
-        backup = None
-        if not inplace:
-            backup = '{}.upgrade-{:%Y-%m-%d-%H:%M:%S}'.format(self.path, datetime.datetime.now())
-            logger.info('making a hardlink copy in %s', backup)
-            if not dryrun:
-                shutil.copytree(self.path, backup, copy_function=os.link)
-        logger.info("opening attic repository with borg and converting")
-        # now lock the repo, after we have made the copy
-        self.lock = UpgradableLock(os.path.join(self.path, 'lock'), exclusive=True, timeout=1.0).acquire()
-        segments = [filename for i, filename in self.io.segment_iterator()]
-        try:
-            keyfile = self.find_attic_keyfile()
-        except KeyfileNotFoundError:
-            logger.warning("no key file found for repository")
-        else:
-            self.convert_keyfiles(keyfile, dryrun)
-        self.close()
+        with self:
+            backup = None
+            if not inplace:
+                backup = '{}.upgrade-{:%Y-%m-%d-%H:%M:%S}'.format(self.path, datetime.datetime.now())
+                logger.info('making a hardlink copy in %s', backup)
+                if not dryrun:
+                    shutil.copytree(self.path, backup, copy_function=os.link)
+            logger.info("opening attic repository with borg and converting")
+            # now lock the repo, after we have made the copy
+            self.lock = UpgradableLock(os.path.join(self.path, 'lock'), exclusive=True, timeout=1.0).acquire()
+            segments = [filename for i, filename in self.io.segment_iterator()]
+            try:
+                keyfile = self.find_attic_keyfile()
+            except KeyfileNotFoundError:
+                logger.warning("no key file found for repository")
+            else:
+                self.convert_keyfiles(keyfile, dryrun)
         # partial open: just hold on to the lock
         # partial open: just hold on to the lock
         self.lock = UpgradableLock(os.path.join(self.path, 'lock'),
         self.lock = UpgradableLock(os.path.join(self.path, 'lock'),
                                    exclusive=True).acquire()
                                    exclusive=True).acquire()
@@ -282,12 +282,13 @@ class BorgRepositoryUpgrader(Repository):
         """convert an old borg repository to a current borg repository
         """convert an old borg repository to a current borg repository
         """
         """
         logger.info("converting borg 0.xx to borg current")
         logger.info("converting borg 0.xx to borg current")
-        try:
-            keyfile = self.find_borg0xx_keyfile()
-        except KeyfileNotFoundError:
-            logger.warning("no key file found for repository")
-        else:
-            self.move_keyfiles(keyfile, dryrun)
+        with self:
+            try:
+                keyfile = self.find_borg0xx_keyfile()
+            except KeyfileNotFoundError:
+                logger.warning("no key file found for repository")
+            else:
+                self.move_keyfiles(keyfile, dryrun)
 
 
     def find_borg0xx_keyfile(self):
     def find_borg0xx_keyfile(self):
         return Borg0xxKeyfileKey.find_key_file(self)
         return Borg0xxKeyfileKey.find_key_file(self)