Răsfoiți Sursa

Added support for unencrypted and passphrase encrypted stores

Jonas Borgström 12 ani în urmă
părinte
comite
f28933254a
4 a modificat fișierele cu 325 adăugiri și 167 ștergeri
  1. 30 39
      darc/archiver.py
  2. 16 13
      darc/helpers.py
  3. 272 111
      darc/key.py
  4. 7 4
      darc/test.py

+ 30 - 39
darc/archiver.py

@@ -8,7 +8,7 @@ import sys
 from .archive import Archive
 from .archive import Archive
 from .store import Store
 from .store import Store
 from .cache import Cache
 from .cache import Cache
-from .key import Key
+from .key import key_creator
 from .helpers import location_validator, format_time, \
 from .helpers import location_validator, format_time, \
     format_file_mode, IncludePattern, ExcludePattern, exclude_path, adjust_patterns, to_localtime, \
     format_file_mode, IncludePattern, ExcludePattern, exclude_path, adjust_patterns, to_localtime, \
     get_cache_dir, format_timedelta, prune_split, Manifest, Location
     get_cache_dir, format_timedelta, prune_split, Manifest, Location
@@ -22,9 +22,11 @@ class Archiver(object):
 
 
     def open_store(self, location, create=False):
     def open_store(self, location, create=False):
         if location.proto == 'ssh':
         if location.proto == 'ssh':
-            return RemoteStore(location, create=create)
+            store = RemoteStore(location, create=create)
         else:
         else:
-            return Store(location.path, create=create)
+            store = Store(location.path, create=create)
+        store._location = location
+        return store
 
 
     def print_error(self, msg, *args):
     def print_error(self, msg, *args):
         msg = args and msg % args or msg
         msg = args and msg % args or msg
@@ -45,31 +47,24 @@ class Archiver(object):
     def do_init(self, args):
     def do_init(self, args):
         print 'Initializing store "%s"' % args.store.orig
         print 'Initializing store "%s"' % args.store.orig
         store = self.open_store(args.store, create=True)
         store = self.open_store(args.store, create=True)
-        key = Key.create(store, args.store.to_key_filename(), password=args.password)
-        print 'Key file "%s" created.' % key.path
-        print 'Remember that this file (and password) is needed to access your data. Keep it safe!'
-        print
-        manifest = Manifest(store, key, dont_load=True)
+        key = key_creator(store, args)
+        manifest = Manifest()
+        manifest.store = store
+        manifest.key = key
         manifest.write()
         manifest.write()
         store.commit()
         store.commit()
         return self.exit_code
         return self.exit_code
 
 
-    def do_chpasswd(self, args):
-        if os.path.isfile(args.store_or_file):
-            key = Key()
-            key.open(args.store_or_file)
-        else:
-            store = self.open_store(Location(args.store_or_file))
-            key = Key(store)
-        key.chpasswd()
-        print 'Key file "%s" updated' % key.path
+    def do_change_passphrase(self, args):
+        store = self.open_store(Location(args.store))
+        manifest, key = Manifest.load(store)
+        key.change_passphrase()
         return self.exit_code
         return self.exit_code
 
 
     def do_create(self, args):
     def do_create(self, args):
         t0 = datetime.now()
         t0 = datetime.now()
         store = self.open_store(args.archive)
         store = self.open_store(args.archive)
-        key = Key(store)
-        manifest = Manifest(store, key)
+        manifest, key = Manifest.load(store)
         cache = Cache(store, key, manifest)
         cache = Cache(store, key, manifest)
         archive = Archive(store, key, manifest, args.archive.archive, cache=cache,
         archive = Archive(store, key, manifest, args.archive.archive, cache=cache,
                           create=True, checkpoint_interval=args.checkpoint_interval,
                           create=True, checkpoint_interval=args.checkpoint_interval,
@@ -158,8 +153,7 @@ class Archiver(object):
             self.print_verbose(item['path'])
             self.print_verbose(item['path'])
 
 
         store = self.open_store(args.archive)
         store = self.open_store(args.archive)
-        key = Key(store)
-        manifest = Manifest(store, key)
+        manifest, key = Manifest.load(store)
         archive = Archive(store, key, manifest, args.archive.archive,
         archive = Archive(store, key, manifest, args.archive.archive,
                           numeric_owner=args.numeric_owner)
                           numeric_owner=args.numeric_owner)
         dirs = []
         dirs = []
@@ -177,8 +171,7 @@ class Archiver(object):
 
 
     def do_delete(self, args):
     def do_delete(self, args):
         store = self.open_store(args.archive)
         store = self.open_store(args.archive)
-        key = Key(store)
-        manifest = Manifest(store, key)
+        manifest, key = Manifest.load(store)
         cache = Cache(store, key, manifest)
         cache = Cache(store, key, manifest)
         archive = Archive(store, key, manifest, args.archive.archive, cache=cache)
         archive = Archive(store, key, manifest, args.archive.archive, cache=cache)
         archive.delete(cache)
         archive.delete(cache)
@@ -186,8 +179,7 @@ class Archiver(object):
 
 
     def do_list(self, args):
     def do_list(self, args):
         store = self.open_store(args.src)
         store = self.open_store(args.src)
-        key = Key(store)
-        manifest = Manifest(store, key)
+        manifest, key = Manifest.load(store)
         if args.src.archive:
         if args.src.archive:
             tmap = {1: 'p', 2: 'c', 4: 'd', 6: 'b', 010: '-', 012: 'l', 014: 's'}
             tmap = {1: 'p', 2: 'c', 4: 'd', 6: 'b', 010: '-', 012: 'l', 014: 's'}
             archive = Archive(store, key, manifest, args.src.archive)
             archive = Archive(store, key, manifest, args.src.archive)
@@ -219,8 +211,7 @@ class Archiver(object):
 
 
     def do_verify(self, args):
     def do_verify(self, args):
         store = self.open_store(args.archive)
         store = self.open_store(args.archive)
-        key = Key(store)
-        manifest = Manifest(store, key)
+        manifest, key = Manifest.load(store)
         archive = Archive(store, key, manifest, args.archive.archive)
         archive = Archive(store, key, manifest, args.archive.archive)
 
 
         def start_cb(item):
         def start_cb(item):
@@ -239,8 +230,7 @@ class Archiver(object):
 
 
     def do_info(self, args):
     def do_info(self, args):
         store = self.open_store(args.archive)
         store = self.open_store(args.archive)
-        key = Key(store)
-        manifest = Manifest(store, key)
+        manifest, key = Manifest.load(store)
         cache = Cache(store, key, manifest)
         cache = Cache(store, key, manifest)
         archive = Archive(store, key, manifest, args.archive.archive, cache=cache)
         archive = Archive(store, key, manifest, args.archive.archive, cache=cache)
         stats = archive.calc_stats(cache)
         stats = archive.calc_stats(cache)
@@ -255,8 +245,7 @@ class Archiver(object):
 
 
     def do_prune(self, args):
     def do_prune(self, args):
         store = self.open_store(args.store)
         store = self.open_store(args.store)
-        key = Key(store)
-        manifest = Manifest(store, key)
+        manifest, key = Manifest.load(store)
         cache = Cache(store, key, manifest)
         cache = Cache(store, key, manifest)
         archives = list(sorted(Archive.list_archives(store, key, manifest, cache),
         archives = list(sorted(Archive.list_archives(store, key, manifest, cache),
                                key=attrgetter('ts'), reverse=True))
                                key=attrgetter('ts'), reverse=True))
@@ -284,7 +273,7 @@ class Archiver(object):
         for archive in keep:
         for archive in keep:
             self.print_verbose('Keeping archive "%s"' % archive.name)
             self.print_verbose('Keeping archive "%s"' % archive.name)
         for archive in to_delete:
         for archive in to_delete:
-            self.print_verbose('Purging archive "%s"', archive.name)
+            self.print_verbose('Pruning archive "%s"', archive.name)
             archive.delete(cache)
             archive.delete(cache)
         return self.exit_code
         return self.exit_code
 
 
@@ -307,17 +296,19 @@ class Archiver(object):
 
 
         subparser = subparsers.add_parser('init', parents=[common_parser])
         subparser = subparsers.add_parser('init', parents=[common_parser])
         subparser.set_defaults(func=self.do_init)
         subparser.set_defaults(func=self.do_init)
-        subparser.add_argument('-p', '--password', dest='password',
-                               help='Protect store key with password (Default: prompt)')
         subparser.add_argument('store',
         subparser.add_argument('store',
                                type=location_validator(archive=False),
                                type=location_validator(archive=False),
                                help='Store to create')
                                help='Store to create')
+        subparser.add_argument('--key-file', dest='keyfile',
+                               action='store_true', default=False,
+                               help='Encrypt data using key file')
+        subparser.add_argument('--passphrase', dest='passphrase',
+                               action='store_true', default=False,
+                               help='Encrypt data using passphrase derived key')
 
 
-        subparser = subparsers.add_parser('change-password', parents=[common_parser])
-        subparser.set_defaults(func=self.do_chpasswd)
-        subparser.add_argument('store_or_file', metavar='STORE_OR_KEY_FILE',
-                               type=str,
-                               help='Key file to operate on')
+        subparser = subparsers.add_parser('change-passphrase', parents=[common_parser])
+        subparser.set_defaults(func=self.do_change_passphrase)
+        subparser.add_argument('store', type=location_validator(archive=False))
 
 
         subparser = subparsers.add_parser('create', parents=[common_parser])
         subparser = subparsers.add_parser('create', parents=[common_parser])
         subparser.set_defaults(func=self.do_create)
         subparser.set_defaults(func=self.do_create)

+ 16 - 13
darc/helpers.py

@@ -18,23 +18,26 @@ class Manifest(object):
 
 
     MANIFEST_ID = '\0' * 32
     MANIFEST_ID = '\0' * 32
 
 
-    def __init__(self, store, key, dont_load=False):
-        self.store = store
-        self.key = key
+    def __init__(self):
         self.archives = {}
         self.archives = {}
         self.config = {}
         self.config = {}
-        if not dont_load:
-            self.load()
 
 
-    def load(self):
-        data = self.key.decrypt(None, self.store.get(self.MANIFEST_ID))
-        self.id = self.key.id_hash(data)
-        manifest = msgpack.unpackb(data)
-        if not manifest.get('version') == 1:
+    @classmethod
+    def load(cls, store):
+        from .key import key_factory
+        manifest = cls()
+        manifest.store = store
+        cdata = store.get(manifest.MANIFEST_ID)
+        manifest.key = key = key_factory(store, cdata)
+        data = key.decrypt(None, cdata)
+        manifest.id = key.id_hash(data)
+        m = msgpack.unpackb(data)
+        if not m.get('version') == 1:
             raise ValueError('Invalid manifest version')
             raise ValueError('Invalid manifest version')
-        self.archives = manifest['archives']
-        self.config = manifest['config']
-        self.key.post_manifest_load(self.config)
+        manifest.archives = m['archives']
+        manifest.config = m['config']
+        key.post_manifest_load(manifest.config)
+        return manifest, key
 
 
     def write(self):
     def write(self):
         self.key.pre_manifest_write(self)
         self.key.pre_manifest_write(self)

+ 272 - 111
darc/key.py

@@ -1,8 +1,9 @@
 from __future__ import with_statement
 from __future__ import with_statement
 from getpass import getpass
 from getpass import getpass
-import hashlib
 import os
 import os
 import msgpack
 import msgpack
+import tempfile
+import unittest
 import zlib
 import zlib
 
 
 from Crypto.Cipher import AES
 from Crypto.Cipher import AES
@@ -16,68 +17,230 @@ from .helpers import IntegrityError, get_keys_dir
 
 
 PREFIX = '\0' * 8
 PREFIX = '\0' * 8
 
 
+KEYFILE = '\0'
+PASSPHRASE = '\1'
+PLAINTEXT = '\2'
+
+
+def key_creator(store, args):
+    if args.keyfile:
+        return KeyfileKey.create(store, args)
+    elif args.passphrase:
+        return PassphraseKey.create(store, args)
+    else:
+        return PlaintextKey.create(store, args)
+
+
+def key_factory(store, manifest_data):
+    if manifest_data[0] == KEYFILE:
+        return KeyfileKey.detect(store, manifest_data)
+    elif manifest_data[0] == PASSPHRASE:
+        return PassphraseKey.detect(store, manifest_data)
+    elif manifest_data[0] == PLAINTEXT:
+        return PlaintextKey.detect(store, manifest_data)
+    else:
+        raise Exception('Unkown Key type %d' % ord(manifest_data[0]))
+
 
 
 def SHA256_PDF(p, s):
 def SHA256_PDF(p, s):
     return HMAC.new(p, s, SHA256).digest()
     return HMAC.new(p, s, SHA256).digest()
 
 
 
 
-class Key(object):
+class KeyBase(object):
+
+    def id_hash(self, data):
+        """Return HMAC hash using the "id" HMAC key
+        """
+
+    def encrypt(self, data):
+        pass
+
+    def decrypt(self, id, data):
+        pass
+
+    def post_manifest_load(self, config):
+        pass
+
+    def pre_manifest_write(self, manifest):
+        pass
+
+
+class PlaintextKey(KeyBase):
+    TYPE = PLAINTEXT
+
+    chunk_seed = 0
+
+    @classmethod
+    def create(cls, store, args):
+        print 'Encryption NOT enabled.\nUse the --key-file or --passphrase options to enable encryption.'
+        return cls()
+
+    @classmethod
+    def detect(cls, store, manifest_data):
+        return cls()
+
+    def id_hash(self, data):
+        return SHA256.new(data).digest()
+
+    def encrypt(self, data):
+        return ''.join([self.TYPE, zlib.compress(data)])
+
+    def decrypt(self, id, data):
+        if data[0] != self.TYPE:
+            raise IntegrityError('Invalid encryption envelope')
+        data = zlib.decompress(data[1:])
+        if id and SHA256.new(data).digest() != id:
+            raise IntegrityError('Chunk id verification failed')
+        return data
+
+
+class AESKeyBase(KeyBase):
+
+    def post_manifest_load(self, config):
+        iv = bytes_to_long(config['aes_counter']) + 100
+        self.counter = Counter.new(64, initial_value=iv, prefix=PREFIX)
+
+    def pre_manifest_write(self, manifest):
+        manifest.config['aes_counter'] = long_to_bytes(self.counter.next_value(), 8)
+
+    def id_hash(self, data):
+        """Return HMAC hash using the "id" HMAC key
+        """
+        return HMAC.new(self.id_key, data, SHA256).digest()
+
+    def encrypt(self, data):
+        data = zlib.compress(data)
+        nonce = long_to_bytes(self.counter.next_value(), 8)
+        data = ''.join((nonce, AES.new(self.enc_key, AES.MODE_CTR, '',
+                                       counter=self.counter).encrypt(data)))
+        hash = HMAC.new(self.enc_hmac_key, data, SHA256).digest()
+        return ''.join((self.TYPE, hash, data))
+
+    def decrypt(self, id, data):
+        if data[0] != self.TYPE:
+            raise IntegrityError('Invalid encryption envelope')
+        hash = data[1:33]
+        if HMAC.new(self.enc_hmac_key, data[33:], SHA256).digest() != hash:
+            raise IntegrityError('Encryption envelope checksum mismatch')
+        nonce = bytes_to_long(data[33:41])
+        counter = Counter.new(64, initial_value=nonce, prefix=PREFIX)
+        data = zlib.decompress(AES.new(self.enc_key, AES.MODE_CTR, counter=counter).decrypt(data[41:]))
+        if id and HMAC.new(self.id_key, data, SHA256).digest() != id:
+            raise IntegrityError('Chunk id verification failed')
+        return data
+
+    def init_from_random_data(self, data):
+        self.enc_key = data[0:32]
+        self.enc_hmac_key = data[32:64]
+        self.id_key = data[64:96]
+        self.chunk_seed = bytes_to_long(data[96:100])
+        # Convert to signed int32
+        if self.chunk_seed & 0x80000000:
+            self.chunk_seed = self.chunk_seed - 0xffffffff - 1
+        self.counter = Counter.new(64, initial_value=1, prefix=PREFIX)
+
+
+class PassphraseKey(AESKeyBase):
+    TYPE = PASSPHRASE
+    iterations = 10000
+
+    @classmethod
+    def create(cls, store, args):
+        key = cls()
+        passphrase = os.environ.get('DARC_PASSPHRASE')
+        if passphrase is not None:
+            passphrase2 = passphrase
+        else:
+            passphrase, passphrase2 = 1, 2
+        while passphrase != passphrase2:
+            passphrase = getpass('Enter passphrase: ')
+            if not passphrase:
+                print 'Passphrase must not be blank'
+                continue
+            passphrase2 = getpass('Enter same passphrase again: ')
+            if passphrase != passphrase2:
+                print 'Passphrases do not match'
+        key.init(store, passphrase)
+        if passphrase:
+            print 'Remember your passphrase. Your data will be inaccessible without it.'
+        return key
+
+    @classmethod
+    def detect(cls, store, manifest_data):
+        prompt = 'Enter passphrase for %s: ' % store._location.orig
+        key = cls()
+        passphrase = os.environ.get('DARC_PASSPHRASE')
+        if passphrase is None:
+            passphrase = getpass(prompt)
+        while True:
+            key.init(store, passphrase)
+            try:
+                key.decrypt(None, manifest_data)
+                return key
+            except IntegrityError:
+                passphrase = getpass(prompt)
+
+    def init(self, store, passphrase):
+        self.init_from_random_data(PBKDF2(passphrase, store.id, 100, self.iterations, SHA256_PDF))
+
+
+class KeyfileKey(AESKeyBase):
     FILE_ID = 'DARC KEY'
     FILE_ID = 'DARC KEY'
+    TYPE = KEYFILE
 
 
-    def __init__(self, store=None, password=None):
-        if store:
-            self.open(self.find_key_file(store), password=password)
+    @classmethod
+    def detect(cls, store, manifest_data):
+        key = cls()
+        path = cls.find_key_file(store)
+        prompt = 'Enter passphrase for key file %s: ' % path
+        passphrase = os.environ.get('DARC_PASSPHRASE', '')
+        while not key.load(path, passphrase):
+            passphrase = getpass(prompt)
+        return key
 
 
-    def find_key_file(self, store):
+    @classmethod
+    def find_key_file(cls, store):
         id = store.id.encode('hex')
         id = store.id.encode('hex')
         keys_dir = get_keys_dir()
         keys_dir = get_keys_dir()
         for name in os.listdir(keys_dir):
         for name in os.listdir(keys_dir):
             filename = os.path.join(keys_dir, name)
             filename = os.path.join(keys_dir, name)
             with open(filename, 'rb') as fd:
             with open(filename, 'rb') as fd:
                 line = fd.readline().strip()
                 line = fd.readline().strip()
-                if line and line.startswith(self.FILE_ID) and line[9:] == id:
+                if line and line.startswith(cls.FILE_ID) and line[9:] == id:
                     return filename
                     return filename
         raise Exception('Key file for store with ID %s not found' % id)
         raise Exception('Key file for store with ID %s not found' % id)
 
 
-    def open(self, filename, prompt=None, password=None):
-        prompt = prompt or 'Enter password for %s: ' % filename
+    def load(self, filename, passphrase):
         with open(filename, 'rb') as fd:
         with open(filename, 'rb') as fd:
-            lines = fd.readlines()
-            if not lines[0].startswith(self.FILE_ID) != self.FILE_ID:
-                raise ValueError('Not a DARC key file')
-            self.store_id = lines[0][len(self.FILE_ID):].strip().decode('hex')
-            cdata = (''.join(lines[1:])).decode('base64')
-        self.password = password or ''
-        data = self.decrypt_key_file(cdata, self.password)
-        while not data:
-            self.password = getpass(prompt)
-            if not self.password:
-                raise Exception('Key decryption failed')
-            data = self.decrypt_key_file(cdata, self.password)
-            if not data:
-                print 'Incorrect password'
-        key = msgpack.unpackb(data)
-        if key['version'] != 1:
-            raise IntegrityError('Invalid key file header')
-        self.store_id = key['store_id']
-        self.enc_key = key['enc_key']
-        self.enc_hmac_key = key['enc_hmac_key']
-        self.id_key = key['id_key']
-        self.chunk_seed = key['chunk_seed']
-        self.counter = Counter.new(64, initial_value=1, prefix=PREFIX)
-        self.path = filename
-
-    def post_manifest_load(self, config):
-        iv = bytes_to_long(config['aes_counter']) + 100
-        self.counter = Counter.new(64, initial_value=iv, prefix=PREFIX)
+            cdata = (''.join(fd.readlines()[1:])).decode('base64')
+        data = self.decrypt_key_file(cdata, passphrase)
+        if data:
+            key = msgpack.unpackb(data)
+            if key['version'] != 1:
+                raise IntegrityError('Invalid key file header')
+            self.store_id = key['store_id']
+            self.enc_key = key['enc_key']
+            self.enc_hmac_key = key['enc_hmac_key']
+            self.id_key = key['id_key']
+            self.chunk_seed = key['chunk_seed']
+            self.counter = Counter.new(64, initial_value=1, prefix=PREFIX)
+            self.path = filename
+            return True
 
 
-    def pre_manifest_write(self, manifest):
-        manifest.config['aes_counter'] = long_to_bytes(self.counter.next_value(), 8)
+    def decrypt_key_file(self, data, passphrase):
+        d = msgpack.unpackb(data)
+        assert d['version'] == 1
+        assert d['algorithm'] == 'SHA256'
+        key = PBKDF2(passphrase, d['salt'], 32, d['iterations'], SHA256_PDF)
+        data = AES.new(key, AES.MODE_CTR, counter=Counter.new(128)).decrypt(d['data'])
+        if HMAC.new(key, data, SHA256).digest() != d['hash']:
+            return None
+        return data
 
 
-    def encrypt_key_file(self, data, password):
+    def encrypt_key_file(self, data, passphrase):
         salt = get_random_bytes(32)
         salt = get_random_bytes(32)
         iterations = 10000
         iterations = 10000
-        key = PBKDF2(password, salt, 32, iterations, SHA256_PDF)
+        key = PBKDF2(passphrase, salt, 32, iterations, SHA256_PDF)
         hash = HMAC.new(key, data, SHA256).digest()
         hash = HMAC.new(key, data, SHA256).digest()
         cdata = AES.new(key, AES.MODE_CTR, counter=Counter.new(128)).encrypt(data)
         cdata = AES.new(key, AES.MODE_CTR, counter=Counter.new(128)).encrypt(data)
         d = {
         d = {
@@ -90,17 +253,7 @@ class Key(object):
         }
         }
         return msgpack.packb(d)
         return msgpack.packb(d)
 
 
-    def decrypt_key_file(self, data, password):
-        d = msgpack.unpackb(data)
-        assert d['version'] == 1
-        assert d['algorithm'] == 'SHA256'
-        key = PBKDF2(password, d['salt'], 32, d['iterations'], SHA256_PDF)
-        data = AES.new(key, AES.MODE_CTR, counter=Counter.new(128)).decrypt(d['data'])
-        if HMAC.new(key, data, SHA256).digest() != d['hash']:
-            return None
-        return data
-
-    def save(self, path, password):
+    def save(self, path, passphrase):
         key = {
         key = {
             'version': 1,
             'version': 1,
             'store_id': self.store_id,
             'store_id': self.store_id,
@@ -109,77 +262,85 @@ class Key(object):
             'id_key': self.enc_key,
             'id_key': self.enc_key,
             'chunk_seed': self.chunk_seed,
             'chunk_seed': self.chunk_seed,
         }
         }
-        data = self.encrypt_key_file(msgpack.packb(key), password)
+        data = self.encrypt_key_file(msgpack.packb(key), passphrase)
         with open(path, 'wb') as fd:
         with open(path, 'wb') as fd:
             fd.write('%s %s\n' % (self.FILE_ID, self.store_id.encode('hex')))
             fd.write('%s %s\n' % (self.FILE_ID, self.store_id.encode('hex')))
             fd.write(data.encode('base64'))
             fd.write(data.encode('base64'))
         self.path = path
         self.path = path
 
 
-    def chpasswd(self):
-        password, password2 = 1, 2
-        while password != password2:
-            password = getpass('New password: ')
-            password2 = getpass('New password again: ')
-            if password != password2:
-                print 'Passwords do not match'
-        self.save(self.path, password)
-        return 0
-
-    @staticmethod
-    def create(store, filename, password=None):
-        i = 1
+    def change_passphrase(self):
+        passphrase, passphrase2 = 1, 2
+        while passphrase != passphrase2:
+            passphrase = getpass('New passphrase: ')
+            passphrase2 = getpass('Enter same passphrase again: ')
+            if passphrase != passphrase2:
+                print 'Passphrases do not match'
+        self.save(self.path, passphrase)
+        print 'Key file "%s" updated' % self.path
+
+    @classmethod
+    def create(cls, store, args):
+        filename = args.store.to_key_filename()
         path = filename
         path = filename
+        i = 1
         while os.path.exists(path):
         while os.path.exists(path):
             i += 1
             i += 1
             path = filename + '.%d' % i
             path = filename + '.%d' % i
-        if password is not None:
-            password2 = password
+        passphrase = os.environ.get('DARC_PASSPHRASE')
+        if passphrase is not None:
+            passphrase2 = passphrase
         else:
         else:
-            password, password2 = 1, 2
-        while password != password2:
-            password = getpass('Key file password (Leave blank for no password): ')
-            password2 = getpass('Key file password again: ')
-            if password != password2:
-                print 'Passwords do not match'
-        key = Key()
+            passphrase, passphrase2 = 1, 2
+        while passphrase != passphrase2:
+            passphrase = getpass('Enter passphrase (empty for no passphrase):')
+            passphrase2 = getpass('Enter same passphrase again: ')
+            if passphrase != passphrase2:
+                print 'Passphrases do not match'
+        key = cls()
         key.store_id = store.id
         key.store_id = store.id
-        # Chunk AES256 encryption key
-        key.enc_key = get_random_bytes(32)
-        # Chunk encryption HMAC key
-        key.enc_hmac_key = get_random_bytes(32)
-        # Chunk id HMAC key
-        key.id_key = get_random_bytes(32)
-        # Chunkifier seed
-        key.chunk_seed = bytes_to_long(get_random_bytes(4))
-        # Convert to signed int32
-        if key.chunk_seed & 0x80000000:
-            key.chunk_seed = key.chunk_seed - 0xffffffff - 1
-        key.save(path, password)
-        return Key(store, password=password)
+        key.init_from_random_data(get_random_bytes(100))
+        key.save(path, passphrase)
+        print 'Key file "%s" created.' % key.path
+        print 'Keep this file safe. Your data will be inaccessible without it.'
+        return key
 
 
-    def id_hash(self, data):
-        """Return HMAC hash using the "id" HMAC key
-        """
-        return HMAC.new(self.id_key, data, SHA256).digest()
 
 
-    def encrypt(self, data):
-        data = zlib.compress(data)
-        nonce = long_to_bytes(self.counter.next_value(), 8)
-        data = ''.join((nonce, AES.new(self.enc_key, AES.MODE_CTR, '',
-                                       counter=self.counter).encrypt(data)))
-        hash = HMAC.new(self.enc_hmac_key, data, SHA256).digest()
-        return ''.join(('\0', hash, data))
+class KeyTestCase(unittest.TestCase):
 
 
-    def decrypt(self, id, data):
-        if data[0] != '\0':
-            raise IntegrityError('Invalid encryption envelope')
-        hash = data[1:33]
-        if HMAC.new(self.enc_hmac_key, data[33:], SHA256).digest() != hash:
-            raise IntegrityError('Encryption envelope checksum mismatch')
-        nonce = bytes_to_long(data[33:41])
-        counter = Counter.new(64, initial_value=nonce, prefix=PREFIX)
-        data = zlib.decompress(AES.new(self.enc_key, AES.MODE_CTR, counter=counter).decrypt(data[41:]))
-        if id and HMAC.new(self.id_key, data, SHA256).digest() != id:
-            raise IntegrityError('Chunk id verification failed')
-        return data
+    class MockStore(object):
+        id = '\0' * 32
+
+    def test_plaintext(self):
+        key = PlaintextKey.create(None, None)
+        data = 'foo'
+        self.assertEqual(key.id_hash(data).encode('hex'), '2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae')
+        self.assertEqual(data, key.decrypt(key.id_hash(data), key.encrypt(data)))
+
+    def test_keyfile(self):
+        class MockArgs(object):
+            class StoreArg(object):
+                def to_key_filename(self):
+                    return tempfile.mkstemp()[1]
+            store = StoreArg()
+        os.environ['DARC_PASSPHRASE'] = 'test'
+        key = KeyfileKey.create(self.MockStore(), MockArgs())
+        data = 'foo'
+        self.assertEqual(data, key.decrypt(key.id_hash(data), key.encrypt(data)))
+
+    def test_passphrase(self):
+        os.environ['DARC_PASSPHRASE'] = 'test'
+        key = PassphraseKey.create(self.MockStore(), None)
+        self.assertEqual(key.id_key.encode('hex'), 'f28e915da78a972786da47fee6c4bd2960a421b9bdbdb35a7942eb82552e9a72')
+        self.assertEqual(key.enc_hmac_key.encode('hex'), '169c6082f209e524ea97e2c75318936f6e93c101b9345942a95491e9ae1738ca')
+        self.assertEqual(key.enc_key.encode('hex'), 'c05dd423843d4dd32a52e4dc07bb11acabe215917fc5cf3a3df6c92b47af79ba')
+        self.assertEqual(key.chunk_seed, -324662077)
+        data = 'foo'
+        self.assertEqual(key.id_hash(data).encode('hex'), '016c27cd40dc8e84f196f3b43a9424e8472897e09f6935d0d3a82fb41664bad7')
+        self.assertEqual(data, key.decrypt(key.id_hash(data), key.encrypt(data)))
+
+
+def suite():
+    return unittest.TestLoader().loadTestsFromTestCase(KeyTestCase)
 
 
+if __name__ == '__main__':
+    unittest.main()

+ 7 - 4
darc/test.py

@@ -11,6 +11,7 @@ from xattr import xattr, XATTR_NOFOLLOW
 
 
 from . import helpers, lrucache
 from . import helpers, lrucache
 from .archiver import Archiver
 from .archiver import Archiver
+from .key import suite as KeySuite
 from .store import Store, suite as StoreSuite
 from .store import Store, suite as StoreSuite
 from .remote import Store, suite as RemoteStoreSuite
 from .remote import Store, suite as RemoteStoreSuite
 
 
@@ -40,6 +41,7 @@ class Test(unittest.TestCase):
         shutil.rmtree(self.tmpdir)
         shutil.rmtree(self.tmpdir)
 
 
     def darc(self, *args, **kwargs):
     def darc(self, *args, **kwargs):
+        os.environ['DARC_PASSPHRASE'] = ''
         exit_code = kwargs.get('exit_code', 0)
         exit_code = kwargs.get('exit_code', 0)
         args = list(args)
         args = list(args)
         try:
         try:
@@ -57,7 +59,7 @@ class Test(unittest.TestCase):
 
 
     def create_src_archive(self, name):
     def create_src_archive(self, name):
         src_dir = os.path.join(os.getcwd(), os.path.dirname(__file__))
         src_dir = os.path.join(os.getcwd(), os.path.dirname(__file__))
-        self.darc('init', '--password', '', self.store_location)
+        self.darc('init', self.store_location)
         self.darc('create', self.store_location + '::' + name, src_dir)
         self.darc('create', self.store_location + '::' + name, src_dir)
 
 
     def create_regual_file(self, name, size=0):
     def create_regual_file(self, name, size=0):
@@ -102,7 +104,7 @@ class Test(unittest.TestCase):
                 os.path.join(self.input_path, 'hardlink'))
                 os.path.join(self.input_path, 'hardlink'))
         os.symlink('somewhere', os.path.join(self.input_path, 'link1'))
         os.symlink('somewhere', os.path.join(self.input_path, 'link1'))
         os.mkfifo(os.path.join(self.input_path, 'fifo1'))
         os.mkfifo(os.path.join(self.input_path, 'fifo1'))
-        self.darc('init', '-p', '', self.store_location)
+        self.darc('init', self.store_location)
         self.darc('create', self.store_location + '::test', 'input')
         self.darc('create', self.store_location + '::test', 'input')
         self.darc('create', self.store_location + '::test.2', 'input')
         self.darc('create', self.store_location + '::test.2', 'input')
         self.darc('extract', self.store_location + '::test', 'output')
         self.darc('extract', self.store_location + '::test', 'output')
@@ -117,7 +119,7 @@ class Test(unittest.TestCase):
     def test_delete(self):
     def test_delete(self):
         self.create_regual_file('file1', size=1024 * 80)
         self.create_regual_file('file1', size=1024 * 80)
         self.create_regual_file('dir2/file2', size=1024 * 80)
         self.create_regual_file('dir2/file2', size=1024 * 80)
-        self.darc('init', '-p', '', self.store_location)
+        self.darc('init', self.store_location)
         self.darc('create', self.store_location + '::test', 'input')
         self.darc('create', self.store_location + '::test', 'input')
         self.darc('create', self.store_location + '::test.2', 'input')
         self.darc('create', self.store_location + '::test.2', 'input')
         self.darc('verify', self.store_location + '::test')
         self.darc('verify', self.store_location + '::test')
@@ -141,7 +143,7 @@ class Test(unittest.TestCase):
 
 
     def test_prune_store(self):
     def test_prune_store(self):
         src_dir = os.path.join(os.getcwd(), os.path.dirname(__file__))
         src_dir = os.path.join(os.getcwd(), os.path.dirname(__file__))
-        self.darc('init', '-p', '', self.store_location)
+        self.darc('init', self.store_location)
         self.darc('create', self.store_location + '::test1', src_dir)
         self.darc('create', self.store_location + '::test1', src_dir)
         self.darc('create', self.store_location + '::test2', src_dir)
         self.darc('create', self.store_location + '::test2', src_dir)
         self.darc('prune', self.store_location, '--daily=2')
         self.darc('prune', self.store_location, '--daily=2')
@@ -158,6 +160,7 @@ def suite():
     suite = unittest.TestSuite()
     suite = unittest.TestSuite()
     suite.addTest(unittest.TestLoader().loadTestsFromTestCase(Test))
     suite.addTest(unittest.TestLoader().loadTestsFromTestCase(Test))
     suite.addTest(unittest.TestLoader().loadTestsFromTestCase(RemoteTest))
     suite.addTest(unittest.TestLoader().loadTestsFromTestCase(RemoteTest))
+    suite.addTest(KeySuite())
     suite.addTest(StoreSuite())
     suite.addTest(StoreSuite())
     suite.addTest(RemoteStoreSuite())
     suite.addTest(RemoteStoreSuite())
     suite.addTest(doctest.DocTestSuite(helpers))
     suite.addTest(doctest.DocTestSuite(helpers))