Browse Source

Initial check --repair support

Jonas Borgström 11 years ago
parent
commit
33b58eac82
5 changed files with 128 additions and 100 deletions
  1. 4 1
      attic/archiver.py
  2. 5 2
      attic/remote.py
  3. 104 92
      attic/repository.py
  4. 2 3
      attic/testsuite/archiver.py
  5. 13 2
      attic/testsuite/repository.py

+ 4 - 1
attic/archiver.py

@@ -66,7 +66,7 @@ class Archiver:
         repository = self.open_repository(args.repository)
         repository = self.open_repository(args.repository)
         if args.progress is None:
         if args.progress is None:
             args.progress = is_a_terminal(sys.stdout) or args.verbose
             args.progress = is_a_terminal(sys.stdout) or args.verbose
-        if not repository.check(progress=args.progress):
+        if not repository.check(progress=args.progress, repair=args.repair):
             self.exit_code = 1
             self.exit_code = 1
         return self.exit_code
         return self.exit_code
 
 
@@ -390,6 +390,9 @@ class Archiver:
                                help='Report progress status to standard output stream')
                                help='Report progress status to standard output stream')
         subparser.add_argument('--no-progress', dest='progress', action='store_false',
         subparser.add_argument('--no-progress', dest='progress', action='store_false',
                                help='Disable progress reporting')
                                help='Disable progress reporting')
+        subparser.add_argument('--repair', dest='repair', action='store_true',
+                               default=False,
+                               help='Attempt to repair any inconsistencies found')
 
 
         subparser = subparsers.add_parser('change-passphrase', parents=[common_parser],
         subparser = subparsers.add_parser('change-passphrase', parents=[common_parser],
                                           description=self.do_change_passphrase.__doc__)
                                           description=self.do_change_passphrase.__doc__)

+ 5 - 2
attic/remote.py

@@ -180,8 +180,8 @@ class RemoteRepository(object):
                     w_fds = []
                     w_fds = []
         self.ignore_responses |= set(waiting_for)
         self.ignore_responses |= set(waiting_for)
 
 
-    def check(self, progress=False):
-        return self.call('check', progress)
+    def check(self, progress=False, repair=False):
+        return self.call('check', progress, repair)
 
 
     def commit(self, *args):
     def commit(self, *args):
         return self.call('commit')
         return self.call('commit')
@@ -189,6 +189,9 @@ class RemoteRepository(object):
     def rollback(self, *args):
     def rollback(self, *args):
         return self.call('rollback')
         return self.call('rollback')
 
 
+    def __len__(self):
+        return self.call('__len__')
+
     def get(self, id_):
     def get(self, id_):
         for resp in self.get_many([id_]):
         for resp in self.get_many([id_]):
             return resp
             return resp

+ 104 - 92
attic/repository.py

@@ -51,6 +51,7 @@ class Repository(object):
         self.io = None
         self.io = None
         self.lock = None
         self.lock = None
         self.index = None
         self.index = None
+        self._active_txn = False
         if create:
         if create:
             self.create(path)
             self.create(path)
         self.open(path)
         self.open(path)
@@ -84,6 +85,13 @@ class Repository(object):
         else:
         else:
             return None
             return None
 
 
+    def get_transaction_id(self):
+        index_transaction_id = self.get_index_transaction_id()
+        segments_transaction_id = self.io.get_segments_transaction_id(index_transaction_id or 0)
+        if index_transaction_id != segments_transaction_id:
+            raise self.CheckNeeded(self.path)
+        return index_transaction_id
+
     def open(self, path):
     def open(self, path):
         self.path = path
         self.path = path
         if not os.path.isdir(path):
         if not os.path.isdir(path):
@@ -96,7 +104,7 @@ class Repository(object):
         self.max_segment_size = self.config.getint('repository', 'max_segment_size')
         self.max_segment_size = self.config.getint('repository', 'max_segment_size')
         self.segments_per_dir = self.config.getint('repository', 'segments_per_dir')
         self.segments_per_dir = self.config.getint('repository', 'segments_per_dir')
         self.id = unhexlify(self.config.get('repository', 'id').strip())
         self.id = unhexlify(self.config.get('repository', 'id').strip())
-        self.rollback()
+        self.io = LoggedIO(self.path, self.max_segment_size, self.segments_per_dir)
 
 
     def close(self):
     def close(self):
         if self.lock:
         if self.lock:
@@ -114,22 +122,23 @@ class Repository(object):
         self.write_index()
         self.write_index()
         self.rollback()
         self.rollback()
 
 
-    def open_index(self, head, read_only=False):
-        if head is None:
-            self.lock.upgrade()
+    def get_read_only_index(self, transaction_id):
+        if transaction_id is None:
+            return {}
+        return NSIndex((os.path.join(self.path, 'index.%d') % transaction_id).encode('utf-8'), readonly=True)
+
+    def get_index(self, transaction_id):
+        self.lock.upgrade()
+        if transaction_id is None:
             self.index = NSIndex.create(os.path.join(self.path, 'index.tmp').encode('utf-8'))
             self.index = NSIndex.create(os.path.join(self.path, 'index.tmp').encode('utf-8'))
             self.segments = {}
             self.segments = {}
             self.compact = set()
             self.compact = set()
         else:
         else:
-            if read_only:
-                self.index = NSIndex((os.path.join(self.path, 'index.%d') % head).encode('utf-8'), readonly=True)
-            else:
-                self.lock.upgrade()
-                self.io.cleanup()
-                shutil.copy(os.path.join(self.path, 'index.%d' % head),
-                            os.path.join(self.path, 'index.tmp'))
-                self.index = NSIndex(os.path.join(self.path, 'index.tmp').encode('utf-8'))
-            hints = read_msgpack(os.path.join(self.path, 'hints.%d' % head))
+            self.io.cleanup(transaction_id)
+            shutil.copy(os.path.join(self.path, 'index.%d' % transaction_id),
+                        os.path.join(self.path, 'index.tmp'))
+            self.index = NSIndex(os.path.join(self.path, 'index.tmp').encode('utf-8'))
+            hints = read_msgpack(os.path.join(self.path, 'hints.%d' % transaction_id))
             if hints[b'version'] != 1:
             if hints[b'version'] != 1:
                 raise ValueError('Unknown hints file version: %d' % hints['version'])
                 raise ValueError('Unknown hints file version: %d' % hints['version'])
             self.segments = hints[b'segments']
             self.segments = hints[b'segments']
@@ -139,12 +148,13 @@ class Repository(object):
         hints = {b'version': 1,
         hints = {b'version': 1,
                  b'segments': self.segments,
                  b'segments': self.segments,
                  b'compact': list(self.compact)}
                  b'compact': list(self.compact)}
-        write_msgpack(os.path.join(self.path, 'hints.%d' % self.io.head), hints)
+        transaction_id = self.io.get_segments_transaction_id()
+        write_msgpack(os.path.join(self.path, 'hints.%d' % transaction_id), hints)
         self.index.flush()
         self.index.flush()
         os.rename(os.path.join(self.path, 'index.tmp'),
         os.rename(os.path.join(self.path, 'index.tmp'),
-                  os.path.join(self.path, 'index.%d' % self.io.head))
+                  os.path.join(self.path, 'index.%d' % transaction_id))
         # Remove old indices
         # Remove old indices
-        current = '.%d' % self.io.head
+        current = '.%d' % transaction_id
         for name in os.listdir(self.path):
         for name in os.listdir(self.path):
             if not name.startswith('index.') and not name.startswith('hints.'):
             if not name.startswith('index.') and not name.startswith('hints.'):
                 continue
                 continue
@@ -176,102 +186,95 @@ class Repository(object):
             self.io.delete_segment(segment)
             self.io.delete_segment(segment)
         self.compact = set()
         self.compact = set()
 
 
-    def recover(self, path):
-        """Recover missing index by replaying logs"""
-        start = None
-        available = self._available_indices()
-        if available:
-            start = available[-1]
-        self.open_index(start)
-        for segment, filename in self.io._segment_names():
-            if start is not None and segment <= start:
-                continue
-            self.segments[segment] = 0
-            for tag, key, offset in self.io.iter_objects(segment):
-                if tag == TAG_PUT:
-                    try:
-                        s, _ = self.index[key]
-                        self.compact.add(s)
-                        self.segments[s] -= 1
-                    except KeyError:
-                        pass
-                    self.index[key] = segment, offset
-                    self.segments[segment] += 1
-                elif tag == TAG_DELETE:
-                    try:
-                        s, _ = self.index.pop(key)
-                        self.segments[s] -= 1
-                        self.compact.add(s)
-                        self.compact.add(segment)
-                    except KeyError:
-                        pass
-            if self.segments[segment] == 0:
-                self.compact.add(segment)
-        if self.io.head is not None:
-            self.write_index()
-
-    def check(self, progress=False):
+    def check(self, progress=False, repair=False):
         """Check repository consistency
         """Check repository consistency
 
 
         This method verifies all segment checksums and makes sure
         This method verifies all segment checksums and makes sure
         the index is consistent with the data stored in the segments.
         the index is consistent with the data stored in the segments.
         """
         """
-        if not self.index:
-            self.open_index(self.io.head, read_only=True)
+        assert not self._active_txn
+        assert not self.index
+        index_transaction_id = self.get_index_transaction_id()
+        segments_transaction_id = self.io.get_segments_transaction_id(index_transaction_id)
+        if index_transaction_id is None and segments_transaction_id is None:
+            return True
+        transaction_id = max(index_transaction_id or 0, segments_transaction_id or 0)
+        self.get_index(None)
+        if index_transaction_id == segments_transaction_id:
+            current_index = self.get_read_only_index(transaction_id)
+        else:
+            current_index = None
         progress_time = None
         progress_time = None
         error_found = False
         error_found = False
+
         def report_progress(msg, error=False):
         def report_progress(msg, error=False):
             nonlocal error_found
             nonlocal error_found
             if error:
             if error:
                 error_found = True
                 error_found = True
             if error or progress:
             if error or progress:
                 print(msg, file=sys.stderr)
                 print(msg, file=sys.stderr)
-        seen = set()
+
         for segment, filename in self.io.segment_iterator():
         for segment, filename in self.io.segment_iterator():
+            if segment > transaction_id:
+                continue
             if progress:
             if progress:
                 if int(time.time()) != progress_time:
                 if int(time.time()) != progress_time:
                     progress_time = int(time.time())
                     progress_time = int(time.time())
-                    report_progress('Checking segment {}/{}'.format(segment, self.io.head))
+                    report_progress('Checking segment {}/{}'.format(segment, transaction_id))
             try:
             try:
                 objects = list(self.io.iter_objects(segment))
                 objects = list(self.io.iter_objects(segment))
             except (IntegrityError, struct.error):
             except (IntegrityError, struct.error):
                 report_progress('Error reading segment {}'.format(segment), error=True)
                 report_progress('Error reading segment {}'.format(segment), error=True)
                 objects = []
                 objects = []
+                if repair:
+                    self.io.recover_segment(segment, filename)
+                    objects = list(self.io.iter_objects(segment))
+            self.segments[segment] = 0
             for tag, key, offset in objects:
             for tag, key, offset in objects:
                 if tag == TAG_PUT:
                 if tag == TAG_PUT:
-                    if key in seen:
+                    try:
+                        s, _ = self.index[key]
+                        self.compact.add(s)
+                        self.segments[s] -= 1
                         report_progress('Key found in more than one segment. Segment={}, key={}'.format(segment, hexlify(key)), error=True)
                         report_progress('Key found in more than one segment. Segment={}, key={}'.format(segment, hexlify(key)), error=True)
-                    seen.add(key)
-                    if self.index.get(key, (0, 0)) != (segment, offset):
-                        report_progress('Index vs segment header mismatch. Segment={}, key={}'.format(segment, hexlify(key)), error=True)
+                    except KeyError:
+                        pass
+                    self.index[key] = segment, offset
+                    self.segments[segment] += 1
+                elif tag == TAG_DELETE:
+                    try:
+                        s, _ = self.index.pop(key)
+                        self.segments[s] -= 1
+                        self.compact.add(s)
+                        self.compact.add(segment)
+                    except KeyError:
+                        pass
                 elif tag == TAG_COMMIT:
                 elif tag == TAG_COMMIT:
                     continue
                     continue
                 else:
                 else:
                     report_progress('Unexpected tag {} in segment {}'.format(tag, segment), error=True)
                     report_progress('Unexpected tag {} in segment {}'.format(tag, segment), error=True)
-        if len(self.index) != len(seen):
-            report_progress('Index object count mismatch. {} != {}'.format(len(self.index), len(seen)), error=True)
+        if current_index and len(current_index) != len(self.index):
+            report_progress('Index object count mismatch. {} != {}'.format(len(current_index), len(self.index)), error=True)
         if not error_found:
         if not error_found:
             report_progress('Check complete, no errors found.')
             report_progress('Check complete, no errors found.')
+        if repair:
+            self.write_index()
         return not error_found
         return not error_found
 
 
     def rollback(self):
     def rollback(self):
         """
         """
         """
         """
-        if self.io:
-            self.io.close()
-            self.io = None
         self.index = None
         self.index = None
         self._active_txn = False
         self._active_txn = False
-        self.io = LoggedIO(self.path, self.max_segment_size, self.segments_per_dir, self.get_index_transaction_id())
 
 
-    def _len(self):
+    def __len__(self):
         if not self.index:
         if not self.index:
-            self.open_index(self.io.head, read_only=True)
+            self.index = self.get_read_only_index(self.get_transaction_id())
         return len(self.index)
         return len(self.index)
 
 
     def get(self, id_):
     def get(self, id_):
         if not self.index:
         if not self.index:
-            self.open_index(self.io.head, read_only=True)
+            self.index = self.get_read_only_index(self.get_transaction_id())
         try:
         try:
             segment, offset = self.index[id_]
             segment, offset = self.index[id_]
             return self.io.read(segment, offset, id_)
             return self.io.read(segment, offset, id_)
@@ -284,8 +287,8 @@ class Repository(object):
 
 
     def put(self, id, data, wait=True):
     def put(self, id, data, wait=True):
         if not self._active_txn:
         if not self._active_txn:
+            self.get_index(self.get_transaction_id())
             self._active_txn = True
             self._active_txn = True
-            self.open_index(self.io.head)
         try:
         try:
             segment, _ = self.index[id]
             segment, _ = self.index[id]
             self.segments[segment] -= 1
             self.segments[segment] -= 1
@@ -302,8 +305,8 @@ class Repository(object):
 
 
     def delete(self, id, wait=True):
     def delete(self, id, wait=True):
         if not self._active_txn:
         if not self._active_txn:
+            self.get_index(self.get_transaction_id())
             self._active_txn = True
             self._active_txn = True
-            self.open_index(self.io.head)
         try:
         try:
             segment, offset = self.index.pop(id)
             segment, offset = self.index.pop(id)
             self.segments[segment] -= 1
             self.segments[segment] -= 1
@@ -333,16 +336,14 @@ class LoggedIO(object):
     _commit = header_no_crc_fmt.pack(9, TAG_COMMIT)
     _commit = header_no_crc_fmt.pack(9, TAG_COMMIT)
     COMMIT = crc_fmt.pack(crc32(_commit)) + _commit
     COMMIT = crc_fmt.pack(crc32(_commit)) + _commit
 
 
-    def __init__(self, path, limit, segments_per_dir, latest_index, capacity=100):
+    def __init__(self, path, limit, segments_per_dir, capacity=100):
         self.path = path
         self.path = path
         self.fds = LRUCache(capacity)
         self.fds = LRUCache(capacity)
-        self.segment = None
+        self.segment = 0
         self.limit = limit
         self.limit = limit
         self.segments_per_dir = segments_per_dir
         self.segments_per_dir = segments_per_dir
         self.offset = 0
         self.offset = 0
         self._write_fd = None
         self._write_fd = None
-        self.head = None
-        self.verify_segments_head(latest_index)
 
 
     def close(self):
     def close(self):
         for segment in list(self.fds.keys()):
         for segment in list(self.fds.keys()):
@@ -357,31 +358,23 @@ class LoggedIO(object):
             for filename in filenames:
             for filename in filenames:
                 yield int(filename), os.path.join(dirpath, filename)
                 yield int(filename), os.path.join(dirpath, filename)
 
 
-    def verify_segments_head(self, latest_index):
+    def get_segments_transaction_id(self, index_transaction_id=0):
         """Verify that the transaction id is consistent with the index transaction id
         """Verify that the transaction id is consistent with the index transaction id
         """
         """
-        self.segment = 0
         for segment, filename in self.segment_iterator(reverse=True):
         for segment, filename in self.segment_iterator(reverse=True):
-            if latest_index is None or segment < latest_index:
+            if segment < index_transaction_id:
                 # The index is newer than any committed transaction found
                 # The index is newer than any committed transaction found
-                raise Repository.CheckNeeded()
+                return -1
             if self.is_committed_segment(filename):
             if self.is_committed_segment(filename):
-                if segment > latest_index:
-                    # The committed transaction found is newer than the index
-                    raise Repository.CheckNeeded()
-                self.head = segment
-                self.segment = self.head + 1
-                break
-        else:
-            if latest_index is not None:
-                # An index has been found but no committed transaction
-                raise Repository.CheckNeeded()
+                return segment
+        return None
 
 
-    def cleanup(self):
+    def cleanup(self, transaction_id):
         """Delete segment files left by aborted transactions
         """Delete segment files left by aborted transactions
         """
         """
+        self.segment = transaction_id + 1
         for segment, filename in self.segment_iterator(reverse=True):
         for segment, filename in self.segment_iterator(reverse=True):
-            if segment > self.head:
+            if segment > transaction_id:
                 os.unlink(filename)
                 os.unlink(filename)
             else:
             else:
                 break
                 break
@@ -456,8 +449,28 @@ class LoggedIO(object):
             offset += size
             offset += size
             header = fd.read(self.header_fmt.size)
             header = fd.read(self.header_fmt.size)
 
 
+    def recover_segment(self, segment, filename):
+        self.fds.pop(segment).close()
+        # FIXME: save a copy of the original file
+        with open(filename, 'rb') as fd:
+            data = memoryview(fd.read())
+        os.rename(filename, filename + '.beforerecover')
+        print('attempting to recover ' + filename, file=sys.stderr)
+        with open(filename, 'wb') as fd:
+            fd.write(MAGIC)
+            while len(data) >= self.header_fmt.size:
+                crc, size, tag = self.header_fmt.unpack(data[:self.header_fmt.size])
+                if size > len(data):
+                    data = data[1:]
+                    continue
+                if crc32(data[4:size]) & 0xffffffff != crc:
+                    data = data[1:]
+                    continue
+                fd.write(data[:size])
+                data = data[size:]
+
     def read(self, segment, offset, id):
     def read(self, segment, offset, id):
-        if segment == self.segment:
+        if segment == self.segment and self._write_fd:
             self._write_fd.flush()
             self._write_fd.flush()
         fd = self.get_fd(segment)
         fd = self.get_fd(segment)
         fd.seek(offset)
         fd.seek(offset)
@@ -495,7 +508,6 @@ class LoggedIO(object):
         header = self.header_no_crc_fmt.pack(self.header_fmt.size, TAG_COMMIT)
         header = self.header_no_crc_fmt.pack(self.header_fmt.size, TAG_COMMIT)
         crc = self.crc_fmt.pack(crc32(header) & 0xffffffff)
         crc = self.crc_fmt.pack(crc32(header) & 0xffffffff)
         fd.write(b''.join((crc, header)))
         fd.write(b''.join((crc, header)))
-        self.head = self.segment
         self.close_segment()
         self.close_segment()
 
 
     def close_segment(self):
     def close_segment(self):

+ 2 - 3
attic/testsuite/archiver.py

@@ -205,7 +205,7 @@ class ArchiverTestCase(AtticTestCase):
         self.attic('delete', self.repository_location + '::test.2')
         self.attic('delete', self.repository_location + '::test.2')
         # Make sure all data except the manifest has been deleted
         # Make sure all data except the manifest has been deleted
         repository = Repository(self.repository_path)
         repository = Repository(self.repository_path)
-        self.assert_equal(repository._len(), 1)
+        self.assert_equal(len(repository), 1)
 
 
     def test_corrupted_repository(self):
     def test_corrupted_repository(self):
         self.attic('init', self.repository_location)
         self.attic('init', self.repository_location)
@@ -269,8 +269,7 @@ class ArchiverTestCase(AtticTestCase):
 
 
         def verify_uniqueness():
         def verify_uniqueness():
             repository = Repository(self.repository_path)
             repository = Repository(self.repository_path)
-            repository.open_index(repository.io.head)
-            for key, _ in repository.index.iteritems():
+            for key, _ in repository.get_read_only_index(repository.get_transaction_id()).iteritems():
                 data = repository.get(key)
                 data = repository.get(key)
                 hash = sha256(data).digest()
                 hash = sha256(data).digest()
                 if not hash in seen:
                 if not hash in seen:

+ 13 - 2
attic/testsuite/repository.py

@@ -112,7 +112,7 @@ class RepositoryCheckTestCase(AtticTestCase):
         self.repository.commit()
         self.repository.commit()
 
 
     def get_head(self):
     def get_head(self):
-        return sorted(int(n) for n in os.listdir(os.path.join(self.tmppath, 'repository', 'data', '0')))[-1]
+        return sorted(int(n) for n in os.listdir(os.path.join(self.tmppath, 'repository', 'data', '0')) if n.isdigit())[-1]
 
 
     def open_index(self):
     def open_index(self):
         return NSIndex(os.path.join(self.tmppath, 'repository', 'index.{}'.format(self.get_head())))
         return NSIndex(os.path.join(self.tmppath, 'repository', 'index.{}'.format(self.get_head())))
@@ -137,12 +137,23 @@ class RepositoryCheckTestCase(AtticTestCase):
         self.assert_equal(False, self.repository.check())
         self.assert_equal(False, self.repository.check())
         self.assert_equal(set([1, 2, 3, 4, 5, 6]), self.list_objects())
         self.assert_equal(set([1, 2, 3, 4, 5, 6]), self.list_objects())
 
 
+    def test_check_repair(self):
+        self.add_objects([1, 2, 3])
+        self.add_objects([4, 5, 6])
+        self.assert_equal(set([1, 2, 3, 4, 5, 6]), self.list_objects())
+        self.assert_equal(True, self.repository.check())
+        self.corrupt_object(5)
+        self.reopen()
+        self.assert_equal(False, self.repository.check(repair=True))
+        self.assert_equal(set([1, 2, 3, 4, 6]), self.list_objects())
+
+
     def test_check_missing_or_corrupt_commit_tag(self):
     def test_check_missing_or_corrupt_commit_tag(self):
         self.add_objects([1, 2, 3])
         self.add_objects([1, 2, 3])
         self.assert_equal(set([1, 2, 3]), self.list_objects())
         self.assert_equal(set([1, 2, 3]), self.list_objects())
         with open(os.path.join(self.tmppath, 'repository', 'data', '0', str(self.get_head())), 'ab') as fd:
         with open(os.path.join(self.tmppath, 'repository', 'data', '0', str(self.get_head())), 'ab') as fd:
             fd.write(b'X')
             fd.write(b'X')
-        self.assert_raises(Repository.CheckNeeded, self.reopen)
+        self.assert_raises(Repository.CheckNeeded, lambda: self.repository.get(bytes(32)))
 
 
 class RemoteRepositoryTestCase(RepositoryTestCase):
 class RemoteRepositoryTestCase(RepositoryTestCase):