Browse Source

implement compression flexibility (attic init -c / --compression zlib|lzma)

I had to do a bit of bit-fiddling to preserve backwards compatibility.
Previous code used just 1 byte to determine encryption type and compression was hardcoded to be zlib.
It uses type bytes 0x00, 0x01 and 0x02 for that.
The record layout was rather fixed and there was no variable length part to add a compression type byte.
So I split that type byte: the upper 4bits are compression (0 means zlib as before), the lower 4 bits are for encryption.
Thomas Waldmann 10 years ago
parent
commit
7b1a3dcd5d
4 changed files with 61 additions and 20 deletions
  1. 3 0
      attic/archiver.py
  2. 55 18
      attic/key.py
  3. 2 2
      attic/testsuite/archive.py
  4. 1 0
      attic/testsuite/key.py

+ 3 - 0
attic/archiver.py

@@ -475,6 +475,9 @@ Type "Yes I am sure" if you understand this and want to continue.\n""")
         subparser.add_argument('-e', '--encryption', dest='encryption',
         subparser.add_argument('-e', '--encryption', dest='encryption',
                                choices=('none', 'passphrase', 'keyfile'), default='none',
                                choices=('none', 'passphrase', 'keyfile'), default='none',
                                help='select encryption method')
                                help='select encryption method')
+        subparser.add_argument('-c', '--compression', dest='compression',
+                               choices=('zlib', 'lzma'), default='zlib',
+                               help='select compression method')
 
 
         check_epilog = textwrap.dedent("""
         check_epilog = textwrap.dedent("""
         The check command verifies the consistency of a repository and the corresponding
         The check command verifies the consistency of a repository and the corresponding

+ 55 - 18
attic/key.py

@@ -35,24 +35,30 @@ class HMAC(hmac.HMAC):
 def key_creator(repository, args):
 def key_creator(repository, args):
     if args.encryption == 'keyfile':
     if args.encryption == 'keyfile':
         return KeyfileKey.create(repository, args)
         return KeyfileKey.create(repository, args)
-    elif args.encryption == 'passphrase':
+    if args.encryption == 'passphrase':
         return PassphraseKey.create(repository, args)
         return PassphraseKey.create(repository, args)
-    else:
+    if args.encryption == 'none':
         return PlaintextKey.create(repository, args)
         return PlaintextKey.create(repository, args)
+    raise NotImplemented(args.encryption)
 
 
 
 
 def key_factory(repository, manifest_data):
 def key_factory(repository, manifest_data):
-    if manifest_data[0] == KeyfileKey.TYPE:
+    # key type is determined by 4 lower bits of the type byte
+    key_type = manifest_data[0] & 0x0f
+    if key_type == KeyfileKey.TYPE:
         return KeyfileKey.detect(repository, manifest_data)
         return KeyfileKey.detect(repository, manifest_data)
-    elif manifest_data[0] == PassphraseKey.TYPE:
+    if key_type == PassphraseKey.TYPE:
         return PassphraseKey.detect(repository, manifest_data)
         return PassphraseKey.detect(repository, manifest_data)
-    elif manifest_data[0] == PlaintextKey.TYPE:
+    if key_type == PlaintextKey.TYPE:
         return PlaintextKey.detect(repository, manifest_data)
         return PlaintextKey.detect(repository, manifest_data)
-    else:
-        raise UnsupportedPayloadError(manifest_data[0])
+    raise UnsupportedPayloadError(manifest_data[0])
 
 
 
 
 class CompressorBase(object):
 class CompressorBase(object):
+    @classmethod
+    def create(cls, args):
+        return cls()
+
     def compress(self, data):
     def compress(self, data):
         pass
         pass
 
 
@@ -61,6 +67,8 @@ class CompressorBase(object):
 
 
 
 
 class ZlibCompressor(CompressorBase):
 class ZlibCompressor(CompressorBase):
+    TYPE = 0x00  # must be 0x00 for backwards compatibility
+
     def compress(self, data):
     def compress(self, data):
         return zlib.compress(data)
         return zlib.compress(data)
 
 
@@ -69,6 +77,8 @@ class ZlibCompressor(CompressorBase):
 
 
 
 
 class LzmaCompressor(CompressorBase):
 class LzmaCompressor(CompressorBase):
+    TYPE = 0x10
+
     def __init__(self):
     def __init__(self):
         if lzma is None:
         if lzma is None:
             raise NotImplemented("lzma compression needs Python >= 3.3 or backports.lzma from PyPi")
             raise NotImplemented("lzma compression needs Python >= 3.3 or backports.lzma from PyPi")
@@ -80,11 +90,31 @@ class LzmaCompressor(CompressorBase):
         return lzma.decompress(data)
         return lzma.decompress(data)
 
 
 
 
+def compressor_creator(args):
+    if args is None:  # used by unit tests
+        return ZlibCompressor.create(args)
+    if args.compression == 'lzma':
+        return LzmaCompressor.create(args)
+    if args.compression == 'zlib':
+        return ZlibCompressor.create(args)
+    raise NotImplemented(args.compression)
+
+
+def compressor_factory(manifest_data):
+    # compression is determined by 4 upper bits of the type byte
+    compression_type = manifest_data[0] & 0xf0
+    if compression_type == ZlibCompressor.TYPE:
+        return ZlibCompressor()
+    if compression_type == LzmaCompressor.TYPE:
+        return LzmaCompressor()
+    raise UnsupportedPayloadError(manifest_data[0])
+
+
 class KeyBase(object):
 class KeyBase(object):
 
 
-    def __init__(self):
-        self.TYPE_STR = bytes([self.TYPE])
-        self.compressor = ZlibCompressor()
+    def __init__(self, compressor):
+        self.compressor = compressor
+        self.TYPE_STR = bytes([self.TYPE | self.compressor.TYPE])
 
 
     def id_hash(self, data):
     def id_hash(self, data):
         """Return HMAC hash using the "id" HMAC key
         """Return HMAC hash using the "id" HMAC key
@@ -105,11 +135,13 @@ class PlaintextKey(KeyBase):
     @classmethod
     @classmethod
     def create(cls, repository, args):
     def create(cls, repository, args):
         print('Encryption NOT enabled.\nUse the "--encryption=passphrase|keyfile" to enable encryption.')
         print('Encryption NOT enabled.\nUse the "--encryption=passphrase|keyfile" to enable encryption.')
-        return cls()
+        compressor = compressor_creator(args)
+        return cls(compressor)
 
 
     @classmethod
     @classmethod
     def detect(cls, repository, manifest_data):
     def detect(cls, repository, manifest_data):
-        return cls()
+        compressor = compressor_factory(manifest_data)
+        return cls(compressor)
 
 
     def id_hash(self, data):
     def id_hash(self, data):
         return sha256(data).digest()
         return sha256(data).digest()
@@ -118,8 +150,9 @@ class PlaintextKey(KeyBase):
         return b''.join([self.TYPE_STR, self.compressor.compress(data)])
         return b''.join([self.TYPE_STR, self.compressor.compress(data)])
 
 
     def decrypt(self, id, data):
     def decrypt(self, id, data):
-        if data[0] != self.TYPE:
-            raise IntegrityError('Invalid encryption envelope')
+        type_str = bytes([data[0]])
+        if type_str != self.TYPE_STR:
+            raise IntegrityError('Invalid encryption envelope %r' % type_str)
         data = self.compressor.decompress(memoryview(data)[1:])
         data = self.compressor.decompress(memoryview(data)[1:])
         if id and sha256(data).digest() != id:
         if id and sha256(data).digest() != id:
             raise IntegrityError('Chunk id verification failed')
             raise IntegrityError('Chunk id verification failed')
@@ -191,7 +224,8 @@ class PassphraseKey(AESKeyBase):
 
 
     @classmethod
     @classmethod
     def create(cls, repository, args):
     def create(cls, repository, args):
-        key = cls()
+        compressor = compressor_creator(args)
+        key = cls(compressor)
         passphrase = os.environ.get('ATTIC_PASSPHRASE')
         passphrase = os.environ.get('ATTIC_PASSPHRASE')
         if passphrase is not None:
         if passphrase is not None:
             passphrase2 = passphrase
             passphrase2 = passphrase
@@ -213,7 +247,8 @@ class PassphraseKey(AESKeyBase):
     @classmethod
     @classmethod
     def detect(cls, repository, manifest_data):
     def detect(cls, repository, manifest_data):
         prompt = 'Enter passphrase for %s: ' % repository._location.orig
         prompt = 'Enter passphrase for %s: ' % repository._location.orig
-        key = cls()
+        compressor = compressor_factory(manifest_data)
+        key = cls(compressor)
         passphrase = os.environ.get('ATTIC_PASSPHRASE')
         passphrase = os.environ.get('ATTIC_PASSPHRASE')
         if passphrase is None:
         if passphrase is None:
             passphrase = getpass(prompt)
             passphrase = getpass(prompt)
@@ -238,7 +273,8 @@ class KeyfileKey(AESKeyBase):
 
 
     @classmethod
     @classmethod
     def detect(cls, repository, manifest_data):
     def detect(cls, repository, manifest_data):
-        key = cls()
+        compressor = compressor_factory(manifest_data)
+        key = cls(compressor)
         path = cls.find_key_file(repository)
         path = cls.find_key_file(repository)
         prompt = 'Enter passphrase for key file %s: ' % path
         prompt = 'Enter passphrase for key file %s: ' % path
         passphrase = os.environ.get('ATTIC_PASSPHRASE', '')
         passphrase = os.environ.get('ATTIC_PASSPHRASE', '')
@@ -346,7 +382,8 @@ class KeyfileKey(AESKeyBase):
             passphrase2 = getpass('Enter same passphrase again: ')
             passphrase2 = getpass('Enter same passphrase again: ')
             if passphrase != passphrase2:
             if passphrase != passphrase2:
                 print('Passphrases do not match')
                 print('Passphrases do not match')
-        key = cls()
+        compressor = compressor_creator(args)
+        key = cls(compressor)
         key.repository_id = repository.id
         key.repository_id = repository.id
         key.init_from_random_data(get_random_bytes(100))
         key.init_from_random_data(get_random_bytes(100))
         key.init_ciphers()
         key.init_ciphers()

+ 2 - 2
attic/testsuite/archive.py

@@ -1,7 +1,7 @@
 import msgpack
 import msgpack
 from attic.testsuite import AtticTestCase
 from attic.testsuite import AtticTestCase
 from attic.archive import CacheChunkBuffer, RobustUnpacker
 from attic.archive import CacheChunkBuffer, RobustUnpacker
-from attic.key import PlaintextKey
+from attic.key import PlaintextKey, ZlibCompressor
 
 
 
 
 class MockCache:
 class MockCache:
@@ -19,7 +19,7 @@ class ChunkBufferTestCase(AtticTestCase):
     def test(self):
     def test(self):
         data = [{b'foo': 1}, {b'bar': 2}]
         data = [{b'foo': 1}, {b'bar': 2}]
         cache = MockCache()
         cache = MockCache()
-        key = PlaintextKey()
+        key = PlaintextKey(ZlibCompressor())
         chunks = CacheChunkBuffer(cache, key, None)
         chunks = CacheChunkBuffer(cache, key, None)
         for d in data:
         for d in data:
             chunks.add(d)
             chunks.add(d)

+ 1 - 0
attic/testsuite/key.py

@@ -13,6 +13,7 @@ class KeyTestCase(AtticTestCase):
 
 
     class MockArgs(object):
     class MockArgs(object):
         repository = Location(tempfile.mkstemp()[1])
         repository = Location(tempfile.mkstemp()[1])
+        compression = 'zlib'
 
 
     keyfile2_key_file = """
     keyfile2_key_file = """
         ATTIC KEY 0000000000000000000000000000000000000000000000000000000000000000
         ATTIC KEY 0000000000000000000000000000000000000000000000000000000000000000