2
0
Эх сурвалжийг харах

RepositoryCache: checksum decrypted cache

Marian Beermann 8 жил өмнө
parent
commit
b544af2af1

+ 44 - 23
src/borg/remote.py

@@ -30,6 +30,7 @@ from .helpers import format_file_size
 from .logger import create_logger, setup_logging
 from .logger import create_logger, setup_logging
 from .repository import Repository, MAX_OBJECT_SIZE, LIST_SCAN_LIMIT
 from .repository import Repository, MAX_OBJECT_SIZE, LIST_SCAN_LIMIT
 from .version import parse_version, format_version
 from .version import parse_version, format_version
+from .algorithms.checksums import xxh64
 
 
 logger = create_logger(__name__)
 logger = create_logger(__name__)
 
 
@@ -1086,6 +1087,9 @@ class RepositoryCache(RepositoryNoCache):
     should return the initial data (as returned by *transform*).
     should return the initial data (as returned by *transform*).
     """
     """
 
 
+    class InvalidateCacheEntry(Exception):
+        pass
+
     def __init__(self, repository, pack=None, unpack=None, transform=None):
     def __init__(self, repository, pack=None, unpack=None, transform=None):
         super().__init__(repository, transform)
         super().__init__(repository, transform)
         self.pack = pack or (lambda data: data)
         self.pack = pack or (lambda data: data)
@@ -1100,6 +1104,7 @@ class RepositoryCache(RepositoryNoCache):
         self.slow_misses = 0
         self.slow_misses = 0
         self.slow_lat = 0.0
         self.slow_lat = 0.0
         self.evictions = 0
         self.evictions = 0
+        self.checksum_errors = 0
         self.enospc = 0
         self.enospc = 0
 
 
     def query_size_limit(self):
     def query_size_limit(self):
@@ -1144,10 +1149,10 @@ class RepositoryCache(RepositoryNoCache):
 
 
     def close(self):
     def close(self):
         logger.debug('RepositoryCache: current items %d, size %s / %s, %d hits, %d misses, %d slow misses (+%.1fs), '
         logger.debug('RepositoryCache: current items %d, size %s / %s, %d hits, %d misses, %d slow misses (+%.1fs), '
-                     '%d evictions, %d ENOSPC hit',
+                     '%d evictions, %d ENOSPC hit, %d checksum errors',
                      len(self.cache), format_file_size(self.size), format_file_size(self.size_limit),
                      len(self.cache), format_file_size(self.size), format_file_size(self.size_limit),
                      self.hits, self.misses, self.slow_misses, self.slow_lat,
                      self.hits, self.misses, self.slow_misses, self.slow_lat,
-                     self.evictions, self.enospc)
+                     self.evictions, self.enospc, self.checksum_errors)
         self.cache.clear()
         self.cache.clear()
         shutil.rmtree(self.basedir)
         shutil.rmtree(self.basedir)
 
 
@@ -1157,30 +1162,37 @@ class RepositoryCache(RepositoryNoCache):
         for key in keys:
         for key in keys:
             if key in self.cache:
             if key in self.cache:
                 file = self.key_filename(key)
                 file = self.key_filename(key)
-                with open(file, 'rb') as fd:
-                    self.hits += 1
-                    yield self.unpack(fd.read())
-            else:
-                for key_, data in repository_iterator:
-                    if key_ == key:
-                        transformed = self.add_entry(key, data, cache)
-                        self.misses += 1
-                        yield transformed
-                        break
-                else:
-                    # slow path: eviction during this get_many removed this key from the cache
-                    t0 = time.perf_counter()
-                    data = self.repository.get(key)
-                    self.slow_lat += time.perf_counter() - t0
+                try:
+                    with open(file, 'rb') as fd:
+                        self.hits += 1
+                        yield self.unpack(fd.read())
+                        continue  # go to the next key
+                except self.InvalidateCacheEntry:
+                    self.cache.remove(key)
+                    self.size -= os.stat(file).st_size
+                    self.checksum_errors += 1
+                    os.unlink(file)
+                    # fall through to fetch the object again
+            for key_, data in repository_iterator:
+                if key_ == key:
                     transformed = self.add_entry(key, data, cache)
                     transformed = self.add_entry(key, data, cache)
-                    self.slow_misses += 1
+                    self.misses += 1
                     yield transformed
                     yield transformed
+                    break
+            else:
+                # slow path: eviction during this get_many removed this key from the cache
+                t0 = time.perf_counter()
+                data = self.repository.get(key)
+                self.slow_lat += time.perf_counter() - t0
+                transformed = self.add_entry(key, data, cache)
+                self.slow_misses += 1
+                yield transformed
         # Consume any pending requests
         # Consume any pending requests
         for _ in repository_iterator:
         for _ in repository_iterator:
             pass
             pass
 
 
 
 
-def cache_if_remote(repository, *, decrypted_cache=False, pack=None, unpack=None, transform=None):
+def cache_if_remote(repository, *, decrypted_cache=False, pack=None, unpack=None, transform=None, force_cache=False):
     """
     """
     Return a Repository(No)Cache for *repository*.
     Return a Repository(No)Cache for *repository*.
 
 
@@ -1194,21 +1206,30 @@ def cache_if_remote(repository, *, decrypted_cache=False, pack=None, unpack=None
         raise ValueError('decrypted_cache and pack/unpack/transform are incompatible')
         raise ValueError('decrypted_cache and pack/unpack/transform are incompatible')
     elif decrypted_cache:
     elif decrypted_cache:
         key = decrypted_cache
         key = decrypted_cache
-        cache_struct = struct.Struct('=I')
+        # 32 bit csize, 64 bit (8 byte) xxh64
+        cache_struct = struct.Struct('=I8s')
         compressor = LZ4()
         compressor = LZ4()
 
 
         def pack(data):
         def pack(data):
-            return cache_struct.pack(data[0]) + compressor.compress(data[1])
+            csize, decrypted = data
+            compressed = compressor.compress(decrypted)
+            return cache_struct.pack(csize, xxh64(compressed)) + compressed
 
 
         def unpack(data):
         def unpack(data):
-            return cache_struct.unpack(data[:cache_struct.size])[0], compressor.decompress(data[cache_struct.size:])
+            data = memoryview(data)
+            csize, checksum = cache_struct.unpack(data[:cache_struct.size])
+            compressed = data[cache_struct.size:]
+            if checksum != xxh64(compressed):
+                logger.warning('Repository metadata cache: detected corrupted data in cache!')
+                raise RepositoryCache.InvalidateCacheEntry
+            return csize, compressor.decompress(compressed)
 
 
         def transform(id_, data):
         def transform(id_, data):
             csize = len(data)
             csize = len(data)
             decrypted = key.decrypt(id_, data)
             decrypted = key.decrypt(id_, data)
             return csize, decrypted
             return csize, decrypted
 
 
-    if isinstance(repository, RemoteRepository):
+    if isinstance(repository, RemoteRepository) or force_cache:
         return RepositoryCache(repository, pack, unpack, transform)
         return RepositoryCache(repository, pack, unpack, transform)
     else:
     else:
         return RepositoryNoCache(repository, transform)
         return RepositoryNoCache(repository, transform)

+ 53 - 1
src/borg/testsuite/remote.py

@@ -1,13 +1,17 @@
 import errno
 import errno
 import os
 import os
+import io
 import time
 import time
 from unittest.mock import patch
 from unittest.mock import patch
 
 
 import pytest
 import pytest
 
 
-from ..remote import SleepingBandwidthLimiter, RepositoryCache
+from ..remote import SleepingBandwidthLimiter, RepositoryCache, cache_if_remote
 from ..repository import Repository
 from ..repository import Repository
+from ..crypto.key import PlaintextKey
+from ..compress import CompressionSpec
 from .hashindex import H
 from .hashindex import H
+from .key import TestKey
 
 
 
 
 class TestSleepingBandwidthLimiter:
 class TestSleepingBandwidthLimiter:
@@ -147,3 +151,51 @@ class TestRepositoryCache:
             assert cache.evictions == 0
             assert cache.evictions == 0
 
 
         assert next(iterator) == bytes(100)
         assert next(iterator) == bytes(100)
+
+    @pytest.fixture
+    def key(self, repository, monkeypatch):
+        monkeypatch.setenv('BORG_PASSPHRASE', 'test')
+        key = PlaintextKey.create(repository, TestKey.MockArgs())
+        key.compressor = CompressionSpec('none').compressor
+        return key
+
+    def _put_encrypted_object(self, key, repository, data):
+        id_ = key.id_hash(data)
+        repository.put(id_, key.encrypt(data))
+        return id_
+
+    @pytest.fixture
+    def H1(self, key, repository):
+        return self._put_encrypted_object(key, repository, b'1234')
+
+    @pytest.fixture
+    def H2(self, key, repository):
+        return self._put_encrypted_object(key, repository, b'5678')
+
+    @pytest.fixture
+    def H3(self, key, repository):
+        return self._put_encrypted_object(key, repository, bytes(100))
+
+    @pytest.fixture
+    def decrypted_cache(self, key, repository):
+        return cache_if_remote(repository, decrypted_cache=key, force_cache=True)
+
+    def test_cache_corruption(self, decrypted_cache: RepositoryCache, H1, H2, H3):
+        list(decrypted_cache.get_many([H1, H2, H3]))
+
+        iterator = decrypted_cache.get_many([H1, H2, H3])
+        assert next(iterator) == (7, b'1234')
+
+        with open(decrypted_cache.key_filename(H2), 'a+b') as fd:
+            fd.seek(-1, io.SEEK_END)
+            corrupted = (int.from_bytes(fd.read(), 'little') ^ 2).to_bytes(1, 'little')
+            fd.seek(-1, io.SEEK_END)
+            fd.write(corrupted)
+            fd.truncate()
+
+        assert next(iterator) == (7, b'5678')
+        assert decrypted_cache.checksum_errors == 1
+        assert decrypted_cache.slow_misses == 1
+        assert next(iterator) == (103, bytes(100))
+        assert decrypted_cache.hits == 3
+        assert decrypted_cache.misses == 3