2
0
Эх сурвалжийг харах

file_integrity: split in IntegrityCheckedFile + Detached variant

Marian Beermann 8 жил өмнө
parent
commit
39051ac5f1

+ 51 - 28
src/borg/crypto/file_integrity.py

@@ -104,7 +104,7 @@ class FileIntegrityError(IntegrityError):
 
 
 class IntegrityCheckedFile(FileLikeWrapper):
-    def __init__(self, path, write, filename=None, override_fd=None):
+    def __init__(self, path, write, filename=None, override_fd=None, integrity_data=None):
         self.path = path
         self.writing = write
         mode = 'wb' if write else 'rb'
@@ -114,10 +114,10 @@ class IntegrityCheckedFile(FileLikeWrapper):
 
         self.hash_filename(filename)
 
-        if write:
+        if write or not integrity_data:
             self.digests = {}
         else:
-            self.digests = self.read_integrity_file(path, self.hasher)
+            self.digests = self.parse_integrity_data(path, integrity_data, self.hasher)
             # TODO: When we're reading but don't have any digests, i.e. no integrity file existed,
             # TODO: then we could just short-circuit.
 
@@ -126,32 +126,27 @@ class IntegrityCheckedFile(FileLikeWrapper):
         # In Borg the name itself encodes the context (eg. index.N, cache, files),
         # while the path doesn't matter, and moving e.g. a repository or cache directory is supported.
         # Changing the name however imbues a change of context that is not permissible.
+        # While Borg does not use anything except ASCII in these file names, it's important to use
+        # the same encoding everywhere for portability. Using os.fsencode() would be wrong.
         filename = os.path.basename(filename or self.path)
         self.hasher.update(('%10d' % len(filename)).encode())
         self.hasher.update(filename.encode())
 
-    @staticmethod
-    def integrity_file_path(path):
-        return path + '.integrity'
-
     @classmethod
-    def read_integrity_file(cls, path, hasher):
+    def parse_integrity_data(cls, path: str, data: str, hasher: SHA512FileHashingWrapper):
         try:
-            with open(cls.integrity_file_path(path), 'r') as fd:
-                integrity_file = json.load(fd)
-                # Provisions for agility now, implementation later, but make sure the on-disk joint is oiled.
-                algorithm = integrity_file['algorithm']
-                if algorithm != hasher.ALGORITHM:
-                    logger.warning('Cannot verify integrity of %s: Unknown algorithm %r', path, algorithm)
-                    return
-                digests = integrity_file['digests']
-                # Require at least presence of the final digest
-                digests['final']
-                return digests
-        except FileNotFoundError:
-            logger.info('No integrity file found for %s', path)
-        except (OSError, ValueError, TypeError, KeyError) as e:
-            logger.warning('Could not read integrity file for %s: %s', path, e)
+            integrity_file = json.loads(data)
+            # Provisions for agility now, implementation later, but make sure the on-disk joint is oiled.
+            algorithm = integrity_file['algorithm']
+            if algorithm != hasher.ALGORITHM:
+                logger.warning('Cannot verify integrity of %s: Unknown algorithm %r', path, algorithm)
+                return
+            digests = integrity_file['digests']
+            # Require at least presence of the final digest
+            digests['final']
+            return digests
+        except (ValueError, TypeError, KeyError) as e:
+            logger.warning('Could not parse integrity data for %s: %s', path, e)
             raise FileIntegrityError(path)
 
     def hash_part(self, partname, is_final=False):
@@ -173,10 +168,38 @@ class IntegrityCheckedFile(FileLikeWrapper):
         if exception:
             return
         if self.writing:
-            with open(self.integrity_file_path(self.path), 'w') as fd:
-                json.dump({
-                    'algorithm': self.hasher.ALGORITHM,
-                    'digests': self.digests,
-                }, fd)
+            self.store_integrity_data(json.dumps({
+                'algorithm': self.hasher.ALGORITHM,
+                'digests': self.digests,
+            }))
         elif self.digests:
             logger.debug('Verified integrity of %s', self.path)
+
+    def store_integrity_data(self, data: str):
+        self.integrity_data = data
+
+
+class DetachedIntegrityCheckedFile(IntegrityCheckedFile):
+    def __init__(self, path, write, filename=None, override_fd=None):
+        super().__init__(path, write, filename, override_fd)
+        if not write:
+            self.digests = self.read_integrity_file(self.path, self.hasher)
+
+    @staticmethod
+    def integrity_file_path(path):
+        return path + '.integrity'
+
+    @classmethod
+    def read_integrity_file(cls, path, hasher):
+        try:
+            with open(cls.integrity_file_path(path), 'r') as fd:
+                return cls.parse_integrity_data(path, fd.read(), hasher)
+        except FileNotFoundError:
+            logger.info('No integrity file found for %s', path)
+        except OSError as e:
+            logger.warning('Could not read integrity file for %s: %s', path, e)
+            raise FileIntegrityError(path)
+
+    def store_integrity_data(self, data: str):
+        with open(self.integrity_file_path(self.path), 'w') as fd:
+            fd.write(data)

+ 19 - 19
src/borg/testsuite/file_integrity.py

@@ -1,21 +1,21 @@
 
 import pytest
 
-from ..crypto.file_integrity import IntegrityCheckedFile, FileIntegrityError
+from ..crypto.file_integrity import IntegrityCheckedFile, DetachedIntegrityCheckedFile, FileIntegrityError
 
 
 class TestReadIntegrityFile:
     def test_no_integrity(self, tmpdir):
         protected_file = tmpdir.join('file')
         protected_file.write('1234')
-        assert IntegrityCheckedFile.read_integrity_file(str(protected_file), None) is None
+        assert DetachedIntegrityCheckedFile.read_integrity_file(str(protected_file), None) is None
 
     def test_truncated_integrity(self, tmpdir):
         protected_file = tmpdir.join('file')
         protected_file.write('1234')
         tmpdir.join('file.integrity').write('')
         with pytest.raises(FileIntegrityError):
-            IntegrityCheckedFile.read_integrity_file(str(protected_file), None)
+            DetachedIntegrityCheckedFile.read_integrity_file(str(protected_file), None)
 
     def test_unknown_algorithm(self, tmpdir):
         class SomeHasher:
@@ -24,7 +24,7 @@ class TestReadIntegrityFile:
         protected_file = tmpdir.join('file')
         protected_file.write('1234')
         tmpdir.join('file.integrity').write('{"algorithm": "HMAC_SERIOUSHASH", "digests": "1234"}')
-        assert IntegrityCheckedFile.read_integrity_file(str(protected_file), SomeHasher()) is None
+        assert DetachedIntegrityCheckedFile.read_integrity_file(str(protected_file), SomeHasher()) is None
 
     @pytest.mark.parametrize('json', (
         '{"ALGORITHM": "HMAC_SERIOUSHASH", "digests": "1234"}',
@@ -38,7 +38,7 @@ class TestReadIntegrityFile:
         protected_file.write('1234')
         tmpdir.join('file.integrity').write(json)
         with pytest.raises(FileIntegrityError):
-            IntegrityCheckedFile.read_integrity_file(str(protected_file), None)
+            DetachedIntegrityCheckedFile.read_integrity_file(str(protected_file), None)
 
     def test_valid(self, tmpdir):
         class SomeHasher:
@@ -47,35 +47,35 @@ class TestReadIntegrityFile:
         protected_file = tmpdir.join('file')
         protected_file.write('1234')
         tmpdir.join('file.integrity').write('{"algorithm": "HMAC_FOO1", "digests": {"final": "1234"}}')
-        assert IntegrityCheckedFile.read_integrity_file(str(protected_file), SomeHasher()) == {'final': '1234'}
+        assert DetachedIntegrityCheckedFile.read_integrity_file(str(protected_file), SomeHasher()) == {'final': '1234'}
 
 
-class TestIntegrityCheckedFile:
+class TestDetachedIntegrityCheckedFile:
     @pytest.fixture
     def integrity_protected_file(self, tmpdir):
         path = str(tmpdir.join('file'))
-        with IntegrityCheckedFile(path, write=True) as fd:
+        with DetachedIntegrityCheckedFile(path, write=True) as fd:
             fd.write(b'foo and bar')
         return path
 
     def test_simple(self, tmpdir, integrity_protected_file):
         assert tmpdir.join('file').check(file=True)
         assert tmpdir.join('file.integrity').check(file=True)
-        with IntegrityCheckedFile(integrity_protected_file, write=False) as fd:
+        with DetachedIntegrityCheckedFile(integrity_protected_file, write=False) as fd:
             assert fd.read() == b'foo and bar'
 
     def test_corrupted_file(self, integrity_protected_file):
         with open(integrity_protected_file, 'ab') as fd:
             fd.write(b' extra data')
         with pytest.raises(FileIntegrityError):
-            with IntegrityCheckedFile(integrity_protected_file, write=False) as fd:
+            with DetachedIntegrityCheckedFile(integrity_protected_file, write=False) as fd:
                 assert fd.read() == b'foo and bar extra data'
 
     def test_corrupted_file_partial_read(self, integrity_protected_file):
         with open(integrity_protected_file, 'ab') as fd:
             fd.write(b' extra data')
         with pytest.raises(FileIntegrityError):
-            with IntegrityCheckedFile(integrity_protected_file, write=False) as fd:
+            with DetachedIntegrityCheckedFile(integrity_protected_file, write=False) as fd:
                 data = b'foo and bar'
                 assert fd.read(len(data)) == data
 
@@ -88,7 +88,7 @@ class TestIntegrityCheckedFile:
         tmpdir.join('file').move(new_path)
         tmpdir.join('file.integrity').move(new_path + '.integrity')
         with pytest.raises(FileIntegrityError):
-            with IntegrityCheckedFile(str(new_path), write=False) as fd:
+            with DetachedIntegrityCheckedFile(str(new_path), write=False) as fd:
                 assert fd.read() == b'foo and bar'
 
     def test_moved_file(self, tmpdir, integrity_protected_file):
@@ -96,27 +96,27 @@ class TestIntegrityCheckedFile:
         tmpdir.join('file').move(new_dir.join('file'))
         tmpdir.join('file.integrity').move(new_dir.join('file.integrity'))
         new_path = str(new_dir.join('file'))
-        with IntegrityCheckedFile(new_path, write=False) as fd:
+        with DetachedIntegrityCheckedFile(new_path, write=False) as fd:
             assert fd.read() == b'foo and bar'
 
     def test_no_integrity(self, tmpdir, integrity_protected_file):
         tmpdir.join('file.integrity').remove()
-        with IntegrityCheckedFile(integrity_protected_file, write=False) as fd:
+        with DetachedIntegrityCheckedFile(integrity_protected_file, write=False) as fd:
             assert fd.read() == b'foo and bar'
 
 
-class TestIntegrityCheckedFileParts:
+class TestDetachedIntegrityCheckedFileParts:
     @pytest.fixture
     def integrity_protected_file(self, tmpdir):
         path = str(tmpdir.join('file'))
-        with IntegrityCheckedFile(path, write=True) as fd:
+        with DetachedIntegrityCheckedFile(path, write=True) as fd:
             fd.write(b'foo and bar')
             fd.hash_part('foopart')
             fd.write(b' other data')
         return path
 
     def test_simple(self, integrity_protected_file):
-        with IntegrityCheckedFile(integrity_protected_file, write=False) as fd:
+        with DetachedIntegrityCheckedFile(integrity_protected_file, write=False) as fd:
             data1 = b'foo and bar'
             assert fd.read(len(data1)) == data1
             fd.hash_part('foopart')
@@ -127,7 +127,7 @@ class TestIntegrityCheckedFileParts:
             # Because some hash_part failed, the final digest will fail as well - again - even if we catch
             # the failing hash_part. This is intentional: (1) it makes the code simpler (2) it's a good fail-safe
             # against overly broad exception handling.
-            with IntegrityCheckedFile(integrity_protected_file, write=False) as fd:
+            with DetachedIntegrityCheckedFile(integrity_protected_file, write=False) as fd:
                 data1 = b'foo and bar'
                 assert fd.read(len(data1)) == data1
                 with pytest.raises(FileIntegrityError):
@@ -140,7 +140,7 @@ class TestIntegrityCheckedFileParts:
         with open(integrity_protected_file, 'ab') as fd:
             fd.write(b'some extra stuff that does not belong')
         with pytest.raises(FileIntegrityError):
-            with IntegrityCheckedFile(integrity_protected_file, write=False) as fd:
+            with DetachedIntegrityCheckedFile(integrity_protected_file, write=False) as fd:
                 data1 = b'foo and bar'
                 try:
                     assert fd.read(len(data1)) == data1