浏览代码

implement compression flexibility (attic init -c / --compression zlib|lzma)

I had to do a bit of bit-fiddling to preserve backwards compatibility.
Previous code used just 1 byte to determine encryption type and compression was hardcoded to be zlib.
It uses type bytes 0x00, 0x01 and 0x02 for that.
The record layout was rather fixed and there was no variable length part to add a compression type byte.
So I split that type byte: the upper 4bits are compression (0 means zlib as before), the lower 4 bits are for encryption.
Thomas Waldmann 10 年之前
父节点
当前提交
7b1a3dcd5d
共有 4 个文件被更改,包括 61 次插入20 次删除
  1. 3 0
      attic/archiver.py
  2. 55 18
      attic/key.py
  3. 2 2
      attic/testsuite/archive.py
  4. 1 0
      attic/testsuite/key.py

+ 3 - 0
attic/archiver.py

@@ -475,6 +475,9 @@ Type "Yes I am sure" if you understand this and want to continue.\n""")
         subparser.add_argument('-e', '--encryption', dest='encryption',
                                choices=('none', 'passphrase', 'keyfile'), default='none',
                                help='select encryption method')
+        subparser.add_argument('-c', '--compression', dest='compression',
+                               choices=('zlib', 'lzma'), default='zlib',
+                               help='select compression method')
 
         check_epilog = textwrap.dedent("""
         The check command verifies the consistency of a repository and the corresponding

+ 55 - 18
attic/key.py

@@ -35,24 +35,30 @@ class HMAC(hmac.HMAC):
 def key_creator(repository, args):
     if args.encryption == 'keyfile':
         return KeyfileKey.create(repository, args)
-    elif args.encryption == 'passphrase':
+    if args.encryption == 'passphrase':
         return PassphraseKey.create(repository, args)
-    else:
+    if args.encryption == 'none':
         return PlaintextKey.create(repository, args)
+    raise NotImplemented(args.encryption)
 
 
 def key_factory(repository, manifest_data):
-    if manifest_data[0] == KeyfileKey.TYPE:
+    # key type is determined by 4 lower bits of the type byte
+    key_type = manifest_data[0] & 0x0f
+    if key_type == KeyfileKey.TYPE:
         return KeyfileKey.detect(repository, manifest_data)
-    elif manifest_data[0] == PassphraseKey.TYPE:
+    if key_type == PassphraseKey.TYPE:
         return PassphraseKey.detect(repository, manifest_data)
-    elif manifest_data[0] == PlaintextKey.TYPE:
+    if key_type == PlaintextKey.TYPE:
         return PlaintextKey.detect(repository, manifest_data)
-    else:
-        raise UnsupportedPayloadError(manifest_data[0])
+    raise UnsupportedPayloadError(manifest_data[0])
 
 
 class CompressorBase(object):
+    @classmethod
+    def create(cls, args):
+        return cls()
+
     def compress(self, data):
         pass
 
@@ -61,6 +67,8 @@ class CompressorBase(object):
 
 
 class ZlibCompressor(CompressorBase):
+    TYPE = 0x00  # must be 0x00 for backwards compatibility
+
     def compress(self, data):
         return zlib.compress(data)
 
@@ -69,6 +77,8 @@ class ZlibCompressor(CompressorBase):
 
 
 class LzmaCompressor(CompressorBase):
+    TYPE = 0x10
+
     def __init__(self):
         if lzma is None:
             raise NotImplemented("lzma compression needs Python >= 3.3 or backports.lzma from PyPi")
@@ -80,11 +90,31 @@ class LzmaCompressor(CompressorBase):
         return lzma.decompress(data)
 
 
+def compressor_creator(args):
+    if args is None:  # used by unit tests
+        return ZlibCompressor.create(args)
+    if args.compression == 'lzma':
+        return LzmaCompressor.create(args)
+    if args.compression == 'zlib':
+        return ZlibCompressor.create(args)
+    raise NotImplemented(args.compression)
+
+
+def compressor_factory(manifest_data):
+    # compression is determined by 4 upper bits of the type byte
+    compression_type = manifest_data[0] & 0xf0
+    if compression_type == ZlibCompressor.TYPE:
+        return ZlibCompressor()
+    if compression_type == LzmaCompressor.TYPE:
+        return LzmaCompressor()
+    raise UnsupportedPayloadError(manifest_data[0])
+
+
 class KeyBase(object):
 
-    def __init__(self):
-        self.TYPE_STR = bytes([self.TYPE])
-        self.compressor = ZlibCompressor()
+    def __init__(self, compressor):
+        self.compressor = compressor
+        self.TYPE_STR = bytes([self.TYPE | self.compressor.TYPE])
 
     def id_hash(self, data):
         """Return HMAC hash using the "id" HMAC key
@@ -105,11 +135,13 @@ class PlaintextKey(KeyBase):
     @classmethod
     def create(cls, repository, args):
         print('Encryption NOT enabled.\nUse the "--encryption=passphrase|keyfile" to enable encryption.')
-        return cls()
+        compressor = compressor_creator(args)
+        return cls(compressor)
 
     @classmethod
     def detect(cls, repository, manifest_data):
-        return cls()
+        compressor = compressor_factory(manifest_data)
+        return cls(compressor)
 
     def id_hash(self, data):
         return sha256(data).digest()
@@ -118,8 +150,9 @@ class PlaintextKey(KeyBase):
         return b''.join([self.TYPE_STR, self.compressor.compress(data)])
 
     def decrypt(self, id, data):
-        if data[0] != self.TYPE:
-            raise IntegrityError('Invalid encryption envelope')
+        type_str = bytes([data[0]])
+        if type_str != self.TYPE_STR:
+            raise IntegrityError('Invalid encryption envelope %r' % type_str)
         data = self.compressor.decompress(memoryview(data)[1:])
         if id and sha256(data).digest() != id:
             raise IntegrityError('Chunk id verification failed')
@@ -191,7 +224,8 @@ class PassphraseKey(AESKeyBase):
 
     @classmethod
     def create(cls, repository, args):
-        key = cls()
+        compressor = compressor_creator(args)
+        key = cls(compressor)
         passphrase = os.environ.get('ATTIC_PASSPHRASE')
         if passphrase is not None:
             passphrase2 = passphrase
@@ -213,7 +247,8 @@ class PassphraseKey(AESKeyBase):
     @classmethod
     def detect(cls, repository, manifest_data):
         prompt = 'Enter passphrase for %s: ' % repository._location.orig
-        key = cls()
+        compressor = compressor_factory(manifest_data)
+        key = cls(compressor)
         passphrase = os.environ.get('ATTIC_PASSPHRASE')
         if passphrase is None:
             passphrase = getpass(prompt)
@@ -238,7 +273,8 @@ class KeyfileKey(AESKeyBase):
 
     @classmethod
     def detect(cls, repository, manifest_data):
-        key = cls()
+        compressor = compressor_factory(manifest_data)
+        key = cls(compressor)
         path = cls.find_key_file(repository)
         prompt = 'Enter passphrase for key file %s: ' % path
         passphrase = os.environ.get('ATTIC_PASSPHRASE', '')
@@ -346,7 +382,8 @@ class KeyfileKey(AESKeyBase):
             passphrase2 = getpass('Enter same passphrase again: ')
             if passphrase != passphrase2:
                 print('Passphrases do not match')
-        key = cls()
+        compressor = compressor_creator(args)
+        key = cls(compressor)
         key.repository_id = repository.id
         key.init_from_random_data(get_random_bytes(100))
         key.init_ciphers()

+ 2 - 2
attic/testsuite/archive.py

@@ -1,7 +1,7 @@
 import msgpack
 from attic.testsuite import AtticTestCase
 from attic.archive import CacheChunkBuffer, RobustUnpacker
-from attic.key import PlaintextKey
+from attic.key import PlaintextKey, ZlibCompressor
 
 
 class MockCache:
@@ -19,7 +19,7 @@ class ChunkBufferTestCase(AtticTestCase):
     def test(self):
         data = [{b'foo': 1}, {b'bar': 2}]
         cache = MockCache()
-        key = PlaintextKey()
+        key = PlaintextKey(ZlibCompressor())
         chunks = CacheChunkBuffer(cache, key, None)
         for d in data:
             chunks.add(d)

+ 1 - 0
attic/testsuite/key.py

@@ -13,6 +13,7 @@ class KeyTestCase(AtticTestCase):
 
     class MockArgs(object):
         repository = Location(tempfile.mkstemp()[1])
+        compression = 'zlib'
 
     keyfile2_key_file = """
         ATTIC KEY 0000000000000000000000000000000000000000000000000000000000000000