瀏覽代碼

Merge CryptoManager and KeyChain into Keychain

Jonas Borgström 14 年之前
父節點
當前提交
d70993fc83
共有 4 個文件被更改,包括 64 次插入73 次删除
  1. 18 18
      darc/archive.py
  2. 21 27
      darc/archiver.py
  3. 6 6
      darc/cache.py
  4. 19 22
      darc/keychain.py

+ 18 - 18
darc/archive.py

@@ -21,32 +21,32 @@ class Archive(object):
     class DoesNotExist(Exception):
     class DoesNotExist(Exception):
         pass
         pass
 
 
-    def __init__(self, store, crypto, name=None):
-        self.crypto = crypto
+    def __init__(self, store, keychain, name=None):
+        self.keychain = keychain
         self.store = store
         self.store = store
         self.items = []
         self.items = []
         self.chunks = []
         self.chunks = []
         self.chunk_idx = {}
         self.chunk_idx = {}
         self.hard_links = {}
         self.hard_links = {}
         if name:
         if name:
-            self.load(self.crypto.id_hash(name))
+            self.load(self.keychain.id_hash(name))
 
 
     def load(self, id):
     def load(self, id):
         self.id = id
         self.id = id
         try:
         try:
-            data, self.hash = self.crypto.decrypt(self.store.get(NS_ARCHIVE_METADATA, self.id))
+            data, self.hash = self.keychain.decrypt(self.store.get(NS_ARCHIVE_METADATA, self.id))
         except self.store.DoesNotExist:
         except self.store.DoesNotExist:
             raise self.DoesNotExist
             raise self.DoesNotExist
         self.metadata = msgpack.unpackb(data)
         self.metadata = msgpack.unpackb(data)
         assert self.metadata['version'] == 1
         assert self.metadata['version'] == 1
 
 
     def get_items(self):
     def get_items(self):
-        data, chunks_hash = self.crypto.decrypt(self.store.get(NS_ARCHIVE_CHUNKS, self.id))
+        data, chunks_hash = self.keychain.decrypt(self.store.get(NS_ARCHIVE_CHUNKS, self.id))
         chunks = msgpack.unpackb(data)
         chunks = msgpack.unpackb(data)
         assert chunks['version'] == 1
         assert chunks['version'] == 1
         assert self.metadata['chunks_hash'] == chunks_hash
         assert self.metadata['chunks_hash'] == chunks_hash
         self.chunks = chunks['chunks']
         self.chunks = chunks['chunks']
-        data, items_hash = self.crypto.decrypt(self.store.get(NS_ARCHIVE_ITEMS, self.id))
+        data, items_hash = self.keychain.decrypt(self.store.get(NS_ARCHIVE_ITEMS, self.id))
         items = msgpack.unpackb(data)
         items = msgpack.unpackb(data)
         assert items['version'] == 1
         assert items['version'] == 1
         assert self.metadata['items_hash'] == items_hash
         assert self.metadata['items_hash'] == items_hash
@@ -55,12 +55,12 @@ class Archive(object):
             self.chunk_idx[i] = chunk[0]
             self.chunk_idx[i] = chunk[0]
 
 
     def save(self, name):
     def save(self, name):
-        self.id = self.crypto.id_hash(name)
+        self.id = self.keychain.id_hash(name)
         chunks = {'version': 1, 'chunks': self.chunks}
         chunks = {'version': 1, 'chunks': self.chunks}
-        data, chunks_hash = self.crypto.encrypt_create(msgpack.packb(chunks))
+        data, chunks_hash = self.keychain.encrypt_create(msgpack.packb(chunks))
         self.store.put(NS_ARCHIVE_CHUNKS, self.id, data)
         self.store.put(NS_ARCHIVE_CHUNKS, self.id, data)
         items = {'version': 1, 'items': self.items}
         items = {'version': 1, 'items': self.items}
-        data, items_hash = self.crypto.encrypt_read(msgpack.packb(items))
+        data, items_hash = self.keychain.encrypt_read(msgpack.packb(items))
         self.store.put(NS_ARCHIVE_ITEMS, self.id, data)
         self.store.put(NS_ARCHIVE_ITEMS, self.id, data)
         metadata = {
         metadata = {
             'version': 1,
             'version': 1,
@@ -72,7 +72,7 @@ class Archive(object):
             'username': getuser(),
             'username': getuser(),
             'time': datetime.utcnow().isoformat(),
             'time': datetime.utcnow().isoformat(),
         }
         }
-        data, self.hash = self.crypto.encrypt_read(msgpack.packb(metadata))
+        data, self.hash = self.keychain.encrypt_read(msgpack.packb(metadata))
         self.store.put(NS_ARCHIVE_METADATA, self.id, data)
         self.store.put(NS_ARCHIVE_METADATA, self.id, data)
         self.store.commit()
         self.store.commit()
 
 
@@ -134,8 +134,8 @@ class Archive(object):
                     for chunk in item['chunks']:
                     for chunk in item['chunks']:
                         id = self.chunk_idx[chunk]
                         id = self.chunk_idx[chunk]
                         try:
                         try:
-                            data, hash = self.crypto.decrypt(self.store.get(NS_CHUNK, id))
-                            if self.crypto.id_hash(data) != id:
+                            data, hash = self.keychain.decrypt(self.store.get(NS_CHUNK, id))
+                            if self.keychain.id_hash(data) != id:
                                 raise IntegrityError('chunk id did not match')
                                 raise IntegrityError('chunk id did not match')
                             fd.write(data)
                             fd.write(data)
                         except ValueError:
                         except ValueError:
@@ -171,8 +171,8 @@ class Archive(object):
         for chunk in item['chunks']:
         for chunk in item['chunks']:
             id = self.chunk_idx[chunk]
             id = self.chunk_idx[chunk]
             try:
             try:
-                data, hash = self.crypto.decrypt(self.store.get(NS_CHUNK, id))
-                if self.crypto.id_hash(data) != id:
+                data, hash = self.keychain.decrypt(self.store.get(NS_CHUNK, id))
+                if self.keychain.id_hash(data) != id:
                     raise IntegrityError('chunk id did not match')
                     raise IntegrityError('chunk id did not match')
             except IntegrityError:
             except IntegrityError:
                 return False
                 return False
@@ -227,7 +227,7 @@ class Archive(object):
                 return
                 return
             else:
             else:
                 self.hard_links[st.st_ino, st.st_dev] = safe_path
                 self.hard_links[st.st_ino, st.st_dev] = safe_path
-        path_hash = self.crypto.id_hash(path.encode('utf-8'))
+        path_hash = self.keychain.id_hash(path.encode('utf-8'))
         ids, size = cache.file_known_and_unchanged(path_hash, st)
         ids, size = cache.file_known_and_unchanged(path_hash, st)
         if ids is not None:
         if ids is not None:
             # Make sure all ids are available
             # Make sure all ids are available
@@ -245,7 +245,7 @@ class Archive(object):
                 ids = []
                 ids = []
                 chunks = []
                 chunks = []
                 for chunk in chunkify(fd, CHUNK_SIZE, 30):
                 for chunk in chunkify(fd, CHUNK_SIZE, 30):
-                    id = self.crypto.id_hash(chunk)
+                    id = self.keychain.id_hash(chunk)
                     ids.append(id)
                     ids.append(id)
                     try:
                     try:
                         chunks.append(self.chunk_idx[id])
                         chunks.append(self.chunk_idx[id])
@@ -275,8 +275,8 @@ class Archive(object):
         return idx
         return idx
 
 
     @staticmethod
     @staticmethod
-    def list_archives(store, crypto):
+    def list_archives(store, keychain):
         for id in list(store.list(NS_ARCHIVE_METADATA)):
         for id in list(store.list(NS_ARCHIVE_METADATA)):
-            archive = Archive(store, crypto)
+            archive = Archive(store, keychain)
             archive.load(id)
             archive.load(id)
             yield archive
             yield archive

+ 21 - 27
darc/archiver.py

@@ -7,7 +7,7 @@ import sys
 from .archive import Archive
 from .archive import Archive
 from .store import Store
 from .store import Store
 from .cache import Cache
 from .cache import Cache
-from .crypto import CryptoManager, KeyChain
+from .keychain import Keychain
 from .helpers import location_validator, format_file_size, format_time, format_file_mode, walk_dir
 from .helpers import location_validator, format_file_size, format_time, format_file_mode, walk_dir
 
 
 
 
@@ -38,17 +38,16 @@ class Archiver(object):
 
 
     def do_create(self, args):
     def do_create(self, args):
         store = self.open_store(args.archive)
         store = self.open_store(args.archive)
-        keychain = KeyChain(args.keychain)
-        crypto = CryptoManager(keychain)
+        keychain = Keychain(args.keychain)
         try:
         try:
-            Archive(store, crypto, args.archive.archive)
+            Archive(store, keychain, args.archive.archive)
         except Archive.DoesNotExist:
         except Archive.DoesNotExist:
             pass
             pass
         else:
         else:
             self.print_error('Archive already exists')
             self.print_error('Archive already exists')
             return self.exit_code
             return self.exit_code
-        archive = Archive(store, crypto)
-        cache = Cache(store, archive.crypto)
+        archive = Archive(store, keychain)
+        cache = Cache(store, keychain)
         for path in args.paths:
         for path in args.paths:
             for path, st in walk_dir(unicode(path)):
             for path, st in walk_dir(unicode(path)):
                 if stat.S_ISDIR(st.st_mode):
                 if stat.S_ISDIR(st.st_mode):
@@ -70,9 +69,8 @@ class Archiver(object):
 
 
     def do_extract(self, args):
     def do_extract(self, args):
         store = self.open_store(args.archive)
         store = self.open_store(args.archive)
-        keychain = KeyChain(args.keychain)
-        crypto = CryptoManager(keychain)
-        archive = Archive(store, crypto, args.archive.archive)
+        keychain = Keychain(args.keychain)
+        archive = Archive(store, keychain, args.archive.archive)
         archive.get_items()
         archive.get_items()
         dirs = []
         dirs = []
         for item in archive.items:
         for item in archive.items:
@@ -89,20 +87,18 @@ class Archiver(object):
 
 
     def do_delete(self, args):
     def do_delete(self, args):
         store = self.open_store(args.archive)
         store = self.open_store(args.archive)
-        keychain = KeyChain(args.keychain)
-        crypto = CryptoManager(keychain)
-        archive = Archive(store, crypto, args.archive.archive)
-        cache = Cache(store, archive.crypto)
+        keychain = Keychain(args.keychain)
+        archive = Archive(store, keychain, args.archive.archive)
+        cache = Cache(store, keychain)
         archive.delete(cache)
         archive.delete(cache)
         return self.exit_code
         return self.exit_code
 
 
     def do_list(self, args):
     def do_list(self, args):
         store = self.open_store(args.src)
         store = self.open_store(args.src)
-        keychain = KeyChain(args.keychain)
-        crypto = CryptoManager(keychain)
+        keychain = Keychain(args.keychain)
         if args.src.archive:
         if args.src.archive:
             tmap = {1: 'p', 2: 'c', 4: 'd', 6: 'b', 010: '-', 012: 'l', 014: 's'}
             tmap = {1: 'p', 2: 'c', 4: 'd', 6: 'b', 010: '-', 012: 'l', 014: 's'}
-            archive = Archive(store, crypto, args.src.archive)
+            archive = Archive(store, keychain, args.src.archive)
             archive.get_items()
             archive.get_items()
             for item in archive.items:
             for item in archive.items:
                 type = tmap.get(item['mode'] / 4096, '?')
                 type = tmap.get(item['mode'] / 4096, '?')
@@ -112,15 +108,14 @@ class Archiver(object):
                 print '%s%s %-6s %-6s %8d %s %s' % (type, mode, item['user'],
                 print '%s%s %-6s %-6s %8d %s %s' % (type, mode, item['user'],
                                                   item['group'], size, mtime, item['path'])
                                                   item['group'], size, mtime, item['path'])
         else:
         else:
-            for archive in Archive.list_archives(store, crypto):
+            for archive in Archive.list_archives(store, keychain):
                 print '%(name)-20s %(time)s' % archive.metadata
                 print '%(name)-20s %(time)s' % archive.metadata
         return self.exit_code
         return self.exit_code
 
 
     def do_verify(self, args):
     def do_verify(self, args):
         store = self.open_store(args.archive)
         store = self.open_store(args.archive)
-        keychain = KeyChain(args.keychain)
-        crypto = CryptoManager(keychain)
-        archive = Archive(store, crypto, args.archive.archive)
+        keychain = Keychain(args.keychain)
+        archive = Archive(store, keychain, args.archive.archive)
         archive.get_items()
         archive.get_items()
         for item in archive.items:
         for item in archive.items:
             if stat.S_ISREG(item['mode']) and not 'source' in item:
             if stat.S_ISREG(item['mode']) and not 'source' in item:
@@ -134,10 +129,9 @@ class Archiver(object):
 
 
     def do_info(self, args):
     def do_info(self, args):
         store = self.open_store(args.archive)
         store = self.open_store(args.archive)
-        keychain = KeyChain(args.keychain)
-        crypto = CryptoManager(keychain)
-        archive = Archive(store, crypto, args.archive.archive)
-        cache = Cache(store, archive.crypto)
+        keychain = Keychain(args.keychain)
+        archive = Archive(store, keychain, args.archive.archive)
+        cache = Cache(store, keychain)
         osize, csize, usize = archive.stats(cache)
         osize, csize, usize = archive.stats(cache)
         print 'Name:', archive.metadata['name']
         print 'Name:', archive.metadata['name']
         print 'Hostname:', archive.metadata['hostname']
         print 'Hostname:', archive.metadata['hostname']
@@ -151,15 +145,15 @@ class Archiver(object):
         return self.exit_code
         return self.exit_code
 
 
     def do_init_keychain(self, args):
     def do_init_keychain(self, args):
-        return KeyChain.generate(args.keychain)
+        return Keychain.generate(args.keychain)
 
 
     def do_export_restricted(self, args):
     def do_export_restricted(self, args):
-        keychain = KeyChain(args.keychain)
+        keychain = Keychain(args.keychain)
         keychain.restrict(args.output)
         keychain.restrict(args.output)
         return self.exit_code
         return self.exit_code
 
 
     def do_keychain_chpass(self, args):
     def do_keychain_chpass(self, args):
-        return KeyChain(args.keychain).chpass()
+        return Keychain(args.keychain).chpass()
 
 
     def run(self, args=None):
     def run(self, args=None):
         default_keychain = os.path.join(os.path.expanduser('~'),
         default_keychain = os.path.join(os.path.expanduser('~'),

+ 6 - 6
darc/cache.py

@@ -8,9 +8,9 @@ class Cache(object):
     """Client Side cache
     """Client Side cache
     """
     """
 
 
-    def __init__(self, store, crypto):
+    def __init__(self, store, keychain):
         self.store = store
         self.store = store
-        self.crypto = crypto
+        self.keychain = keychain
         self.path = os.path.join(os.path.expanduser('~'), '.darc', 'cache',
         self.path = os.path.join(os.path.expanduser('~'), '.darc', 'cache',
                                  '%s.cache' % self.store.id.encode('hex'))
                                  '%s.cache' % self.store.id.encode('hex'))
         self.tid = -1
         self.tid = -1
@@ -22,7 +22,7 @@ class Cache(object):
         if not os.path.exists(self.path):
         if not os.path.exists(self.path):
             return
             return
         with open(self.path, 'rb') as fd:
         with open(self.path, 'rb') as fd:
-            data, hash = self.crypto.decrypt(fd.read())
+            data, hash = self.keychain.decrypt(fd.read())
             cache = msgpack.unpackb(data)
             cache = msgpack.unpackb(data)
         assert cache['version'] == 1
         assert cache['version'] == 1
         self.chunk_counts = cache['chunk_counts']
         self.chunk_counts = cache['chunk_counts']
@@ -39,7 +39,7 @@ class Cache(object):
         if self.store.tid == 0:
         if self.store.tid == 0:
             return
             return
         for id in list(self.store.list(NS_ARCHIVE_CHUNKS)):
         for id in list(self.store.list(NS_ARCHIVE_CHUNKS)):
-            data, hash = self.crypto.decrypt(self.store.get(NS_ARCHIVE_CHUNKS, id))
+            data, hash = self.keychain.decrypt(self.store.get(NS_ARCHIVE_CHUNKS, id))
             cindex = msgpack.unpackb(data)
             cindex = msgpack.unpackb(data)
             for id, size in cindex['chunks']:
             for id, size in cindex['chunks']:
                 try:
                 try:
@@ -61,7 +61,7 @@ class Cache(object):
                 'chunk_counts': self.chunk_counts,
                 'chunk_counts': self.chunk_counts,
                 'file_chunks': dict(self.filter_file_chunks()),
                 'file_chunks': dict(self.filter_file_chunks()),
         }
         }
-        data, hash = self.crypto.encrypt_create(msgpack.packb(cache))
+        data, hash = self.keychain.encrypt_create(msgpack.packb(cache))
         cachedir = os.path.dirname(self.path)
         cachedir = os.path.dirname(self.path)
         if not os.path.exists(cachedir):
         if not os.path.exists(cachedir):
             os.makedirs(cachedir)
             os.makedirs(cachedir)
@@ -71,7 +71,7 @@ class Cache(object):
     def add_chunk(self, id, data):
     def add_chunk(self, id, data):
         if self.seen_chunk(id):
         if self.seen_chunk(id):
             return self.chunk_incref(id)
             return self.chunk_incref(id)
-        data, hash = self.crypto.encrypt_read(data)
+        data, hash = self.keychain.encrypt_read(data)
         csize = len(data)
         csize = len(data)
         self.store.put(NS_CHUNK, id, data)
         self.store.put(NS_CHUNK, id, data)
         self.chunk_counts[id] = (1, csize)
         self.chunk_counts[id] = (1, csize)

+ 19 - 22
darc/crypto.py → darc/keychain.py

@@ -15,10 +15,16 @@ from .helpers import IntegrityError
 from .oaep import OAEP
 from .oaep import OAEP
 
 
 
 
-class KeyChain(object):
+class Keychain(object):
     FILE_ID = 'DARC KEYCHAIN'
     FILE_ID = 'DARC KEYCHAIN'
 
 
+    CREATE = '\1'
+    READ = '\2'
+
     def __init__(self, path=None):
     def __init__(self, path=None):
+        self._key_cache = {}
+        self.read_key = os.urandom(32)
+        self.create_key = os.urandom(32)
         self.aes_id = self.rsa_read = self.rsa_create = None
         self.aes_id = self.rsa_read = self.rsa_create = None
         self.path = path
         self.path = path
         if path:
         if path:
@@ -31,7 +37,7 @@ class KeyChain(object):
                 raise ValueError('Not a keychain')
                 raise ValueError('Not a keychain')
             cdata = fd.read()
             cdata = fd.read()
         self.password = ''
         self.password = ''
-        data = self.decrypt(cdata, '')
+        data = self._decrypt(cdata, '')
         while not data:
         while not data:
             self.password = getpass('Keychain password: ')
             self.password = getpass('Keychain password: ')
             if not self.password:
             if not self.password:
@@ -44,6 +50,10 @@ class KeyChain(object):
         self.aes_id = chain['aes_id']
         self.aes_id = chain['aes_id']
         self.rsa_read = RSA.importKey(chain['rsa_read'])
         self.rsa_read = RSA.importKey(chain['rsa_read'])
         self.rsa_create = RSA.importKey(chain['rsa_create'])
         self.rsa_create = RSA.importKey(chain['rsa_create'])
+        self.read_encrypted = OAEP(256, hash=SHA256).encode(self.read_key, os.urandom(32))
+        self.read_encrypted = self.rsa_read.encrypt(self.read_encrypted, '')[0]
+        self.create_encrypted = OAEP(256, hash=SHA256).encode(self.create_key, os.urandom(32))
+        self.create_encrypted = self.rsa_create.encrypt(self.create_encrypted, '')[0]
 
 
     def encrypt(self, data, password):
     def encrypt(self, data, password):
         salt = os.urandom(32)
         salt = os.urandom(32)
@@ -61,7 +71,7 @@ class KeyChain(object):
         }
         }
         return msgpack.packb(d)
         return msgpack.packb(d)
 
 
-    def decrypt(self, data, password):
+    def _decrypt(self, data, password):
         d = msgpack.unpackb(data)
         d = msgpack.unpackb(data)
         assert d['version'] == 1
         assert d['version'] == 1
         assert d['algorithm'] == 'SHA256'
         assert d['algorithm'] == 'SHA256'
@@ -113,7 +123,7 @@ class KeyChain(object):
             password2 = getpass('Keychain password again: ')
             password2 = getpass('Keychain password again: ')
             if password != password2:
             if password != password2:
                 print 'Passwords do not match'
                 print 'Passwords do not match'
-        chain = KeyChain()
+        chain = Keychain()
         print 'Generating keychain'
         print 'Generating keychain'
         chain.aes_id = os.urandom(32)
         chain.aes_id = os.urandom(32)
         chain.rsa_read = RSA.generate(2048)
         chain.rsa_read = RSA.generate(2048)
@@ -121,23 +131,8 @@ class KeyChain(object):
         chain.save(path, password)
         chain.save(path, password)
         return 0
         return 0
 
 
-class CryptoManager(object):
-
-    CREATE = '\1'
-    READ = '\2'
-
-    def __init__(self, keychain):
-        self._key_cache = {}
-        self.keychain = keychain
-        self.read_key = os.urandom(32)
-        self.create_key = os.urandom(32)
-        self.read_encrypted = OAEP(256, hash=SHA256).encode(self.read_key, os.urandom(32))
-        self.read_encrypted = keychain.rsa_read.encrypt(self.read_encrypted, '')[0]
-        self.create_encrypted = OAEP(256, hash=SHA256).encode(self.create_key, os.urandom(32))
-        self.create_encrypted = keychain.rsa_create.encrypt(self.create_encrypted, '')[0]
-
     def id_hash(self, data):
     def id_hash(self, data):
-        return HMAC.new(self.keychain.aes_id, data, SHA256).digest()
+        return HMAC.new(self.aes_id, data, SHA256).digest()
 
 
     def encrypt_read(self, data):
     def encrypt_read(self, data):
         data = zlib.compress(data)
         data = zlib.compress(data)
@@ -163,7 +158,7 @@ class CryptoManager(object):
     def decrypt(self, data):
     def decrypt(self, data):
         type = data[0]
         type = data[0]
         if type == self.READ:
         if type == self.READ:
-            key = self.decrypt_key(data[1:257], self.keychain.rsa_read)
+            key = self.decrypt_key(data[1:257], self.rsa_read)
             hash = data[257:289]
             hash = data[257:289]
             counter = Counter.new(128, initial_value=bytes_to_long(hash[:16]), allow_wraparound=True)
             counter = Counter.new(128, initial_value=bytes_to_long(hash[:16]), allow_wraparound=True)
             data = AES.new(key, AES.MODE_CTR, counter=counter).decrypt(data[289:])
             data = AES.new(key, AES.MODE_CTR, counter=counter).decrypt(data[289:])
@@ -171,7 +166,7 @@ class CryptoManager(object):
                 raise IntegrityError('decryption failed')
                 raise IntegrityError('decryption failed')
             return zlib.decompress(data), hash
             return zlib.decompress(data), hash
         elif type == self.CREATE:
         elif type == self.CREATE:
-            key = self.decrypt_key(data[1:257], self.keychain.rsa_create)
+            key = self.decrypt_key(data[1:257], self.rsa_create)
             hash = data[257:289]
             hash = data[257:289]
             counter = Counter.new(128, initial_value=bytes_to_long(hash[:16]), allow_wraparound=True)
             counter = Counter.new(128, initial_value=bytes_to_long(hash[:16]), allow_wraparound=True)
             data = AES.new(key, AES.MODE_CTR, '', counter=counter).decrypt(data[289:])
             data = AES.new(key, AES.MODE_CTR, '', counter=counter).decrypt(data[289:])
@@ -181,3 +176,5 @@ class CryptoManager(object):
         else:
         else:
             raise Exception('Unknown pack type %d found' % ord(type))
             raise Exception('Unknown pack type %d found' % ord(type))
 
 
+
+