瀏覽代碼

Added keychain and switched to MODE_CTR

Jonas Borgström 14 年之前
父節點
當前提交
480562570f
共有 4 個文件被更改,包括 108 次插入35 次删除
  1. 6 7
      dedupestore/archive.py
  2. 35 7
      dedupestore/archiver.py
  3. 63 21
      dedupestore/crypto.py
  4. 4 0
      dedupestore/test.py

+ 6 - 7
dedupestore/archive.py

@@ -7,7 +7,6 @@ import sys
 
 from .cache import NS_ARCHIVES, NS_CHUNKS, NS_CINDEX
 from .chunkifier import chunkify
-from .crypto import CryptoManager
 from .helpers import uid2user, user2uid, gid2group, group2gid, IntegrityError
 
 CHUNK_SIZE = 55001
@@ -15,8 +14,8 @@ CHUNK_SIZE = 55001
 
 class Archive(object):
 
-    def __init__(self, store, name=None):
-        self.crypto = CryptoManager(store)
+    def __init__(self, store, crypto, name=None):
+        self.crypto = crypto
         self.store = store
         self.items = []
         self.chunks = []
@@ -150,8 +149,8 @@ class Archive(object):
                     except IntegrityError:
                         logging.error('%s ... ERROR', item['path'])
                         break
-                    else:
-                        logging.info('%s ... OK', item['path'])
+                else:
+                    logging.info('%s ... OK', item['path'])
 
     def delete(self, cache):
         self.store.delete(NS_ARCHIVES, self.id)
@@ -254,8 +253,8 @@ class Archive(object):
             return idx
 
     @staticmethod
-    def list_archives(store):
+    def list_archives(store, crypto):
         for id in store.list(NS_ARCHIVES):
-            archive = Archive(store)
+            archive = Archive(store, crypto)
             archive.load(id)
             yield archive

+ 35 - 7
dedupestore/archiver.py

@@ -5,6 +5,7 @@ import sys
 from .archive import Archive
 from .bandstore import BandStore
 from .cache import Cache
+from .crypto import CryptoManager, KeyChain
 from .helpers import location_validator, pretty_size, LevelFilter
 
 
@@ -19,43 +20,55 @@ class Archiver(object):
 
     def do_create(self, args):
         store = self.open_store(args.archive)
-        archive = Archive(store)
+        keychain = KeyChain(args.keychain)
+        crypto = CryptoManager(keychain)
+        archive = Archive(store, crypto)
         cache = Cache(store, archive.crypto)
         archive.create(args.archive.archive, args.paths, cache)
         return self.exit_code_from_logger()
 
     def do_extract(self, args):
         store = self.open_store(args.archive)
-        archive = Archive(store, args.archive.archive)
+        keychain = KeyChain(args.keychain)
+        crypto = CryptoManager(keychain)
+        archive = Archive(store, crypto, args.archive.archive)
         archive.extract(args.dest)
         return self.exit_code_from_logger()
 
     def do_delete(self, args):
         store = self.open_store(args.archive)
-        archive = Archive(store, args.archive.archive)
+        keychain = KeyChain(args.keychain)
+        crypto = CryptoManager(keychain)
+        archive = Archive(store, crypto, args.archive.archive)
         cache = Cache(store, archive.crypto)
         archive.delete(cache)
         return self.exit_code_from_logger()
 
     def do_list(self, args):
         store = self.open_store(args.src)
+        keychain = KeyChain(args.keychain)
+        crypto = CryptoManager(keychain)
         if args.src.archive:
-            archive = Archive(store, args.src.archive)
+            archive = Archive(store, crypto, args.src.archive)
             archive.list()
         else:
-            for archive in Archive.list_archives(store):
+            for archive in Archive.list_archives(store, crypto):
                 print archive
         return self.exit_code_from_logger()
 
     def do_verify(self, args):
         store = self.open_store(args.archive)
-        archive = Archive(store, args.archive.archive)
+        keychain = KeyChain(args.keychain)
+        crypto = CryptoManager(keychain)
+        archive = Archive(store, crypto, args.archive.archive)
         archive.verify()
         return self.exit_code_from_logger()
 
     def do_info(self, args):
         store = self.open_store(args.archive)
-        archive = Archive(store, args.archive.archive)
+        keychain = KeyChain(args.keychain)
+        crypto = CryptoManager(keychain)
+        archive = Archive(store, crypto, args.archive.archive)
         cache = Cache(store, archive.crypto)
         osize, csize, usize = archive.stats(cache)
         print 'Original size:', pretty_size(osize)
@@ -63,13 +76,28 @@ class Archiver(object):
         print 'Unique data:', pretty_size(usize)
         return self.exit_code_from_logger()
 
+    def do_keychain_generate(self, args):
+        keychain = KeyChain.generate()
+        keychain.save(args.path)
+        return 0
+
     def run(self, args=None):
         parser = argparse.ArgumentParser(description='Dedupestore')
+        parser.add_argument('-k', '--key-chain', dest='keychain', type=str,
+                            help='Key chain')
         parser.add_argument('-v', '--verbose', dest='verbose', action='store_true',
                             default=False,
                             help='Verbose output')
 
+
         subparsers = parser.add_subparsers(title='Available subcommands')
+        subparser = subparsers.add_parser('keychain')
+        subsubparsers = subparser.add_subparsers(title='Available subcommands')
+        subparser = subsubparsers.add_parser('generate')
+        subparser.add_argument('path', metavar='PATH', type=str,
+                               help='Path to keychain')
+        subparser.set_defaults(func=self.do_keychain_generate)
+
         subparser = subparsers.add_parser('create')
         subparser.set_defaults(func=self.do_create)
         subparser.add_argument('archive', metavar='ARCHIVE',

+ 63 - 21
dedupestore/crypto.py

@@ -1,65 +1,107 @@
 import os
+import logging
+import msgpack
 import zlib
 
 from Crypto.Cipher import AES
 from Crypto.Hash import SHA256, HMAC
-from Crypto.Util.number import bytes_to_long, long_to_bytes
+from Crypto.PublicKey import RSA
+from Crypto.Util import Counter
+from Crypto.Util.number import bytes_to_long
 
 from .helpers import IntegrityError
 from .oaep import OAEP
 
 
+class KeyChain(object):
+
+    def __init__(self, path=None):
+        self.aes_id = self.rsa_read = self.rsa_create = None
+        if path:
+            self.open(path)
+
+    def open(self, path):
+        with open(path, 'rb') as fd:
+            chain = msgpack.unpackb(fd.read())
+        logging.info('Key chain "%s" opened', path)
+        assert chain['version'] == 1
+        self.aes_id = chain['aes_id']
+        self.rsa_read = RSA.importKey(chain['rsa_read'])
+        self.rsa_create = RSA.importKey(chain['rsa_create'])
+
+    def save(self, path):
+        chain = {
+            'version': 1,
+            'aes_id': self.aes_id,
+            'rsa_read': self.rsa_read.exportKey('PEM'),
+            'rsa_create': self.rsa_create.exportKey('PEM'),
+        }
+        with open(path, 'wb') as fd:
+            fd.write(msgpack.packb(chain))
+            logging.info('Key chain "%s" saved', path)
+
+    @staticmethod
+    def generate():
+        chain = KeyChain()
+        chain.aes_id = os.urandom(32)
+        chain.rsa_read = RSA.generate(2048)
+        chain.rsa_create = RSA.generate(2048)
+        return chain
+
 class CryptoManager(object):
 
     CREATE = '\1'
     READ = '\2'
 
-    def __init__(self, store):
-        self.key_cache = {}
-        self.store = store
-        self.tid = store.tid
-        self.id_key = '0' * 32
+    def __init__(self, keychain):
+        self._key_cache = {}
+        self.keychain = keychain
         self.read_key = os.urandom(32)
         self.create_key = os.urandom(32)
         self.read_encrypted = OAEP(256, hash=SHA256).encode(self.read_key, os.urandom(32))
+        self.read_encrypted = keychain.rsa_read.encrypt(self.read_encrypted, '')[0]
         self.create_encrypted = OAEP(256, hash=SHA256).encode(self.create_key, os.urandom(32))
+        self.create_encrypted = keychain.rsa_create.encrypt(self.create_encrypted, '')[0]
 
     def id_hash(self, data):
-        return HMAC.new(self.id_key, data, SHA256).digest()
+        return HMAC.new(self.keychain.aes_id, data, SHA256).digest()
 
     def encrypt_read(self, data):
-        key_data = OAEP(256, hash=SHA256).encode(self.read_key, os.urandom(32))
-        #key_data = self.rsa_create.encrypt(key_data)
         data = zlib.compress(data)
         hash = SHA256.new(data).digest()
-        data = AES.new(self.read_key, AES.MODE_CFB, hash[:16]).encrypt(data)
+        counter = Counter.new(128, initial_value=bytes_to_long(hash[:16]), allow_wraparound=True)
+        data = AES.new(self.read_key, AES.MODE_CTR, '', counter=counter).encrypt(data)
         return ''.join((self.READ, self.read_encrypted, hash, data))
 
     def encrypt_create(self, data):
-        key_data = OAEP(256, hash=SHA256).encode(self.create_key, os.urandom(32))
-        #key_data = self.rsa_create.encrypt(key_data)
         data = zlib.compress(data)
         hash = SHA256.new(data).digest()
-        data = AES.new(self.create_key, AES.MODE_CFB, hash[:16]).encrypt(data)
+        counter = Counter.new(128, initial_value=bytes_to_long(hash[:16]), allow_wraparound=True)
+        data = AES.new(self.create_key, AES.MODE_CTR, '', counter=counter).encrypt(data)
         return ''.join((self.CREATE, self.create_encrypted, hash, data))
 
+    def decrypt_key(self, data, rsa_key):
+        try:
+            return self._key_cache[data]
+        except KeyError:
+            self._key_cache[data] = OAEP(256, hash=SHA256).decode(rsa_key.decrypt(data))
+            return self._key_cache[data]
+
     def decrypt(self, data):
         type = data[0]
         if type == self.READ:
-            key_data = data[1:257]
+            key = self.decrypt_key(data[1:257], self.keychain.rsa_read)
             hash = data[257:289]
-            #key_data = self.rsa_create.decrypt(key_data)
-            key = OAEP(256, hash=SHA256).decode(key_data)
-            data = AES.new(key, AES.MODE_CFB, hash[:16]).decrypt(data[289:])
+            counter = Counter.new(128, initial_value=bytes_to_long(hash[:16]), allow_wraparound=True)
+            data = AES.new(key, AES.MODE_CTR, counter=counter).decrypt(data[289:])
             if SHA256.new(data).digest() != hash:
                 raise IntegrityError('decryption failed')
             return zlib.decompress(data)
         elif type == self.CREATE:
-            key_data = data[1:257]
+            key = self.decrypt_key(data[1:257], self.keychain.rsa_create)
             hash = data[257:289]
-            #key_data = self.rsa_create.decrypt(key_data)
-            key = OAEP(256, hash=SHA256).decode(key_data)
-            data = AES.new(key, AES.MODE_CFB, hash[:16]).decrypt(data[289:])
+            counter = Counter.new(128, initial_value=bytes_to_long(hash[:16]), allow_wraparound=True)
+            data = AES.new(key, AES.MODE_CTR, '', counter=counter).decrypt(data[289:])
             if SHA256.new(data).digest() != hash:
                 raise IntegrityError('decryption failed')
             return zlib.decompress(data)

+ 4 - 0
dedupestore/test.py

@@ -13,12 +13,16 @@ class Test(unittest.TestCase):
         self.archiver = Archiver()
         self.tmpdir = tempfile.mkdtemp()
         self.store_path = os.path.join(self.tmpdir, 'store')
+        self.keychain = '/tmp/_test_dedupstore.keychain'
+        if not os.path.exists(self.keychain):
+            self.dedupestore('keychain', 'generate', self.keychain)
 
     def tearDown(self):
         shutil.rmtree(self.tmpdir)
 
     def dedupestore(self, *args, **kwargs):
         exit_code = kwargs.get('exit_code', 0)
+        args = ['--key-chain', self.keychain] + list(args)
         self.assertEqual(exit_code, self.archiver.run(args))
 
     def create_src_archive(self, name):