浏览代码

Merge pull request #1371 from ThomasWaldmann/fix-deadlock

fix deadlock
TW 8 年之前
父节点
当前提交
18654eaf91
共有 10 个文件被更改,包括 126 次插入79 次删除
  1. 7 7
      borg/archiver.py
  2. 3 3
      borg/cache.py
  3. 8 3
      borg/locking.py
  4. 12 13
      borg/remote.py
  5. 25 12
      borg/repository.py
  6. 2 2
      borg/testsuite/archiver.py
  7. 23 15
      borg/testsuite/locking.py
  8. 42 19
      borg/testsuite/repository.py
  9. 1 1
      borg/testsuite/upgrader.py
  10. 3 4
      borg/upgrader.py

+ 7 - 7
borg/archiver.py

@@ -68,8 +68,8 @@ def with_repository(fake=False, create=False, lock=True, exclusive=False, manife
             if argument(args, fake):
             if argument(args, fake):
                 return method(self, args, repository=None, **kwargs)
                 return method(self, args, repository=None, **kwargs)
             elif location.proto == 'ssh':
             elif location.proto == 'ssh':
-                repository = RemoteRepository(location, create=create, lock_wait=self.lock_wait, lock=lock,
-                                              append_only=append_only, args=args)
+                repository = RemoteRepository(location, create=create, exclusive=argument(args, exclusive),
+                                              lock_wait=self.lock_wait, lock=lock, append_only=append_only, args=args)
             else:
             else:
                 repository = Repository(location.path, create=create, exclusive=argument(args, exclusive),
                 repository = Repository(location.path, create=create, exclusive=argument(args, exclusive),
                                         lock_wait=self.lock_wait, lock=lock,
                                         lock_wait=self.lock_wait, lock=lock,
@@ -134,7 +134,7 @@ class Archiver:
             pass
             pass
         return self.exit_code
         return self.exit_code
 
 
-    @with_repository(exclusive='repair', manifest=False)
+    @with_repository(exclusive=True, manifest=False)
     def do_check(self, args, repository):
     def do_check(self, args, repository):
         """Check repository consistency"""
         """Check repository consistency"""
         if args.repair:
         if args.repair:
@@ -174,7 +174,7 @@ class Archiver:
         key_new.change_passphrase()  # option to change key protection passphrase, save
         key_new.change_passphrase()  # option to change key protection passphrase, save
         return EXIT_SUCCESS
         return EXIT_SUCCESS
 
 
-    @with_repository(fake='dry_run')
+    @with_repository(fake='dry_run', exclusive=True)
     def do_create(self, args, repository, manifest=None, key=None):
     def do_create(self, args, repository, manifest=None, key=None):
         """Create new archive"""
         """Create new archive"""
         matcher = PatternMatcher(fallback=True)
         matcher = PatternMatcher(fallback=True)
@@ -595,7 +595,7 @@ class Archiver:
         print(str(cache))
         print(str(cache))
         return self.exit_code
         return self.exit_code
 
 
-    @with_repository()
+    @with_repository(exclusive=True)
     def do_prune(self, args, repository, manifest, key):
     def do_prune(self, args, repository, manifest, key):
         """Prune repository archives according to specified rules"""
         """Prune repository archives according to specified rules"""
         if not any((args.hourly, args.daily,
         if not any((args.hourly, args.daily,
@@ -722,7 +722,7 @@ class Archiver:
                 print("object %s fetched." % hex_id)
                 print("object %s fetched." % hex_id)
         return EXIT_SUCCESS
         return EXIT_SUCCESS
 
 
-    @with_repository(manifest=False)
+    @with_repository(manifest=False, exclusive=True)
     def do_debug_put_obj(self, args, repository):
     def do_debug_put_obj(self, args, repository):
         """put file(s) contents into the repository"""
         """put file(s) contents into the repository"""
         for path in args.paths:
         for path in args.paths:
@@ -734,7 +734,7 @@ class Archiver:
         repository.commit()
         repository.commit()
         return EXIT_SUCCESS
         return EXIT_SUCCESS
 
 
-    @with_repository(manifest=False)
+    @with_repository(manifest=False, exclusive=True)
     def do_debug_delete_obj(self, args, repository):
     def do_debug_delete_obj(self, args, repository):
         """delete the objects with the given IDs from the repo"""
         """delete the objects with the given IDs from the repo"""
         modified = False
         modified = False

+ 3 - 3
borg/cache.py

@@ -11,7 +11,7 @@ from .logger import create_logger
 logger = create_logger()
 logger = create_logger()
 from .helpers import Error, get_cache_dir, decode_dict, int_to_bigint, \
 from .helpers import Error, get_cache_dir, decode_dict, int_to_bigint, \
     bigint_to_int, format_file_size, yes
     bigint_to_int, format_file_size, yes
-from .locking import UpgradableLock
+from .locking import Lock
 from .hashindex import ChunkIndex
 from .hashindex import ChunkIndex
 
 
 import msgpack
 import msgpack
@@ -35,7 +35,7 @@ class Cache:
     @staticmethod
     @staticmethod
     def break_lock(repository, path=None):
     def break_lock(repository, path=None):
         path = path or os.path.join(get_cache_dir(), hexlify(repository.id).decode('ascii'))
         path = path or os.path.join(get_cache_dir(), hexlify(repository.id).decode('ascii'))
-        UpgradableLock(os.path.join(path, 'lock'), exclusive=True).break_lock()
+        Lock(os.path.join(path, 'lock'), exclusive=True).break_lock()
 
 
     @staticmethod
     @staticmethod
     def destroy(repository, path=None):
     def destroy(repository, path=None):
@@ -152,7 +152,7 @@ Chunk index:    {0.total_unique_chunks:20d} {0.total_chunks:20d}"""
     def open(self, lock_wait=None):
     def open(self, lock_wait=None):
         if not os.path.isdir(self.path):
         if not os.path.isdir(self.path):
             raise Exception('%s Does not look like a Borg cache' % self.path)
             raise Exception('%s Does not look like a Borg cache' % self.path)
-        self.lock = UpgradableLock(os.path.join(self.path, 'lock'), exclusive=True, timeout=lock_wait).acquire()
+        self.lock = Lock(os.path.join(self.path, 'lock'), exclusive=True, timeout=lock_wait).acquire()
         self.rollback()
         self.rollback()
 
 
     def close(self):
     def close(self):

+ 8 - 3
borg/locking.py

@@ -217,7 +217,7 @@ class LockRoster:
         self.save(roster)
         self.save(roster)
 
 
 
 
-class UpgradableLock:
+class Lock:
     """
     """
     A Lock for a resource that can be accessed in a shared or exclusive way.
     A Lock for a resource that can be accessed in a shared or exclusive way.
     Typically, write access to a resource needs an exclusive lock (1 writer,
     Typically, write access to a resource needs an exclusive lock (1 writer,
@@ -226,7 +226,7 @@ class UpgradableLock:
 
 
     If possible, try to use the contextmanager here like::
     If possible, try to use the contextmanager here like::
 
 
-        with UpgradableLock(...) as lock:
+        with Lock(...) as lock:
             ...
             ...
 
 
     This makes sure the lock is released again if the block is left, no
     This makes sure the lock is released again if the block is left, no
@@ -242,7 +242,7 @@ class UpgradableLock:
         self._roster = LockRoster(path + '.roster', id=id)
         self._roster = LockRoster(path + '.roster', id=id)
         # an exclusive lock, used for:
         # an exclusive lock, used for:
         # - holding while doing roster queries / updates
         # - holding while doing roster queries / updates
-        # - holding while the UpgradableLock itself is exclusive
+        # - holding while the Lock instance itself is exclusive
         self._lock = ExclusiveLock(path + '.exclusive', id=id, timeout=timeout)
         self._lock = ExclusiveLock(path + '.exclusive', id=id, timeout=timeout)
 
 
     def __enter__(self):
     def __enter__(self):
@@ -299,6 +299,8 @@ class UpgradableLock:
                 self._roster.modify(SHARED, REMOVE)
                 self._roster.modify(SHARED, REMOVE)
 
 
     def upgrade(self):
     def upgrade(self):
+        # WARNING: if multiple read-lockers want to upgrade, it will deadlock because they
+        # all will wait until the other read locks go away - and that won't happen.
         if not self.is_exclusive:
         if not self.is_exclusive:
             self.acquire(exclusive=True, remove=SHARED)
             self.acquire(exclusive=True, remove=SHARED)
 
 
@@ -306,6 +308,9 @@ class UpgradableLock:
         if self.is_exclusive:
         if self.is_exclusive:
             self.acquire(exclusive=False, remove=EXCLUSIVE)
             self.acquire(exclusive=False, remove=EXCLUSIVE)
 
 
+    def got_exclusive_lock(self):
+        return self.is_exclusive and self._lock.is_locked() and self._lock.by_me()
+
     def break_lock(self):
     def break_lock(self):
         self._roster.remove()
         self._roster.remove()
         self._lock.break_lock()
         self._lock.break_lock()

+ 12 - 13
borg/remote.py

@@ -114,7 +114,7 @@ class RepositoryServer:  # pragma: no cover
     def negotiate(self, versions):
     def negotiate(self, versions):
         return RPC_PROTOCOL_VERSION
         return RPC_PROTOCOL_VERSION
 
 
-    def open(self, path, create=False, lock_wait=None, lock=True, append_only=False):
+    def open(self, path, create=False, lock_wait=None, lock=True, exclusive=None, append_only=False):
         path = os.fsdecode(path)
         path = os.fsdecode(path)
         if path.startswith('/~'):
         if path.startswith('/~'):
             path = path[1:]
             path = path[1:]
@@ -125,7 +125,9 @@ class RepositoryServer:  # pragma: no cover
                     break
                     break
             else:
             else:
                 raise PathNotAllowed(path)
                 raise PathNotAllowed(path)
-        self.repository = Repository(path, create, lock_wait=lock_wait, lock=lock, append_only=self.append_only or append_only)
+        self.repository = Repository(path, create, lock_wait=lock_wait, lock=lock,
+                                     append_only=self.append_only or append_only,
+                                     exclusive=exclusive)
         self.repository.__enter__()  # clean exit handled by serve() method
         self.repository.__enter__()  # clean exit handled by serve() method
         return self.repository.id
         return self.repository.id
 
 
@@ -141,7 +143,7 @@ class RemoteRepository:
     class NoAppendOnlyOnServer(Error):
     class NoAppendOnlyOnServer(Error):
         """Server does not support --append-only."""
         """Server does not support --append-only."""
 
 
-    def __init__(self, location, create=False, lock_wait=None, lock=True, append_only=False, args=None):
+    def __init__(self, location, create=False, exclusive=False, lock_wait=None, lock=True, append_only=False, args=None):
         self.location = self._location = location
         self.location = self._location = location
         self.preload_ids = []
         self.preload_ids = []
         self.msgid = 0
         self.msgid = 0
@@ -178,16 +180,13 @@ class RemoteRepository:
                 raise ConnectionClosedWithHint('Is borg working on the server?') from None
                 raise ConnectionClosedWithHint('Is borg working on the server?') from None
             if version != RPC_PROTOCOL_VERSION:
             if version != RPC_PROTOCOL_VERSION:
                 raise Exception('Server insisted on using unsupported protocol version %d' % version)
                 raise Exception('Server insisted on using unsupported protocol version %d' % version)
-            # Because of protocol versions, only send append_only if necessary
-            if append_only:
-                try:
-                    self.id = self.call('open', self.location.path, create, lock_wait, lock, append_only)
-                except self.RPCError as err:
-                    if err.remote_type == 'TypeError':
-                        raise self.NoAppendOnlyOnServer() from err
-                    else:
-                        raise
-            else:
+            try:
+                self.id = self.call('open', self.location.path, create, lock_wait, lock, exclusive, append_only)
+            except self.RPCError as err:
+                if err.remote_type != 'TypeError':
+                    raise
+                if append_only:
+                    raise self.NoAppendOnlyOnServer()
                 self.id = self.call('open', self.location.path, create, lock_wait, lock)
                 self.id = self.call('open', self.location.path, create, lock_wait, lock)
         except Exception:
         except Exception:
             self.close()
             self.close()

+ 25 - 12
borg/repository.py

@@ -14,7 +14,7 @@ from zlib import crc32
 import msgpack
 import msgpack
 from .helpers import Error, ErrorWithTraceback, IntegrityError, Location, ProgressIndicatorPercent
 from .helpers import Error, ErrorWithTraceback, IntegrityError, Location, ProgressIndicatorPercent
 from .hashindex import NSIndex
 from .hashindex import NSIndex
-from .locking import UpgradableLock, LockError, LockErrorT
+from .locking import Lock, LockError, LockErrorT
 from .lrucache import LRUCache
 from .lrucache import LRUCache
 from .platform import sync_dir
 from .platform import sync_dir
 
 
@@ -79,7 +79,7 @@ class Repository:
         if self.do_create:
         if self.do_create:
             self.do_create = False
             self.do_create = False
             self.create(self.path)
             self.create(self.path)
-        self.open(self.path, self.exclusive, lock_wait=self.lock_wait, lock=self.do_lock)
+        self.open(self.path, bool(self.exclusive), lock_wait=self.lock_wait, lock=self.do_lock)
         return self
         return self
 
 
     def __exit__(self, exc_type, exc_val, exc_tb):
     def __exit__(self, exc_type, exc_val, exc_tb):
@@ -161,14 +161,14 @@ class Repository:
         return self.get_index_transaction_id()
         return self.get_index_transaction_id()
 
 
     def break_lock(self):
     def break_lock(self):
-        UpgradableLock(os.path.join(self.path, 'lock')).break_lock()
+        Lock(os.path.join(self.path, 'lock')).break_lock()
 
 
     def open(self, path, exclusive, lock_wait=None, lock=True):
     def open(self, path, exclusive, lock_wait=None, lock=True):
         self.path = path
         self.path = path
         if not os.path.isdir(path):
         if not os.path.isdir(path):
             raise self.DoesNotExist(path)
             raise self.DoesNotExist(path)
         if lock:
         if lock:
-            self.lock = UpgradableLock(os.path.join(path, 'lock'), exclusive, timeout=lock_wait).acquire()
+            self.lock = Lock(os.path.join(path, 'lock'), exclusive, timeout=lock_wait).acquire()
         else:
         else:
             self.lock = None
             self.lock = None
         self.config = ConfigParser(interpolation=None)
         self.config = ConfigParser(interpolation=None)
@@ -207,14 +207,23 @@ class Repository:
 
 
     def prepare_txn(self, transaction_id, do_cleanup=True):
     def prepare_txn(self, transaction_id, do_cleanup=True):
         self._active_txn = True
         self._active_txn = True
-        try:
-            self.lock.upgrade()
-        except (LockError, LockErrorT):
-            # if upgrading the lock to exclusive fails, we do not have an
-            # active transaction. this is important for "serve" mode, where
-            # the repository instance lives on - even if exceptions happened.
-            self._active_txn = False
-            raise
+        if not self.lock.got_exclusive_lock():
+            if self.exclusive is not None:
+                # self.exclusive is either True or False, thus a new client is active here.
+                # if it is False and we get here, the caller did not use exclusive=True although
+                # it is needed for a write operation. if it is True and we get here, something else
+                # went very wrong, because we should have a exclusive lock, but we don't.
+                raise AssertionError("bug in code, exclusive lock should exist here")
+            # if we are here, this is an old client talking to a new server (expecting lock upgrade).
+            # or we are replaying segments and might need a lock upgrade for that.
+            try:
+                self.lock.upgrade()
+            except (LockError, LockErrorT):
+                # if upgrading the lock to exclusive fails, we do not have an
+                # active transaction. this is important for "serve" mode, where
+                # the repository instance lives on - even if exceptions happened.
+                self._active_txn = False
+                raise
         if not self.index or transaction_id is None:
         if not self.index or transaction_id is None:
             self.index = self.open_index(transaction_id)
             self.index = self.open_index(transaction_id)
         if transaction_id is None:
         if transaction_id is None:
@@ -308,6 +317,9 @@ class Repository:
         self.compact = set()
         self.compact = set()
 
 
     def replay_segments(self, index_transaction_id, segments_transaction_id):
     def replay_segments(self, index_transaction_id, segments_transaction_id):
+        # fake an old client, so that in case we do not have an exclusive lock yet, prepare_txn will upgrade the lock:
+        remember_exclusive = self.exclusive
+        self.exclusive = None
         self.prepare_txn(index_transaction_id, do_cleanup=False)
         self.prepare_txn(index_transaction_id, do_cleanup=False)
         try:
         try:
             segment_count = sum(1 for _ in self.io.segment_iterator())
             segment_count = sum(1 for _ in self.io.segment_iterator())
@@ -323,6 +335,7 @@ class Repository:
             pi.finish()
             pi.finish()
             self.write_index()
             self.write_index()
         finally:
         finally:
+            self.exclusive = remember_exclusive
             self.rollback()
             self.rollback()
 
 
     def _update_index(self, segment, objects, report=None):
     def _update_index(self, segment, objects, report=None):

+ 2 - 2
borg/testsuite/archiver.py

@@ -236,7 +236,7 @@ class ArchiverTestCaseBase(BaseTestCase):
         self.cmd('create', self.repository_location + '::' + name, src_dir)
         self.cmd('create', self.repository_location + '::' + name, src_dir)
 
 
     def open_archive(self, name):
     def open_archive(self, name):
-        repository = Repository(self.repository_path)
+        repository = Repository(self.repository_path, exclusive=True)
         with repository:
         with repository:
             manifest, key = Manifest.load(repository)
             manifest, key = Manifest.load(repository)
             archive = Archive(repository, key, manifest, name)
             archive = Archive(repository, key, manifest, name)
@@ -1288,7 +1288,7 @@ class ArchiverCheckTestCase(ArchiverTestCaseBase):
 
 
     def test_extra_chunks(self):
     def test_extra_chunks(self):
         self.cmd('check', self.repository_location, exit_code=0)
         self.cmd('check', self.repository_location, exit_code=0)
-        with Repository(self.repository_location) as repository:
+        with Repository(self.repository_location, exclusive=True) as repository:
             repository.put(b'01234567890123456789012345678901', b'xxxx')
             repository.put(b'01234567890123456789012345678901', b'xxxx')
             repository.commit()
             repository.commit()
         self.cmd('check', self.repository_location, exit_code=1)
         self.cmd('check', self.repository_location, exit_code=1)

+ 23 - 15
borg/testsuite/locking.py

@@ -2,7 +2,7 @@ import time
 
 
 import pytest
 import pytest
 
 
-from ..locking import get_id, TimeoutTimer, ExclusiveLock, UpgradableLock, LockRoster, \
+from ..locking import get_id, TimeoutTimer, ExclusiveLock, Lock, LockRoster, \
                       ADD, REMOVE, SHARED, EXCLUSIVE, LockTimeout
                       ADD, REMOVE, SHARED, EXCLUSIVE, LockTimeout
 
 
 
 
@@ -58,52 +58,60 @@ class TestExclusiveLock:
                 ExclusiveLock(lockpath, id=ID2, timeout=0.1).acquire()
                 ExclusiveLock(lockpath, id=ID2, timeout=0.1).acquire()
 
 
 
 
-class TestUpgradableLock:
+class TestLock:
     def test_shared(self, lockpath):
     def test_shared(self, lockpath):
-        lock1 = UpgradableLock(lockpath, exclusive=False, id=ID1).acquire()
-        lock2 = UpgradableLock(lockpath, exclusive=False, id=ID2).acquire()
+        lock1 = Lock(lockpath, exclusive=False, id=ID1).acquire()
+        lock2 = Lock(lockpath, exclusive=False, id=ID2).acquire()
         assert len(lock1._roster.get(SHARED)) == 2
         assert len(lock1._roster.get(SHARED)) == 2
         assert len(lock1._roster.get(EXCLUSIVE)) == 0
         assert len(lock1._roster.get(EXCLUSIVE)) == 0
         lock1.release()
         lock1.release()
         lock2.release()
         lock2.release()
 
 
     def test_exclusive(self, lockpath):
     def test_exclusive(self, lockpath):
-        with UpgradableLock(lockpath, exclusive=True, id=ID1) as lock:
+        with Lock(lockpath, exclusive=True, id=ID1) as lock:
             assert len(lock._roster.get(SHARED)) == 0
             assert len(lock._roster.get(SHARED)) == 0
             assert len(lock._roster.get(EXCLUSIVE)) == 1
             assert len(lock._roster.get(EXCLUSIVE)) == 1
 
 
     def test_upgrade(self, lockpath):
     def test_upgrade(self, lockpath):
-        with UpgradableLock(lockpath, exclusive=False) as lock:
+        with Lock(lockpath, exclusive=False) as lock:
             lock.upgrade()
             lock.upgrade()
             lock.upgrade()  # NOP
             lock.upgrade()  # NOP
             assert len(lock._roster.get(SHARED)) == 0
             assert len(lock._roster.get(SHARED)) == 0
             assert len(lock._roster.get(EXCLUSIVE)) == 1
             assert len(lock._roster.get(EXCLUSIVE)) == 1
 
 
     def test_downgrade(self, lockpath):
     def test_downgrade(self, lockpath):
-        with UpgradableLock(lockpath, exclusive=True) as lock:
+        with Lock(lockpath, exclusive=True) as lock:
             lock.downgrade()
             lock.downgrade()
             lock.downgrade()  # NOP
             lock.downgrade()  # NOP
             assert len(lock._roster.get(SHARED)) == 1
             assert len(lock._roster.get(SHARED)) == 1
             assert len(lock._roster.get(EXCLUSIVE)) == 0
             assert len(lock._roster.get(EXCLUSIVE)) == 0
 
 
+    def test_got_exclusive_lock(self, lockpath):
+        lock = Lock(lockpath, exclusive=True, id=ID1)
+        assert not lock.got_exclusive_lock()
+        lock.acquire()
+        assert lock.got_exclusive_lock()
+        lock.release()
+        assert not lock.got_exclusive_lock()
+
     def test_break(self, lockpath):
     def test_break(self, lockpath):
-        lock = UpgradableLock(lockpath, exclusive=True, id=ID1).acquire()
+        lock = Lock(lockpath, exclusive=True, id=ID1).acquire()
         lock.break_lock()
         lock.break_lock()
         assert len(lock._roster.get(SHARED)) == 0
         assert len(lock._roster.get(SHARED)) == 0
         assert len(lock._roster.get(EXCLUSIVE)) == 0
         assert len(lock._roster.get(EXCLUSIVE)) == 0
-        with UpgradableLock(lockpath, exclusive=True, id=ID2):
+        with Lock(lockpath, exclusive=True, id=ID2):
             pass
             pass
 
 
     def test_timeout(self, lockpath):
     def test_timeout(self, lockpath):
-        with UpgradableLock(lockpath, exclusive=False, id=ID1):
+        with Lock(lockpath, exclusive=False, id=ID1):
             with pytest.raises(LockTimeout):
             with pytest.raises(LockTimeout):
-                UpgradableLock(lockpath, exclusive=True, id=ID2, timeout=0.1).acquire()
-        with UpgradableLock(lockpath, exclusive=True, id=ID1):
+                Lock(lockpath, exclusive=True, id=ID2, timeout=0.1).acquire()
+        with Lock(lockpath, exclusive=True, id=ID1):
             with pytest.raises(LockTimeout):
             with pytest.raises(LockTimeout):
-                UpgradableLock(lockpath, exclusive=False, id=ID2, timeout=0.1).acquire()
-        with UpgradableLock(lockpath, exclusive=True, id=ID1):
+                Lock(lockpath, exclusive=False, id=ID2, timeout=0.1).acquire()
+        with Lock(lockpath, exclusive=True, id=ID1):
             with pytest.raises(LockTimeout):
             with pytest.raises(LockTimeout):
-                UpgradableLock(lockpath, exclusive=True, id=ID2, timeout=0.1).acquire()
+                Lock(lockpath, exclusive=True, id=ID2, timeout=0.1).acquire()
 
 
 
 
 @pytest.fixture()
 @pytest.fixture()

+ 42 - 19
borg/testsuite/repository.py

@@ -6,17 +6,23 @@ from unittest.mock import patch
 
 
 from ..hashindex import NSIndex
 from ..hashindex import NSIndex
 from ..helpers import Location, IntegrityError
 from ..helpers import Location, IntegrityError
-from ..locking import UpgradableLock, LockFailed
+from ..locking import Lock, LockFailed
 from ..remote import RemoteRepository, InvalidRPCMethod
 from ..remote import RemoteRepository, InvalidRPCMethod
 from ..repository import Repository, LoggedIO, TAG_COMMIT
 from ..repository import Repository, LoggedIO, TAG_COMMIT
 from . import BaseTestCase
 from . import BaseTestCase
 
 
 
 
+UNSPECIFIED = object()  # for default values where we can't use None
+
+
 class RepositoryTestCaseBase(BaseTestCase):
 class RepositoryTestCaseBase(BaseTestCase):
     key_size = 32
     key_size = 32
+    exclusive = True
 
 
-    def open(self, create=False):
-        return Repository(os.path.join(self.tmppath, 'repository'), create=create)
+    def open(self, create=False, exclusive=UNSPECIFIED):
+        if exclusive is UNSPECIFIED:
+            exclusive = self.exclusive
+        return Repository(os.path.join(self.tmppath, 'repository'), exclusive=exclusive, create=create)
 
 
     def setUp(self):
     def setUp(self):
         self.tmppath = tempfile.mkdtemp()
         self.tmppath = tempfile.mkdtemp()
@@ -27,10 +33,10 @@ class RepositoryTestCaseBase(BaseTestCase):
         self.repository.close()
         self.repository.close()
         shutil.rmtree(self.tmppath)
         shutil.rmtree(self.tmppath)
 
 
-    def reopen(self):
+    def reopen(self, exclusive=UNSPECIFIED):
         if self.repository:
         if self.repository:
             self.repository.close()
             self.repository.close()
-        self.repository = self.open()
+        self.repository = self.open(exclusive=exclusive)
 
 
 
 
 class RepositoryTestCase(RepositoryTestCaseBase):
 class RepositoryTestCase(RepositoryTestCaseBase):
@@ -156,17 +162,6 @@ class RepositoryCommitTestCase(RepositoryTestCaseBase):
             self.assert_equal(len(self.repository), 3)
             self.assert_equal(len(self.repository), 3)
             self.assert_equal(self.repository.check(), True)
             self.assert_equal(self.repository.check(), True)
 
 
-    def test_replay_of_readonly_repository(self):
-        self.add_keys()
-        for name in os.listdir(self.repository.path):
-            if name.startswith('index.'):
-                os.unlink(os.path.join(self.repository.path, name))
-        with patch.object(UpgradableLock, 'upgrade', side_effect=LockFailed) as upgrade:
-            self.reopen()
-            with self.repository:
-                self.assert_raises(LockFailed, lambda: len(self.repository))
-                upgrade.assert_called_once_with()
-
     def test_crash_before_write_index(self):
     def test_crash_before_write_index(self):
         self.add_keys()
         self.add_keys()
         self.repository.write_index = None
         self.repository.write_index = None
@@ -179,6 +174,32 @@ class RepositoryCommitTestCase(RepositoryTestCaseBase):
             self.assert_equal(len(self.repository), 3)
             self.assert_equal(len(self.repository), 3)
             self.assert_equal(self.repository.check(), True)
             self.assert_equal(self.repository.check(), True)
 
 
+    def test_replay_lock_upgrade_old(self):
+        self.add_keys()
+        for name in os.listdir(self.repository.path):
+            if name.startswith('index.'):
+                os.unlink(os.path.join(self.repository.path, name))
+        with patch.object(Lock, 'upgrade', side_effect=LockFailed) as upgrade:
+            self.reopen(exclusive=None)  # simulate old client that always does lock upgrades
+            with self.repository:
+                # the repo is only locked by a shared read lock, but to replay segments,
+                # we need an exclusive write lock - check if the lock gets upgraded.
+                self.assert_raises(LockFailed, lambda: len(self.repository))
+                upgrade.assert_called_once_with()
+
+    def test_replay_lock_upgrade(self):
+        self.add_keys()
+        for name in os.listdir(self.repository.path):
+            if name.startswith('index.'):
+                os.unlink(os.path.join(self.repository.path, name))
+        with patch.object(Lock, 'upgrade', side_effect=LockFailed) as upgrade:
+            self.reopen(exclusive=False)  # current client usually does not do lock upgrade, except for replay
+            with self.repository:
+                # the repo is only locked by a shared read lock, but to replay segments,
+                # we need an exclusive write lock - check if the lock gets upgraded.
+                self.assert_raises(LockFailed, lambda: len(self.repository))
+                upgrade.assert_called_once_with()
+
     def test_crash_before_deleting_compacted_segments(self):
     def test_crash_before_deleting_compacted_segments(self):
         self.add_keys()
         self.add_keys()
         self.repository.io.delete_segment = None
         self.repository.io.delete_segment = None
@@ -202,7 +223,7 @@ class RepositoryCommitTestCase(RepositoryTestCaseBase):
 
 
 class RepositoryAppendOnlyTestCase(RepositoryTestCaseBase):
 class RepositoryAppendOnlyTestCase(RepositoryTestCaseBase):
     def open(self, create=False):
     def open(self, create=False):
-        return Repository(os.path.join(self.tmppath, 'repository'), create=create, append_only=True)
+        return Repository(os.path.join(self.tmppath, 'repository'), exclusive=True, create=create, append_only=True)
 
 
     def test_destroy_append_only(self):
     def test_destroy_append_only(self):
         # Can't destroy append only repo (via the API)
         # Can't destroy append only repo (via the API)
@@ -365,7 +386,8 @@ class RepositoryCheckTestCase(RepositoryTestCaseBase):
 class RemoteRepositoryTestCase(RepositoryTestCase):
 class RemoteRepositoryTestCase(RepositoryTestCase):
 
 
     def open(self, create=False):
     def open(self, create=False):
-        return RemoteRepository(Location('__testsuite__:' + os.path.join(self.tmppath, 'repository')), create=create)
+        return RemoteRepository(Location('__testsuite__:' + os.path.join(self.tmppath, 'repository')),
+                                exclusive=True, create=create)
 
 
     def test_invalid_rpc(self):
     def test_invalid_rpc(self):
         self.assert_raises(InvalidRPCMethod, lambda: self.repository.call('__init__', None))
         self.assert_raises(InvalidRPCMethod, lambda: self.repository.call('__init__', None))
@@ -394,7 +416,8 @@ class RemoteRepositoryTestCase(RepositoryTestCase):
 class RemoteRepositoryCheckTestCase(RepositoryCheckTestCase):
 class RemoteRepositoryCheckTestCase(RepositoryCheckTestCase):
 
 
     def open(self, create=False):
     def open(self, create=False):
-        return RemoteRepository(Location('__testsuite__:' + os.path.join(self.tmppath, 'repository')), create=create)
+        return RemoteRepository(Location('__testsuite__:' + os.path.join(self.tmppath, 'repository')),
+                                exclusive=True, create=create)
 
 
     def test_crash_before_compact(self):
     def test_crash_before_compact(self):
         # skip this test, we can't mock-patch a Repository class in another process!
         # skip this test, we can't mock-patch a Repository class in another process!

+ 1 - 1
borg/testsuite/upgrader.py

@@ -23,7 +23,7 @@ def repo_valid(path):
     :param path: the path to the repository
     :param path: the path to the repository
     :returns: if borg can check the repository
     :returns: if borg can check the repository
     """
     """
-    with Repository(str(path), create=False) as repository:
+    with Repository(str(path), exclusive=True, create=False) as repository:
         # can't check raises() because check() handles the error
         # can't check raises() because check() handles the error
         return repository.check()
         return repository.check()
 
 

+ 3 - 4
borg/upgrader.py

@@ -7,7 +7,7 @@ import shutil
 import time
 import time
 
 
 from .helpers import get_keys_dir, get_cache_dir, ProgressIndicatorPercent
 from .helpers import get_keys_dir, get_cache_dir, ProgressIndicatorPercent
-from .locking import UpgradableLock
+from .locking import Lock
 from .repository import Repository, MAGIC
 from .repository import Repository, MAGIC
 from .key import KeyfileKey, KeyfileNotFoundError
 from .key import KeyfileKey, KeyfileNotFoundError
 
 
@@ -39,7 +39,7 @@ class AtticRepositoryUpgrader(Repository):
                     shutil.copytree(self.path, backup, copy_function=os.link)
                     shutil.copytree(self.path, backup, copy_function=os.link)
             logger.info("opening attic repository with borg and converting")
             logger.info("opening attic repository with borg and converting")
             # now lock the repo, after we have made the copy
             # now lock the repo, after we have made the copy
-            self.lock = UpgradableLock(os.path.join(self.path, 'lock'), exclusive=True, timeout=1.0).acquire()
+            self.lock = Lock(os.path.join(self.path, 'lock'), exclusive=True, timeout=1.0).acquire()
             segments = [filename for i, filename in self.io.segment_iterator()]
             segments = [filename for i, filename in self.io.segment_iterator()]
             try:
             try:
                 keyfile = self.find_attic_keyfile()
                 keyfile = self.find_attic_keyfile()
@@ -48,8 +48,7 @@ class AtticRepositoryUpgrader(Repository):
             else:
             else:
                 self.convert_keyfiles(keyfile, dryrun)
                 self.convert_keyfiles(keyfile, dryrun)
         # partial open: just hold on to the lock
         # partial open: just hold on to the lock
-        self.lock = UpgradableLock(os.path.join(self.path, 'lock'),
-                                   exclusive=True).acquire()
+        self.lock = Lock(os.path.join(self.path, 'lock'), exclusive=True).acquire()
         try:
         try:
             self.convert_cache(dryrun)
             self.convert_cache(dryrun)
             self.convert_repo_index(dryrun=dryrun, inplace=inplace)
             self.convert_repo_index(dryrun=dryrun, inplace=inplace)