Selaa lähdekoodia

Bring function tests back to life

Jonas Borgström 14 vuotta sitten
vanhempi
sitoutus
1d0914177d
8 muutettua tiedostoa jossa 52 lisäystä ja 35 poistoa
  1. 4 1
      darc/archive.py
  2. 9 4
      darc/archiver.py
  3. 2 7
      darc/cache.py
  4. 15 2
      darc/helpers.py
  5. 7 4
      darc/key.py
  6. 0 1
      darc/remote.py
  7. 6 5
      darc/store.py
  8. 9 11
      darc/test.py

+ 4 - 1
darc/archive.py

@@ -54,6 +54,8 @@ class Archive(object):
         unpacker = msgpack.Unpacker()
         counter = Counter(0)
         def cb(chunk, error, id):
+            if error:
+                raise error
             assert not error
             counter.dec()
             data, items_hash = self.key.decrypt(chunk)
@@ -220,9 +222,10 @@ class Archive(object):
 
     def verify_file(self, item, start, result):
         def verify_chunk(chunk, error, (id, i, last)):
+            if error:
+                raise error
             if i == 0:
                 start(item)
-            assert not error
             data, hash = self.key.decrypt(chunk)
             if self.key.id_hash(data) != id:
                 result(item, False)

+ 9 - 4
darc/archiver.py

@@ -10,7 +10,8 @@ from .store import Store
 from .cache import Cache
 from .key import Key
 from .helpers import location_validator, format_file_size, format_time,\
-    format_file_mode, IncludePattern, ExcludePattern, exclude_path, to_localtime
+    format_file_mode, IncludePattern, ExcludePattern, exclude_path, to_localtime, \
+    get_cache_dir
 from .remote import StoreServer, RemoteStore
 
 class Archiver(object):
@@ -46,7 +47,9 @@ class Archiver(object):
 
     def do_init(self, args):
         store = self.open_store(args.store, create=True)
-        key = Key.create(store, args.store.to_key_filename())
+        key = Key.create(store, args.store.to_key_filename(),
+                         password=args.password)
+        return self.exit_code
 
     def do_create(self, args):
         store = self.open_store(args.archive)
@@ -63,7 +66,7 @@ class Archiver(object):
         # Add darc cache dir to inode_skip list
         skip_inodes = set()
         try:
-            st = os.stat(Cache.cache_dir_path())
+            st = os.stat(get_cache_dir())
             skip_inodes.add((st.st_ino, st.st_dev))
         except IOError:
             pass
@@ -232,7 +235,9 @@ class Archiver(object):
 
         subparser = subparsers.add_parser('init')
         subparser.set_defaults(func=self.do_init)
-        subparser.add_argument('store', metavar='ARCHIVE',
+        subparser.add_argument('-p', '--password', dest='password',
+                               help='Protect store key with password (Default: prompt)')
+        subparser.add_argument('store',
                                type=location_validator(archive=False),
                                help='Store to create')
 

+ 2 - 7
darc/cache.py

@@ -6,7 +6,7 @@ import os
 import shutil
 
 from . import NS_CHUNK, NS_ARCHIVE_METADATA
-from .helpers import error_callback
+from .helpers import error_callback, get_cache_dir
 from .hashindex import ChunkIndex
 
 
@@ -18,7 +18,7 @@ class Cache(object):
         self.txn_active = False
         self.store = store
         self.key = key
-        self.path = os.path.join(Cache.cache_dir_path(), self.store.id.encode('hex'))
+        self.path = os.path.join(get_cache_dir(), self.store.id.encode('hex'))
         if not os.path.exists(self.path):
             self.create()
         self.open()
@@ -27,11 +27,6 @@ class Cache(object):
             self.sync()
             self.commit()
 
-    @staticmethod
-    def cache_dir_path():
-        """Return path to directory used for storing users cache files"""
-        return os.path.join(os.path.expanduser('~'), '.darc', 'cache')
-
     def create(self):
         """Create a new empty store at `path`
         """

+ 15 - 2
darc/helpers.py

@@ -30,15 +30,28 @@ class Counter(object):
         return '<Counter(%r)>' % self.v
 
 
+def get_keys_dir():
+    """Determine where to store keys and cache"""
+    return os.environ.get('DARC_KEYS_DIR',
+                          os.path.join(os.path.expanduser('~'), '.darc', 'keys'))
+
+def get_cache_dir():
+    """Determine where to store keys and cache"""
+    return os.environ.get('DARC_CACHE_DIR',
+                          os.path.join(os.path.expanduser('~'), '.darc', 'cache'))
+
+
 def deferrable(f):
     def wrapper(*args, **kw):
         callback = kw.pop('callback', None)
         if callback:
             data = kw.pop('callback_data', None)
             try:
-                callback(f(*args, **kw), None, data)
+                res = f(*args, **kw)
             except Exception, e:
                 callback(None, e, data)
+            else:
+                callback(res, None, data)
         else:
             return f(*args, **kw)
     return wrapper
@@ -288,7 +301,7 @@ class Location(object):
         name = re.sub('[^\w]', '_', self.path).strip('_')
         if self.proto != 'file':
             name = self.host + '__' + name
-        return os.path.join(os.path.expanduser('~'), '.darc', 'keys', name)
+        return os.path.join(get_keys_dir(), name)
 
     def __repr__(self):
         return "Location(%s)" % self

+ 7 - 4
darc/key.py

@@ -12,7 +12,7 @@ from Crypto.Util import Counter
 from Crypto.Util.number import bytes_to_long, long_to_bytes
 from Crypto.Random import get_random_bytes
 
-from .helpers import IntegrityError
+from .helpers import IntegrityError, get_keys_dir
 
 
 class Key(object):
@@ -24,7 +24,7 @@ class Key(object):
 
     def find_key_file(self, store):
         id = store.id.encode('hex')
-        keys_dir = os.path.join(os.path.expanduser('~'), '.darc', 'keys')
+        keys_dir = get_keys_dir()
         for name in os.listdir(keys_dir):
             filename = os.path.join(keys_dir, name)
             with open(filename, 'rb') as fd:
@@ -112,13 +112,16 @@ class Key(object):
         return 0
 
     @staticmethod
-    def create(store, filename):
+    def create(store, filename, password=None):
         i = 1
         path = filename
         while os.path.exists(path):
             i += 1
             path = filename + '.%d' % i
-        password, password2 = 1, 2
+        if password is not None:
+            password2 = password
+        else:
+            password, password2 = 1, 2
         while password != password2:
             password = getpass('Keychain password: ')
             password2 = getpass('Keychain password again: ')

+ 0 - 1
darc/remote.py

@@ -171,7 +171,6 @@ class RemoteStore(object):
         try:
             return self.cmd('get', (ns, id), callback, callback_data)
         except self.RPCError, e:
-            print e.name
             if e.name == 'DoesNotExist':
                 raise self.DoesNotExist
             raise

+ 6 - 5
darc/store.py

@@ -270,9 +270,8 @@ class BandIO(object):
         fd.seek(offset)
         data = fd.read(self.header_fmt.size)
         size, magic, hash, ns_, id_ = self.header_fmt.unpack(data)
-        assert magic == 0
-        assert ns == ns_
-        assert id == id_
+        if magic != 0 or ns != ns_ or id != id_:
+            raise IntegrityError('Invalid band entry header')
         data = fd.read(size - self.header_fmt.size)
         if crc32(data) & 0xffffffff != hash:
             raise IntegrityError('Band checksum mismatch')
@@ -281,12 +280,14 @@ class BandIO(object):
     def iter_objects(self, band, lookup):
         fd = self.get_fd(band)
         fd.seek(0)
-        assert fd.read(8) == 'DARCBAND'
+        if fd.read(8) != 'DARCBAND':
+            raise IntegrityError('Invalid band header')
         offset = 8
         data = fd.read(self.header_fmt.size)
         while data:
             size, magic, hash, ns, key = self.header_fmt.unpack(data)
-            assert magic == 0
+            if magic != 0:
+                raise IntegrityError('Unknown band entry header')
             offset += size
             if lookup(ns, key):
                 data = fd.read(size - self.header_fmt.size)

+ 9 - 11
darc/test.py

@@ -24,19 +24,22 @@ class Test(unittest.TestCase):
         self.store_path = os.path.join(self.tmpdir, 'store')
         self.input_path = os.path.join(self.tmpdir, 'input')
         self.output_path = os.path.join(self.tmpdir, 'output')
+        self.keys_path = os.path.join(self.tmpdir, 'keys')
+        self.cache_path = os.path.join(self.tmpdir, 'cache')
+        os.environ['DARC_KEYS_DIR'] = self.keys_path
+        os.environ['DARC_CACHE_DIR'] = self.cache_path
         os.mkdir(self.input_path)
         os.mkdir(self.output_path)
+        os.mkdir(self.keys_path)
+        os.mkdir(self.cache_path)
         os.chdir(self.tmpdir)
-        self.keychain = '/tmp/_test_dedupstore.keychain'
-        if not os.path.exists(self.keychain):
-            self.darc('init-keychain')
 
     def tearDown(self):
         shutil.rmtree(self.tmpdir)
 
     def darc(self, *args, **kwargs):
         exit_code = kwargs.get('exit_code', 0)
-        args = ['--keychain', self.keychain] + list(args)
+        args = list(args)
         try:
             stdout, stderr = sys.stdout, sys.stderr
             output = StringIO()
@@ -52,6 +55,7 @@ class Test(unittest.TestCase):
 
     def create_src_archive(self, name):
         src_dir = os.path.join(os.getcwd(), os.path.dirname(__file__))
+        self.darc('init', '--password', '', self.store_path)
         self.darc('create', self.store_path + '::' + name, src_dir)
 
     def create_regual_file(self, name, size=0):
@@ -96,6 +100,7 @@ class Test(unittest.TestCase):
                 os.path.join(self.input_path, 'hardlink'))
         os.symlink('somewhere', os.path.join(self.input_path, 'link1'))
         os.mkfifo(os.path.join(self.input_path, 'fifo1'))
+        self.darc('init', '-p', '', self.store_path)
         self.darc('create', self.store_path + '::test', 'input')
         self.darc('create', self.store_path + '::test.2', 'input')
         self.darc('extract', self.store_path + '::test', 'output')
@@ -110,13 +115,6 @@ class Test(unittest.TestCase):
         fd.close()
         self.darc('verify', self.store_path + '::test', exit_code=1)
 
-    def test_keychain(self):
-        keychain = os.path.join(self.tmpdir, 'keychain')
-        keychain2 = os.path.join(self.tmpdir, 'keychain2')
-        self.darc('-k', keychain, 'init-keychain')
-        self.darc('-k', keychain, 'change-password')
-        self.darc('-k', keychain, 'export-restricted', keychain2)
-
 
 def suite():
     suite = unittest.TestSuite()