Quellcode durchsuchen

repository: Fix potential race condition

If we crash between compact_segments() and write_index() and the
transaction deletes objects that are newer than the current index
might become undeleted.
Jonas Borgström vor 11 Jahren
Ursprung
Commit
6425d16aa8
2 geänderte Dateien mit 53 neuen und 63 gelöschten Zeilen
  1. 47 46
      attic/repository.py
  2. 6 17
      attic/testsuite/repository.py

+ 47 - 46
attic/repository.py

@@ -6,7 +6,6 @@ import os
 import shutil
 import struct
 import sys
-import time
 from zlib import crc32
 
 from .hashindex import NSIndex
@@ -87,15 +86,17 @@ class Repository(object):
     def get_transaction_id(self):
         index_transaction_id = self.get_index_transaction_id()
         segments_transaction_id = self.io.get_segments_transaction_id()
+        if index_transaction_id is not None and segments_transaction_id is None:
+            raise self.CheckNeeded(self.path)
         # Attempt to automatically rebuild index if we crashed between commit
         # tag write and index save
-        if (index_transaction_id if index_transaction_id is not None else -1) < (segments_transaction_id if segments_transaction_id is not None else -1):
-            self.replay_segments(index_transaction_id, segments_transaction_id)
-            index_transaction_id = self.get_index_transaction_id()
-
         if index_transaction_id != segments_transaction_id:
-            raise self.CheckNeeded(self.path)
-        return index_transaction_id
+            if index_transaction_id is not None and index_transaction_id > segments_transaction_id:
+                replay_from = None
+            else:
+                replay_from = index_transaction_id
+            self.replay_segments(replay_from, segments_transaction_id)
+        return self.get_index_transaction_id()
 
     def open(self, path):
         self.path = path
@@ -175,21 +176,24 @@ class Repository(object):
         """
         if not self.compact:
             return
-
-        def lookup(tag, key):
-            return tag == TAG_PUT and self.index.get(key, (-1, -1))[0] == segment
+        index_transaction_id = self.get_index_transaction_id()
         segments = self.segments
         for segment in sorted(self.compact):
-            if segments[segment] > 0:
-                for tag, key, data in self.io.iter_objects(segment, lookup, include_data=True):
-                    new_segment, offset = self.io.write_put(key, data)
-                    self.index[key] = new_segment, offset
-                    segments.setdefault(new_segment, 0)
-                    segments[new_segment] += 1
-                    segments[segment] -= 1
+            if self.io.segment_exists(segment):
+                for tag, key, data in self.io.iter_objects(segment, include_data=True):
+                    if tag == TAG_PUT and self.index.get(key, (-1, -1))[0] == segment:
+                        new_segment, offset = self.io.write_put(key, data)
+                        self.index[key] = new_segment, offset
+                        segments.setdefault(new_segment, 0)
+                        segments[new_segment] += 1
+                        segments[segment] -= 1
+                    elif tag == TAG_DELETE:
+                        if index_transaction_id is None or segment > index_transaction_id:
+                            self.io.write_delete(key)
                 assert segments[segment] == 0
+
         self.io.write_commit()
-        for segment in self.compact:
+        for segment in sorted(self.compact):
             assert self.segments.pop(segment) == 0
             self.io.delete_segment(segment)
         self.compact = set()
@@ -215,10 +219,10 @@ class Repository(object):
                 elif tag == TAG_DELETE:
                     try:
                         s, _ = self.index.pop(key)
+                        self.segments[s] -= 1
+                        self.compact.add(s)
                     except KeyError:
-                        raise self.CheckNeeded(self.path)
-                    self.segments[s] -= 1
-                    self.compact.add(s)
+                        pass
                     self.compact.add(segment)
                 elif tag == TAG_COMMIT:
                     continue
@@ -246,21 +250,16 @@ class Repository(object):
 
         assert not self._active_txn
         report_progress('Starting repository check...')
-        index_transaction_id = self.get_index_transaction_id()
-        segments_transaction_id = self.io.get_segments_transaction_id()
-        if index_transaction_id is None and segments_transaction_id is None:
-            return True
-        if segments_transaction_id is not None:
-            transaction_id = segments_transaction_id
-        else:
-            transaction_id = index_transaction_id
-        self.get_index(None)
-        if index_transaction_id == segments_transaction_id:
+        try:
+            transaction_id = self.get_transaction_id()
             current_index = self.get_read_only_index(transaction_id)
-        else:
+        except Exception:
+            transaction_id = self.io.get_segments_transaction_id()
             current_index = None
-            report_progress('No suitable index found', error=True)
-
+        if transaction_id is None:
+            transaction_id = self.get_index_transaction_id()
+        segments_transaction_id = self.io.get_segments_transaction_id()
+        self.get_index(None)
         for segment, filename in self.io.segment_iterator():
             if segment > transaction_id:
                 continue
@@ -302,7 +301,7 @@ class Repository(object):
             self.io.write_commit()
             self.io.close_segment()
         if current_index and not repair:
-            if len(current_index) != len(self.index) and False:
+            if len(current_index) != len(self.index):
                 report_progress('Index object count mismatch. {} != {}'.format(len(current_index), len(self.index)), error=True)
             elif current_index:
                 for key, value in self.index.iteritems():
@@ -369,13 +368,13 @@ class Repository(object):
             self.get_index(self.get_transaction_id())
         try:
             segment, offset = self.index.pop(id)
-            self.segments[segment] -= 1
-            self.compact.add(segment)
-            segment = self.io.write_delete(id)
-            self.compact.add(segment)
-            self.segments.setdefault(segment, 0)
         except KeyError:
             raise self.DoesNotExist(self.path)
+        self.segments[segment] -= 1
+        self.compact.add(segment)
+        segment = self.io.write_delete(id)
+        self.compact.add(segment)
+        self.segments.setdefault(segment, 0)
 
     def preload(self, ids):
         """Preload objects (only applies to remote repositories
@@ -479,7 +478,10 @@ class LoggedIO(object):
         except OSError:
             pass
 
-    def iter_objects(self, segment, lookup=None, include_data=False):
+    def segment_exists(self, segment):
+        return os.path.exists(self.segment_filename(segment))
+
+    def iter_objects(self, segment, include_data=False):
         fd = self.get_fd(segment)
         fd.seek(0)
         if fd.read(8) != MAGIC:
@@ -498,11 +500,10 @@ class LoggedIO(object):
             key = None
             if tag in (TAG_PUT, TAG_DELETE):
                 key = rest[:32]
-            if not lookup or lookup(tag, key):
-                if include_data:
-                    yield tag, key, rest[32:]
-                else:
-                    yield tag, key, offset
+            if include_data:
+                yield tag, key, rest[32:]
+            else:
+                yield tag, key, offset
             offset += size
             header = fd.read(self.header_fmt.size)
 

+ 6 - 17
attic/testsuite/repository.py

@@ -115,9 +115,11 @@ class RepositoryCommitTestCase(RepositoryTestCaseBase):
     def add_keys(self):
         self.repository.put(b'00000000000000000000000000000000', b'foo')
         self.repository.put(b'00000000000000000000000000000001', b'bar')
+        self.repository.put(b'00000000000000000000000000000003', b'bar')
         self.repository.commit()
         self.repository.put(b'00000000000000000000000000000001', b'bar2')
         self.repository.put(b'00000000000000000000000000000002', b'boo')
+        self.repository.delete(b'00000000000000000000000000000003')
 
     def test_replay_of_missing_index(self):
         self.add_keys()
@@ -125,7 +127,7 @@ class RepositoryCommitTestCase(RepositoryTestCaseBase):
             if name.startswith('index.'):
                 os.unlink(os.path.join(self.repository.path, name))
         self.reopen()
-        self.assert_equal(len(self.repository), 2)
+        self.assert_equal(len(self.repository), 3)
         self.assert_equal(self.repository.check(), True)
 
     def test_crash_before_compact_segments(self):
@@ -174,7 +176,6 @@ class RepositoryCommitTestCase(RepositoryTestCaseBase):
         self.assert_equal(len(self.repository), 3)
 
 
-
 class RepositoryCheckTestCase(RepositoryTestCaseBase):
 
     def list_indices(self):
@@ -249,10 +250,6 @@ class RepositoryCheckTestCase(RepositoryTestCaseBase):
     def test_repair_missing_commit_segment(self):
         self.add_objects([[1, 2, 3], [4, 5, 6]])
         self.delete_segment(1)
-        self.assert_raises(Repository.CheckNeeded, lambda: self.get_objects(4))
-        self.check(status=False)
-        self.assert_raises(Repository.CheckNeeded, lambda: self.get_objects(4))
-        self.check(repair=True, status=True)
         self.assert_raises(Repository.DoesNotExist, lambda: self.get_objects(4))
         self.assert_equal(set([1, 2, 3]), self.list_objects())
 
@@ -261,11 +258,9 @@ class RepositoryCheckTestCase(RepositoryTestCaseBase):
         with open(os.path.join(self.tmppath, 'repository', 'data', '0', '1'), 'r+b') as fd:
             fd.seek(-1, os.SEEK_END)
             fd.write(b'X')
-        self.assert_raises(Repository.CheckNeeded, lambda: self.get_objects(4))
-        self.check(status=False)
-        self.check(repair=True, status=True)
-        self.get_objects(3)
         self.assert_raises(Repository.DoesNotExist, lambda: self.get_objects(4))
+        self.check(status=True)
+        self.get_objects(3)
         self.assert_equal(set([1, 2, 3]), self.list_objects())
 
     def test_repair_no_commits(self):
@@ -286,8 +281,6 @@ class RepositoryCheckTestCase(RepositoryTestCaseBase):
     def test_repair_missing_index(self):
         self.add_objects([[1, 2, 3], [4, 5, 6]])
         self.delete_index()
-        self.check(status=False)
-        self.check(repair=True, status=True)
         self.check(status=True)
         self.get_objects(4)
         self.assert_equal(set([1, 2, 3, 4, 5, 6]), self.list_objects())
@@ -296,12 +289,8 @@ class RepositoryCheckTestCase(RepositoryTestCaseBase):
         self.add_objects([[1, 2, 3], [4, 5, 6]])
         self.assert_equal(self.list_indices(), ['index.1'])
         self.rename_index('index.100')
-        self.assert_equal(self.list_indices(), ['index.100'])
-        self.assert_raises(Repository.CheckNeeded, lambda: self.get_objects(4))
-        self.check(status=False)
-        self.check(repair=True, status=True)
-        self.assert_equal(self.list_indices(), ['index.1'])
         self.check(status=True)
+        self.assert_equal(self.list_indices(), ['index.1'])
         self.get_objects(4)
         self.assert_equal(set([1, 2, 3, 4, 5, 6]), self.list_objects())