Ver código fonte

Storage rewrite and more compact encryption envelope

Jonas Borgström 13 anos atrás
pai
commit
ef8c4a0097
9 arquivos alterados com 349 adições e 248 exclusões
  1. 0 1
      darc/__init__.py
  2. 12 39
      darc/archive.py
  3. 26 19
      darc/archiver.py
  4. 8 13
      darc/cache.py
  5. 43 0
      darc/helpers.py
  6. 20 9
      darc/key.py
  7. 2 2
      darc/remote.py
  8. 237 164
      darc/store.py
  9. 1 1
      darc/test.py

+ 0 - 1
darc/__init__.py

@@ -1,2 +1 @@
 # This is a python package
 # This is a python package
-

+ 12 - 39
darc/archive.py

@@ -26,46 +26,22 @@ class Archive(object):
     class DoesNotExist(Exception):
     class DoesNotExist(Exception):
         pass
         pass
 
 
-    def __init__(self, store, key, name=None, cache=None):
+    def __init__(self, store, key, manifest, name=None, cache=None):
         self.key = key
         self.key = key
         self.store = store
         self.store = store
         self.cache = cache
         self.cache = cache
+        self.manifest = manifest
         self.items = StringIO()
         self.items = StringIO()
         self.items_ids = []
         self.items_ids = []
         self.hard_links = {}
         self.hard_links = {}
         self.stats = Statistics()
         self.stats = Statistics()
         if name:
         if name:
-            manifest = Archive.read_manifest(self.store, self.key)
             try:
             try:
-                info = manifest['archives'][name]
+                info = self.manifest.archives[name]
             except KeyError:
             except KeyError:
                 raise Archive.DoesNotExist
                 raise Archive.DoesNotExist
             self.load(info['id'])
             self.load(info['id'])
 
 
-    @staticmethod
-    def read_manifest(store, key):
-        mid = store.meta['manifest']
-        if not mid:
-            return {'version': 1, 'archives': {}}
-        mid = mid.decode('hex')
-        data = key.decrypt(mid, store.get(mid))
-        manifest = msgpack.unpackb(data)
-        if not manifest.get('version') == 1:
-            raise ValueError('Invalid manifest version')
-        return manifest
-
-    def write_manifest(self, manifest):
-        mid = self.store.meta['manifest']
-        if mid:
-            self.cache.chunk_decref(mid.decode('hex'))
-        if manifest:
-            data = msgpack.packb(manifest)
-            mid = self.key.id_hash(data)
-            self.cache.add_chunk(mid, data, self.stats)
-            self.store.meta['manifest'] = mid.encode('hex')
-        else:
-            self.store.meta['manifest'] = ''
-
     def load(self, id):
     def load(self, id):
         self.id = id
         self.id = id
         data = self.key.decrypt(self.id, self.store.get(self.id))
         data = self.key.decrypt(self.id, self.store.get(self.id))
@@ -123,6 +99,8 @@ class Archive(object):
             self.items.write(chunks[-1])
             self.items.write(chunks[-1])
 
 
     def save(self, name, cache):
     def save(self, name, cache):
+        if name in self.manifest.archives:
+            raise ValueError('Archive %s already exists' % name)
         self.flush_items(flush=True)
         self.flush_items(flush=True)
         metadata = {
         metadata = {
             'version': 1,
             'version': 1,
@@ -136,10 +114,8 @@ class Archive(object):
         data = msgpack.packb(metadata)
         data = msgpack.packb(metadata)
         self.id = self.key.id_hash(data)
         self.id = self.key.id_hash(data)
         cache.add_chunk(self.id, data, self.stats)
         cache.add_chunk(self.id, data, self.stats)
-        manifest = Archive.read_manifest(self.store, self.key)
-        assert not name in manifest['archives']
-        manifest['archives'][name] = {'id': self.id, 'time': metadata['time']}
-        self.write_manifest(manifest)
+        self.manifest.archives[name] = {'id': self.id, 'time': metadata['time']}
+        self.manifest.write()
         self.store.commit()
         self.store.commit()
         cache.commit()
         cache.commit()
 
 
@@ -290,10 +266,8 @@ class Archive(object):
             self.store.get(id, callback=callback, callback_data=id)
             self.store.get(id, callback=callback, callback_data=id)
         self.store.flush_rpc()
         self.store.flush_rpc()
         self.cache.chunk_decref(self.id)
         self.cache.chunk_decref(self.id)
-        manifest = Archive.read_manifest(self.store, self.key)
-        assert self.name in manifest['archives']
-        del manifest['archives'][self.name]
-        self.write_manifest(manifest)
+        del self.manifest.archives[self.name]
+        self.manifest.write()
         self.store.commit()
         self.store.commit()
         cache.commit()
         cache.commit()
 
 
@@ -371,10 +345,9 @@ class Archive(object):
         self.add_item(item)
         self.add_item(item)
 
 
     @staticmethod
     @staticmethod
-    def list_archives(store, key, cache=None):
-        manifest = Archive.read_manifest(store, key)
-        for name, info in manifest['archives'].items():
-            archive = Archive(store, key, cache=cache)
+    def list_archives(store, key, manifest, cache=None):
+        for name, info in manifest.archives.items():
+            archive = Archive(store, key, manifest, cache=cache)
             archive.load(info['id'])
             archive.load(info['id'])
             yield archive
             yield archive
 
 

+ 26 - 19
darc/archiver.py

@@ -11,7 +11,7 @@ from .cache import Cache
 from .key import Key
 from .key import Key
 from .helpers import location_validator, format_time, \
 from .helpers import location_validator, format_time, \
     format_file_mode, IncludePattern, ExcludePattern, exclude_path, to_localtime, \
     format_file_mode, IncludePattern, ExcludePattern, exclude_path, to_localtime, \
-    get_cache_dir, format_timedelta, purge_split
+    get_cache_dir, format_timedelta, purge_split, Manifest
 from .remote import StoreServer, RemoteStore
 from .remote import StoreServer, RemoteStore
 
 
 class Archiver(object):
 class Archiver(object):
@@ -43,23 +43,24 @@ class Archiver(object):
 
 
     def do_init(self, args):
     def do_init(self, args):
         store = self.open_store(args.store, create=True)
         store = self.open_store(args.store, create=True)
-        key = Key.create(store, args.store.to_key_filename(),
+        Key.create(store, args.store.to_key_filename(),
                          password=args.password)
                          password=args.password)
+        key = Key(store)
+        manifest = Manifest(store, key, dont_load=True)
+        manifest.write()
+        store.commit()
         return self.exit_code
         return self.exit_code
 
 
     def do_create(self, args):
     def do_create(self, args):
         t0 = datetime.now()
         t0 = datetime.now()
         store = self.open_store(args.archive)
         store = self.open_store(args.archive)
         key = Key(store)
         key = Key(store)
-        try:
-            Archive(store, key, args.archive.archive)
-        except Archive.DoesNotExist:
-            pass
-        else:
+        manifest = Manifest(store, key)
+        if args.archive.archive in manifest.archives:
             self.print_error('Archive already exists')
             self.print_error('Archive already exists')
             return self.exit_code
             return self.exit_code
-        cache = Cache(store, key)
-        archive = Archive(store, key, cache=cache)
+        cache = Cache(store, key, manifest)
+        archive = Archive(store, key, manifest, cache=cache)
         # Add darc cache dir to inode_skip list
         # Add darc cache dir to inode_skip list
         skip_inodes = set()
         skip_inodes = set()
         try:
         try:
@@ -142,7 +143,8 @@ class Archiver(object):
                 archive.extract_item(dirs.pop(-1), args.dest)
                 archive.extract_item(dirs.pop(-1), args.dest)
         store = self.open_store(args.archive)
         store = self.open_store(args.archive)
         key = Key(store)
         key = Key(store)
-        archive = Archive(store, key, args.archive.archive)
+        manifest = Manifest(store, key)
+        archive = Archive(store, key, manifest, args.archive.archive)
         dirs = []
         dirs = []
         archive.iter_items(extract_cb)
         archive.iter_items(extract_cb)
         store.flush_rpc()
         store.flush_rpc()
@@ -153,8 +155,9 @@ class Archiver(object):
     def do_delete(self, args):
     def do_delete(self, args):
         store = self.open_store(args.archive)
         store = self.open_store(args.archive)
         key = Key(store)
         key = Key(store)
-        cache = Cache(store, key)
-        archive = Archive(store, key, args.archive.archive, cache=cache)
+        manifest = Manifest(store, key)
+        cache = Cache(store, key, manifest)
+        archive = Archive(store, key, manifest, args.archive.archive, cache=cache)
         archive.delete(cache)
         archive.delete(cache)
         return self.exit_code
         return self.exit_code
 
 
@@ -183,20 +186,22 @@ class Archiver(object):
 
 
         store = self.open_store(args.src)
         store = self.open_store(args.src)
         key = Key(store)
         key = Key(store)
+        manifest = Manifest(store, key)
         if args.src.archive:
         if args.src.archive:
             tmap = {1: 'p', 2: 'c', 4: 'd', 6: 'b', 010: '-', 012: 'l', 014: 's'}
             tmap = {1: 'p', 2: 'c', 4: 'd', 6: 'b', 010: '-', 012: 'l', 014: 's'}
-            archive = Archive(store, key, args.src.archive)
+            archive = Archive(store, key, manifest, args.src.archive)
             archive.iter_items(callback)
             archive.iter_items(callback)
             store.flush_rpc()
             store.flush_rpc()
         else:
         else:
-            for archive in sorted(Archive.list_archives(store, key), key=attrgetter('ts')):
+            for archive in sorted(Archive.list_archives(store, key, manifest), key=attrgetter('ts')):
                 print '%-20s %s' % (archive.metadata['name'], to_localtime(archive.ts).strftime('%c'))
                 print '%-20s %s' % (archive.metadata['name'], to_localtime(archive.ts).strftime('%c'))
         return self.exit_code
         return self.exit_code
 
 
     def do_verify(self, args):
     def do_verify(self, args):
         store = self.open_store(args.archive)
         store = self.open_store(args.archive)
         key = Key(store)
         key = Key(store)
-        archive = Archive(store, key, args.archive.archive)
+        manifest = Manifest(store, key)
+        archive = Archive(store, key, manifest, args.archive.archive)
         def start_cb(item):
         def start_cb(item):
             self.print_verbose('%s ...', item['path'], newline=False)
             self.print_verbose('%s ...', item['path'], newline=False)
         def result_cb(item, success):
         def result_cb(item, success):
@@ -217,8 +222,9 @@ class Archiver(object):
     def do_info(self, args):
     def do_info(self, args):
         store = self.open_store(args.archive)
         store = self.open_store(args.archive)
         key = Key(store)
         key = Key(store)
-        cache = Cache(store, key)
-        archive = Archive(store, key, args.archive.archive, cache=cache)
+        manifest = Manifest(store, key)
+        cache = Cache(store, key, manifest)
+        archive = Archive(store, key, manifest, args.archive.archive, cache=cache)
         stats = archive.calc_stats(cache)
         stats = archive.calc_stats(cache)
         print 'Name:', archive.name
         print 'Name:', archive.name
         print 'Fingerprint: %s' % archive.id.encode('hex')
         print 'Fingerprint: %s' % archive.id.encode('hex')
@@ -232,8 +238,9 @@ class Archiver(object):
     def do_purge(self, args):
     def do_purge(self, args):
         store = self.open_store(args.store)
         store = self.open_store(args.store)
         key = Key(store)
         key = Key(store)
-        cache = Cache(store, key)
-        archives = list(sorted(Archive.list_archives(store, key, cache),
+        manifest = Manifest(store, key)
+        cache = Cache(store, key, manifest)
+        archives = list(sorted(Archive.list_archives(store, key, manifest, cache),
                                key=attrgetter('ts'), reverse=True))
                                key=attrgetter('ts'), reverse=True))
         if args.hourly + args.daily + args.weekly + args.monthly + args.yearly == 0:
         if args.hourly + args.daily + args.weekly + args.monthly + args.yearly == 0:
             self.print_error('At least one of the "hourly", "daily", "weekly", "monthly" or "yearly" '
             self.print_error('At least one of the "hourly", "daily", "weekly", "monthly" or "yearly" '

+ 8 - 13
darc/cache.py

@@ -5,7 +5,6 @@ import msgpack
 import os
 import os
 import shutil
 import shutil
 
 
-from .archive import Archive
 from .helpers import error_callback, get_cache_dir
 from .helpers import error_callback, get_cache_dir
 from .hashindex import ChunkIndex
 from .hashindex import ChunkIndex
 
 
@@ -14,15 +13,16 @@ class Cache(object):
     """Client Side cache
     """Client Side cache
     """
     """
 
 
-    def __init__(self, store, key):
+    def __init__(self, store, key, manifest):
         self.txn_active = False
         self.txn_active = False
         self.store = store
         self.store = store
         self.key = key
         self.key = key
-        self.path = os.path.join(get_cache_dir(), store.meta['id'])
+        self.manifest = manifest
+        self.path = os.path.join(get_cache_dir(), store.id.encode('hex'))
         if not os.path.exists(self.path):
         if not os.path.exists(self.path):
             self.create()
             self.create()
         self.open()
         self.open()
-        if self.manifest != store.meta['manifest']:
+        if self.manifest.id != self.manifest_id:
             self.sync()
             self.sync()
             self.commit()
             self.commit()
 
 
@@ -35,7 +35,7 @@ class Cache(object):
         config = RawConfigParser()
         config = RawConfigParser()
         config.add_section('cache')
         config.add_section('cache')
         config.set('cache', 'version', '1')
         config.set('cache', 'version', '1')
-        config.set('cache', 'store', self.store.meta['id'])
+        config.set('cache', 'store', self.store.id.encode('hex'))
         config.set('cache', 'manifest', '')
         config.set('cache', 'manifest', '')
         with open(os.path.join(self.path, 'config'), 'wb') as fd:
         with open(os.path.join(self.path, 'config'), 'wb') as fd:
             config.write(fd)
             config.write(fd)
@@ -54,7 +54,7 @@ class Cache(object):
         if self.config.getint('cache', 'version') != 1:
         if self.config.getint('cache', 'version') != 1:
             raise Exception('%s Does not look like a darc cache')
             raise Exception('%s Does not look like a darc cache')
         self.id = self.config.get('cache', 'store')
         self.id = self.config.get('cache', 'store')
-        self.manifest = self.config.get('cache', 'manifest')
+        self.manifest_id = self.config.get('cache', 'manifest').decode('hex')
         self.chunks = ChunkIndex(os.path.join(self.path, 'chunks'))
         self.chunks = ChunkIndex(os.path.join(self.path, 'chunks'))
         self.files = None
         self.files = None
 
 
@@ -91,7 +91,7 @@ class Cache(object):
             with open(os.path.join(self.path, 'files'), 'wb') as fd:
             with open(os.path.join(self.path, 'files'), 'wb') as fd:
                 for item in self.files.iteritems():
                 for item in self.files.iteritems():
                     msgpack.pack(item, fd)
                     msgpack.pack(item, fd)
-        self.config.set('cache', 'manifest', self.store.meta['manifest'])
+        self.config.set('cache', 'manifest', self.manifest.id.encode('hex'))
         with open(os.path.join(self.path, 'config'), 'w') as fd:
         with open(os.path.join(self.path, 'config'), 'w') as fd:
             self.config.write(fd)
             self.config.write(fd)
         self.chunks.flush()
         self.chunks.flush()
@@ -141,13 +141,8 @@ class Cache(object):
         self.begin_txn()
         self.begin_txn()
         print 'Initializing cache...'
         print 'Initializing cache...'
         self.chunks.clear()
         self.chunks.clear()
-        # Add manifest chunk to chunk index
-        mid = self.store.meta['manifest'].decode('hex')
-        cdata = self.store.get(mid)
-        mdata = self.key.decrypt(mid, cdata)
-        self.chunks[mid] = 1, len(mdata), len(cdata)
         unpacker = msgpack.Unpacker()
         unpacker = msgpack.Unpacker()
-        for name, info in Archive.read_manifest(self.store, self.key)['archives'].items():
+        for name, info in self.manifest.archives.items():
             id = info['id']
             id = info['id']
             cdata = self.store.get(id)
             cdata = self.store.get(id)
             data = self.key.decrypt(id, cdata)
             data = self.key.decrypt(id, cdata)

+ 43 - 0
darc/helpers.py

@@ -4,6 +4,7 @@ from datetime import datetime, timedelta
 from fnmatch import fnmatchcase
 from fnmatch import fnmatchcase
 from operator import attrgetter
 from operator import attrgetter
 import grp
 import grp
+import msgpack
 import os
 import os
 import pwd
 import pwd
 import re
 import re
@@ -12,6 +13,38 @@ import sys
 import time
 import time
 import urllib
 import urllib
 
 
+class Manifest(object):
+
+    MANIFEST_ID = '\0' * 32
+
+    def __init__(self, store, key, dont_load=False):
+        self.store = store
+        self.key = key
+        self.archives = {}
+        self.config = {}
+        if not dont_load:
+            self.load()
+
+    def load(self):
+        data = self.key.decrypt(None, self.store.get(self.MANIFEST_ID))
+        self.id = self.key.id_hash(data)
+        manifest = msgpack.unpackb(data)
+        if not manifest.get('version') == 1:
+            raise ValueError('Invalid manifest version')
+        self.archives = manifest['archives']
+        self.config = manifest['config']
+        self.key.post_manifest_load(self.config)
+
+    def write(self):
+        self.key.pre_manifest_write(self)
+        data = msgpack.packb({
+            'version': 1,
+            'archives': self.archives,
+            'config': self.config,
+        })
+        self.id = self.key.id_hash(data)
+        self.store.put(self.MANIFEST_ID, self.key.encrypt(data))
+
 
 
 def purge_split(archives, pattern, n, skip=[]):
 def purge_split(archives, pattern, n, skip=[]):
     items = {}
     items = {}
@@ -359,3 +392,13 @@ def location_validator(archive=None):
     return validator
     return validator
 
 
 
 
+def read_msgpack(filename):
+    with open(filename, 'rb') as fd:
+        return msgpack.unpack(fd)
+
+def write_msgpack(filename, d):
+    with open(filename+'.tmp', 'wb') as fd:
+        msgpack.pack(d, fd)
+        fd.flush()
+        os.fsync(fd)
+    os.rename(filename+'.tmp', filename)

+ 20 - 9
darc/key.py

@@ -14,6 +14,7 @@ from Crypto.Random import get_random_bytes
 
 
 from .helpers import IntegrityError, get_keys_dir
 from .helpers import IntegrityError, get_keys_dir
 
 
+PREFIX = '\0' * 8
 
 
 class Key(object):
 class Key(object):
     FILE_ID = 'DARC KEY'
     FILE_ID = 'DARC KEY'
@@ -23,7 +24,7 @@ class Key(object):
             self.open(self.find_key_file(store))
             self.open(self.find_key_file(store))
 
 
     def find_key_file(self, store):
     def find_key_file(self, store):
-        id = store.meta['id']
+        id = store.id.encode('hex')
         keys_dir = get_keys_dir()
         keys_dir = get_keys_dir()
         for name in os.listdir(keys_dir):
         for name in os.listdir(keys_dir):
             filename = os.path.join(keys_dir, name)
             filename = os.path.join(keys_dir, name)
@@ -57,7 +58,14 @@ class Key(object):
         self.enc_hmac_key = key['enc_hmac_key']
         self.enc_hmac_key = key['enc_hmac_key']
         self.id_key = key['id_key']
         self.id_key = key['id_key']
         self.chunk_seed = key['chunk_seed']
         self.chunk_seed = key['chunk_seed']
-        self.counter = Counter.new(128, initial_value=bytes_to_long(os.urandom(16)), allow_wraparound=True)
+        self.counter = Counter.new(64, initial_value=1, prefix=PREFIX)
+
+    def post_manifest_load(self, config):
+        iv = bytes_to_long(config['aes_counter'])+100
+        self.counter = Counter.new(64, initial_value=iv, prefix=PREFIX)
+
+    def pre_manifest_write(self, manifest):
+        manifest.config['aes_counter'] = long_to_bytes(self.counter.next_value(), 8)
 
 
     def encrypt_key_file(self, data, password):
     def encrypt_key_file(self, data, password):
         salt = get_random_bytes(32)
         salt = get_random_bytes(32)
@@ -127,7 +135,7 @@ class Key(object):
             if password != password2:
             if password != password2:
                 print 'Passwords do not match'
                 print 'Passwords do not match'
         key = Key()
         key = Key()
-        key.store_id = store.meta['id'].decode('hex')
+        key.store_id = store.id
         # Chunk AES256 encryption key
         # Chunk AES256 encryption key
         key.enc_key = get_random_bytes(32)
         key.enc_key = get_random_bytes(32)
         # Chunk encryption HMAC key
         # Chunk encryption HMAC key
@@ -135,7 +143,10 @@ class Key(object):
         # Chunk id HMAC key
         # Chunk id HMAC key
         key.id_key = get_random_bytes(32)
         key.id_key = get_random_bytes(32)
         # Chunkifier seed
         # Chunkifier seed
-        key.chunk_seed = bytes_to_long(get_random_bytes(4)) & 0x7fffffff
+        key.chunk_seed = bytes_to_long(get_random_bytes(4))
+        # Convert to signed int32
+        if key.chunk_seed & 0x80000000:
+            key.chunk_seed = key.chunk_seed - 0xffffffff - 1
         key.save(path, password)
         key.save(path, password)
         return 0
         return 0
 
 
@@ -146,7 +157,7 @@ class Key(object):
 
 
     def encrypt(self, data):
     def encrypt(self, data):
         data = zlib.compress(data)
         data = zlib.compress(data)
-        nonce = long_to_bytes(self.counter.next_value(), 16)
+        nonce = long_to_bytes(self.counter.next_value(), 8)
         data = ''.join((nonce, AES.new(self.enc_key, AES.MODE_CTR, '',
         data = ''.join((nonce, AES.new(self.enc_key, AES.MODE_CTR, '',
                                        counter=self.counter).encrypt(data)))
                                        counter=self.counter).encrypt(data)))
         hash = HMAC.new(self.enc_hmac_key, data, SHA256).digest()
         hash = HMAC.new(self.enc_hmac_key, data, SHA256).digest()
@@ -158,10 +169,10 @@ class Key(object):
         hash = data[1:33]
         hash = data[1:33]
         if HMAC.new(self.enc_hmac_key, data[33:], SHA256).digest() != hash:
         if HMAC.new(self.enc_hmac_key, data[33:], SHA256).digest() != hash:
             raise IntegrityError('Encryption envelope checksum mismatch')
             raise IntegrityError('Encryption envelope checksum mismatch')
-        nonce = bytes_to_long(data[33:49])
-        counter = Counter.new(128, initial_value=nonce, allow_wraparound=True)
-        data = zlib.decompress(AES.new(self.enc_key, AES.MODE_CTR, counter=counter).decrypt(data[49:]))
-        if HMAC.new(self.id_key, data, SHA256).digest() != id:
+        nonce = bytes_to_long(data[33:41])
+        counter = Counter.new(64, initial_value=nonce, prefix=PREFIX)
+        data = zlib.decompress(AES.new(self.enc_key, AES.MODE_CTR, counter=counter).decrypt(data[41:]))
+        if id and HMAC.new(self.id_key, data, SHA256).digest() != id:
             raise IntegrityError('Chunk id verification failed')
             raise IntegrityError('Chunk id verification failed')
         return data
         return data
 
 

+ 2 - 2
darc/remote.py

@@ -65,7 +65,7 @@ class StoreServer(object):
         if path.startswith('/~'):
         if path.startswith('/~'):
             path = path[1:]
             path = path[1:]
         self.store = Store(os.path.expanduser(path), create)
         self.store = Store(os.path.expanduser(path), create)
-        return self.store.meta
+        return self.store.id
 
 
 
 
 class RemoteStore(object):
 class RemoteStore(object):
@@ -110,7 +110,7 @@ class RemoteStore(object):
         self.msgid = 0
         self.msgid = 0
         self.recursion = 0
         self.recursion = 0
         self.odata = []
         self.odata = []
-        self.meta = self.cmd('open', (location.path, create))
+        self.id = self.cmd('open', (location.path, create))
 
 
     def wait(self, write=True):
     def wait(self, write=True):
         with self.channel.lock:
         with self.channel.lock:

+ 237 - 164
darc/store.py

@@ -1,9 +1,8 @@
 from __future__ import with_statement
 from __future__ import with_statement
 from ConfigParser import RawConfigParser
 from ConfigParser import RawConfigParser
-import errno
 import fcntl
 import fcntl
 import os
 import os
-import msgpack
+import re
 import shutil
 import shutil
 import struct
 import struct
 import tempfile
 import tempfile
@@ -11,9 +10,15 @@ import unittest
 from zlib import crc32
 from zlib import crc32
 
 
 from .hashindex import NSIndex
 from .hashindex import NSIndex
-from .helpers import IntegrityError, deferrable
+from .helpers import IntegrityError, deferrable, read_msgpack, write_msgpack
 from .lrucache import LRUCache
 from .lrucache import LRUCache
 
 
+MAX_OBJECT_SIZE = 20 * 1024 * 1024
+
+TAG_PUT = 0
+TAG_DELETE = 1
+TAG_COMMIT = 2
+
 
 
 class Store(object):
 class Store(object):
     """Filesystem based transactional key value store
     """Filesystem based transactional key value store
@@ -22,8 +27,8 @@ class Store(object):
     dir/README
     dir/README
     dir/config
     dir/config
     dir/data/<X / SEGMENTS_PER_DIR>/<X>
     dir/data/<X / SEGMENTS_PER_DIR>/<X>
-    dir/segments
-    dir/index
+    dir/index.X
+    dir/hints.X
     """
     """
     DEFAULT_MAX_SEGMENT_SIZE = 5 * 1024 * 1024
     DEFAULT_MAX_SEGMENT_SIZE = 5 * 1024 * 1024
     DEFAULT_SEGMENTS_PER_DIR = 10000
     DEFAULT_SEGMENTS_PER_DIR = 10000
@@ -33,7 +38,6 @@ class Store(object):
 
 
 
 
     def __init__(self, path, create=False):
     def __init__(self, path, create=False):
-        self.txn_active = False
         if create:
         if create:
             self.create(path)
             self.create(path)
         self.open(path)
         self.open(path)
@@ -53,132 +57,138 @@ class Store(object):
         config.set('store', 'version', '1')
         config.set('store', 'version', '1')
         config.set('store', 'segments_per_dir', self.DEFAULT_SEGMENTS_PER_DIR)
         config.set('store', 'segments_per_dir', self.DEFAULT_SEGMENTS_PER_DIR)
         config.set('store', 'max_segment_size', self.DEFAULT_MAX_SEGMENT_SIZE)
         config.set('store', 'max_segment_size', self.DEFAULT_MAX_SEGMENT_SIZE)
-        config.set('store', 'next_segment', '0')
-        config.add_section('meta')
-        config.set('meta', 'manifest', '')
-        config.set('meta', 'id', os.urandom(32).encode('hex'))
-        NSIndex.create(os.path.join(path, 'index'))
-        self.write_dict(os.path.join(path, 'segments'), {})
+        config.set('store', 'id', os.urandom(32).encode('hex'))
         with open(os.path.join(path, 'config'), 'w') as fd:
         with open(os.path.join(path, 'config'), 'w') as fd:
             config.write(fd)
             config.write(fd)
 
 
     def open(self, path):
     def open(self, path):
+        self.head = None
         self.path = path
         self.path = path
         if not os.path.isdir(path):
         if not os.path.isdir(path):
             raise Exception('%s Does not look like a darc store' % path)
             raise Exception('%s Does not look like a darc store' % path)
         self.lock_fd = open(os.path.join(path, 'README'), 'r+')
         self.lock_fd = open(os.path.join(path, 'README'), 'r+')
         fcntl.flock(self.lock_fd, fcntl.LOCK_EX)
         fcntl.flock(self.lock_fd, fcntl.LOCK_EX)
+        self.config = RawConfigParser()
+        self.config.read(os.path.join(self.path, 'config'))
+        if self.config.getint('store', 'version') != 1:
+            raise Exception('%s Does not look like a darc store')
+        self.max_segment_size = self.config.getint('store', 'max_segment_size')
+        self.segments_per_dir = self.config.getint('store', 'segments_per_dir')
+        self.id = self.config.get('store', 'id').decode('hex')
         self.rollback()
         self.rollback()
 
 
-    def read_dict(self, filename):
-        with open(filename, 'rb') as fd:
-            return msgpack.unpackb(fd.read())
-
-    def write_dict(self, filename, d):
-        with open(filename+'.tmp', 'wb') as fd:
-            fd.write(msgpack.packb(d))
-        os.rename(filename+'.tmp', filename)
-
-    def delete_segments(self):
-        delete_path = os.path.join(self.path, 'delete')
-        if os.path.exists(delete_path):
-            segments = self.read_dict(os.path.join(self.path, 'segments'))
-            for segment in self.read_dict(delete_path):
-                assert segments.pop(segment, 0) == 0
-                self.io.delete_segment(segment, missing_ok=True)
-            self.write_dict(os.path.join(self.path, 'segments'), segments)
-
-    def begin_txn(self):
-        txn_dir = os.path.join(self.path, 'txn.tmp')
-        # Initialize transaction snapshot
-        os.mkdir(txn_dir)
-        shutil.copy(os.path.join(self.path, 'config'), txn_dir)
-        shutil.copy(os.path.join(self.path, 'index'), txn_dir)
-        shutil.copy(os.path.join(self.path, 'segments'), txn_dir)
-        os.rename(os.path.join(self.path, 'txn.tmp'),
-                  os.path.join(self.path, 'txn.active'))
-        self.compact = set()
-        self.txn_active = True
-
     def close(self):
     def close(self):
-        self.rollback()
         self.lock_fd.close()
         self.lock_fd.close()
 
 
-    def commit(self, meta=None):
+    def commit(self, rollback=True):
         """Commit transaction
         """Commit transaction
         """
         """
-        meta = meta or self.meta
+        self.io.write_commit()
         self.compact_segments()
         self.compact_segments()
-        self.io.close()
-        self.config.set('store', 'next_segment', self.io.segment + 1)
-        self.config.remove_section('meta')
-        self.config.add_section('meta')
-        for k, v in meta.items():
-            self.config.set('meta', k, v)
-        with open(os.path.join(self.path, 'config'), 'w') as fd:
-            self.config.write(fd)
-        self.index.flush()
-        self.write_dict(os.path.join(self.path, 'segments'), self.segments)
-        # If we crash before this line, the transaction will be
-        # rolled back by open()
-        os.rename(os.path.join(self.path, 'txn.active'),
-                  os.path.join(self.path, 'txn.commit'))
+        self.write_index()
         self.rollback()
         self.rollback()
 
 
+    def _available_indices(self, reverse=False):
+        names = [int(name[6:]) for name in os.listdir(self.path) if re.match('index\.\d+', name)]
+        names.sort(reverse=reverse)
+        return names
+
+    def open_index(self, head):
+        if head is None:
+            self.index = NSIndex.create(os.path.join(self.path, 'index.tmp'))
+            self.segments = {}
+            self.compact = set()
+        else:
+            shutil.copy(os.path.join(self.path, 'index.%d' % head),
+                        os.path.join(self.path, 'index.tmp'))
+            self.index = NSIndex(os.path.join(self.path, 'index.tmp'))
+            hints = read_msgpack(os.path.join(self.path, 'hints.%d' % head))
+            if hints['version'] != 1:
+                raise ValueError('Unknown hints file version: %d' % hints['version'])
+            self.segments = hints['segments']
+            self.compact = set(hints['compact'])
+
+    def write_index(self):
+        hints = {'version': 1,
+                 'segments': self.segments,
+                 'compact': list(self.compact)}
+        write_msgpack(os.path.join(self.path, 'hints.%d' % self.io.head), hints)
+        self.index.flush()
+        os.rename(os.path.join(self.path, 'index.tmp'),
+                  os.path.join(self.path, 'index.%d' % self.io.head))
+        # Remove old indices
+        current = '.%d' % self.io.head
+        for name in os.listdir(self.path):
+            if not name.startswith('index.') and not name.startswith('hints.'):
+                continue
+            if name.endswith(current):
+                continue
+            os.unlink(os.path.join(self.path, name))
+
     def compact_segments(self):
     def compact_segments(self):
         """Compact sparse segments by copying data into new segments
         """Compact sparse segments by copying data into new segments
         """
         """
         if not self.compact:
         if not self.compact:
             return
             return
-        self.io.close_segment()
-        def lookup(key):
-            return self.index.get(key, (-1, -1))[0] == segment
+        def lookup(tag, key):
+            return tag == TAG_PUT and self.index.get(key, (-1, -1))[0] == segment
         segments = self.segments
         segments = self.segments
         for segment in self.compact:
         for segment in self.compact:
             if segments[segment] > 0:
             if segments[segment] > 0:
-                for key, data in self.io.iter_objects(segment, lookup):
-                    new_segment, offset = self.io.write(key, data)
+                for tag, key, data in self.io.iter_objects(segment, lookup, include_data=True):
+                    new_segment, offset = self.io.write_put(key, data)
                     self.index[key] = new_segment, offset
                     self.index[key] = new_segment, offset
                     segments.setdefault(new_segment, 0)
                     segments.setdefault(new_segment, 0)
                     segments[new_segment] += 1
                     segments[new_segment] += 1
                     segments[segment] -= 1
                     segments[segment] -= 1
-        self.write_dict(os.path.join(self.path, 'delete'), tuple(self.compact))
+                assert segments[segment] == 0
+        self.io.write_commit()
+        for segment in self.compact:
+            assert self.segments.pop(segment) == 0
+            self.io.delete_segment(segment)
+        self.compact = set()
+
+    def recover(self, path):
+        """Recover missing index by replaying logs"""
+        start = None
+        available = self._available_indices()
+        if available:
+            start = available[-1]
+        self.open_index(start)
+        for segment, filename in self.io._segment_names():
+            if start is not None and segment <= start:
+                continue
+            self.segments[segment] = 0
+            for tag, key, offset in self.io.iter_objects(segment):
+                if tag == TAG_PUT:
+                    try:
+                        s, _ = self.index[key]
+                        self.compact.add(s)
+                        self.segments[s] -= 1
+                    except KeyError:
+                        pass
+                    self.index[key] = segment, offset
+                    self.segments[segment] += 1
+                elif tag == TAG_DELETE:
+                    try:
+                        s, _ = self.index.pop(key)
+                        self.segments[s] -= 1
+                        self.compact.add(s)
+                        self.compact.add(segment)
+                    except KeyError:
+                        pass
+            if self.segments[segment] == 0:
+                self.compact.add(segment)
+        if self.io.head is not None:
+            self.write_index()
 
 
     def rollback(self):
     def rollback(self):
         """
         """
         """
         """
-        # Commit any half committed transaction
-        if os.path.exists(os.path.join(self.path, 'txn.commit')):
-            self.delete_segments()
-            os.rename(os.path.join(self.path, 'txn.commit'),
-                      os.path.join(self.path, 'txn.tmp'))
-
-        delete_path = os.path.join(self.path, 'delete')
-        if os.path.exists(delete_path):
-            os.unlink(delete_path)
-        # Roll back active transaction
-        txn_dir = os.path.join(self.path, 'txn.active')
-        if os.path.exists(txn_dir):
-            shutil.copy(os.path.join(txn_dir, 'config'), self.path)
-            shutil.copy(os.path.join(txn_dir, 'index'), self.path)
-            shutil.copy(os.path.join(txn_dir, 'segments'), self.path)
-            os.rename(txn_dir, os.path.join(self.path, 'txn.tmp'))
-        # Remove partially removed transaction
-        if os.path.exists(os.path.join(self.path, 'txn.tmp')):
-            shutil.rmtree(os.path.join(self.path, 'txn.tmp'))
-        self.index = NSIndex(os.path.join(self.path, 'index'))
-        self.segments = self.read_dict(os.path.join(self.path, 'segments'))
-        self.config = RawConfigParser()
-        self.config.read(os.path.join(self.path, 'config'))
-        if self.config.getint('store', 'version') != 1:
-            raise Exception('%s Does not look like a darc store')
-        next_segment = self.config.getint('store', 'next_segment')
-        max_segment_size = self.config.getint('store', 'max_segment_size')
-        segments_per_dir = self.config.getint('store', 'segments_per_dir')
-        self.meta = dict(self.config.items('meta'))
-        self.io = SegmentIO(self.path, next_segment, max_segment_size, segments_per_dir)
-        self.io.cleanup()
-        self.txn_active = False
+        self.io = LoggedIO(self.path, self.max_segment_size, self.segments_per_dir)
+        if self.io.head is not None and not os.path.exists(os.path.join(self.path, 'index.%d' % self.io.head)):
+            self.recover(self.path)
+        self.open_index(self.io.head)
 
 
     @deferrable
     @deferrable
     def get(self, id):
     def get(self, id):
@@ -190,27 +200,25 @@ class Store(object):
 
 
     @deferrable
     @deferrable
     def put(self, id, data):
     def put(self, id, data):
-        if not self.txn_active:
-            self.begin_txn()
         try:
         try:
             segment, _ = self.index[id]
             segment, _ = self.index[id]
             self.segments[segment] -= 1
             self.segments[segment] -= 1
             self.compact.add(segment)
             self.compact.add(segment)
+            self.compact.add(self.io.write_delete(id))
         except KeyError:
         except KeyError:
             pass
             pass
-        segment, offset = self.io.write(id, data)
+        segment, offset = self.io.write_put(id, data)
         self.segments.setdefault(segment, 0)
         self.segments.setdefault(segment, 0)
         self.segments[segment] += 1
         self.segments[segment] += 1
         self.index[id] = segment, offset
         self.index[id] = segment, offset
 
 
     @deferrable
     @deferrable
     def delete(self, id):
     def delete(self, id):
-        if not self.txn_active:
-            self.begin_txn()
         try:
         try:
             segment, offset = self.index.pop(id)
             segment, offset = self.index.pop(id)
             self.segments[segment] -= 1
             self.segments[segment] -= 1
             self.compact.add(segment)
             self.compact.add(segment)
+            self.compact.add(self.io.write_delete(id))
         except KeyError:
         except KeyError:
             raise self.DoesNotExist
             raise self.DoesNotExist
 
 
@@ -218,109 +226,169 @@ class Store(object):
         pass
         pass
 
 
 
 
-class SegmentIO(object):
+class LoggedIO(object):
+
+    header_fmt = struct.Struct('<IIB')
+    assert header_fmt.size == 9
+    put_header_fmt = struct.Struct('<IIB32s')
+    assert put_header_fmt.size == 41
+    header_no_crc_fmt = struct.Struct('<IB')
+    assert header_no_crc_fmt.size == 5
+    crc_fmt = struct.Struct('<I')
+    assert crc_fmt.size == 4
 
 
-    header_fmt = struct.Struct('<IBI32s')
-    assert header_fmt.size == 41
+    _commit = header_no_crc_fmt.pack(9, TAG_COMMIT)
+    COMMIT = crc_fmt.pack(crc32(_commit)) + _commit
 
 
-    def __init__(self, path, next_segment, limit, segments_per_dir, capacity=100):
+    def __init__(self, path, limit, segments_per_dir, capacity=100):
         self.path = path
         self.path = path
         self.fds = LRUCache(capacity)
         self.fds = LRUCache(capacity)
-        self.segment = next_segment
+        self.segment = None
         self.limit = limit
         self.limit = limit
         self.segments_per_dir = segments_per_dir
         self.segments_per_dir = segments_per_dir
         self.offset = 0
         self.offset = 0
+        self._write_fd = None
+        self.head = None
+        self.cleanup()
 
 
     def close(self):
     def close(self):
         for segment in self.fds.keys():
         for segment in self.fds.keys():
             self.fds.pop(segment).close()
             self.fds.pop(segment).close()
-	self.fds = None # Just to make sure we're disabled
+        self.close_segment()
+        self.fds = None # Just to make sure we're disabled
+
+    def _segment_names(self, reverse=False):
+        for dirpath, dirs, filenames in os.walk(os.path.join(self.path, 'data')):
+            dirs.sort(lambda a, b: cmp(int(a), int(b)), reverse=reverse)
+            filenames.sort(lambda a, b: cmp(int(a), int(b)), reverse=reverse)
+            for filename in filenames:
+                yield int(filename), os.path.join(dirpath, filename)
 
 
     def cleanup(self):
     def cleanup(self):
         """Delete segment files left by aborted transactions
         """Delete segment files left by aborted transactions
         """
         """
-        segment = self.segment
-        while True:
-            filename = self.segment_filename(segment)
-            if not os.path.exists(filename):
-                break
-            os.unlink(filename)
-            segment += 1
+        self.head = None
+        self.segment = 0
+        for segment, filename in self._segment_names(reverse=True):
+            if self.is_complete_segment(filename):
+                self.head = segment
+                self.segment = self.head + 1
+                return
+            else:
+                os.unlink(filename)
+
+    def is_complete_segment(self, filename):
+        with open(filename, 'rb') as fd:
+            fd.seek(-self.header_fmt.size, 2)
+            return fd.read(self.header_fmt.size) == self.COMMIT
 
 
     def segment_filename(self, segment):
     def segment_filename(self, segment):
         return os.path.join(self.path, 'data', str(segment / self.segments_per_dir), str(segment))
         return os.path.join(self.path, 'data', str(segment / self.segments_per_dir), str(segment))
 
 
-    def get_fd(self, segment, write=False):
+    def get_write_fd(self):
+        if self.offset and self.offset > self.limit:
+            self.close_segment()
+        if not self._write_fd:
+            if self.segment % self.segments_per_dir == 0:
+                dirname = os.path.join(self.path, 'data', str(self.segment / self.segments_per_dir))
+                if not os.path.exists(dirname):
+                    os.mkdir(dirname)
+            self._write_fd = open(self.segment_filename(self.segment), 'ab')
+            self._write_fd.write('DSEGMENT')
+            self.offset = 8
+        return self._write_fd
+
+    def get_fd(self, segment):
         try:
         try:
             return self.fds[segment]
             return self.fds[segment]
         except KeyError:
         except KeyError:
-            if write and segment % self.segments_per_dir == 0:
-                dirname = os.path.join(self.path, 'data', str(segment / self.segments_per_dir))
-                if not os.path.exists(dirname):
-                    os.mkdir(dirname)
-            fd = open(self.segment_filename(segment), write and 'w+' or 'rb')
+            fd = open(self.segment_filename(segment), 'rb')
             self.fds[segment] = fd
             self.fds[segment] = fd
             return fd
             return fd
 
 
-    def delete_segment(self, segment, missing_ok=False):
+    def delete_segment(self, segment):
         try:
         try:
             os.unlink(self.segment_filename(segment))
             os.unlink(self.segment_filename(segment))
         except OSError, e:
         except OSError, e:
-            if not missing_ok or e.errno != errno.ENOENT:
-                raise
-
-    def read(self, segment, offset, id):
-        fd = self.get_fd(segment)
-        fd.seek(offset)
-        data = fd.read(self.header_fmt.size)
-        size, magic, hash, id_ = self.header_fmt.unpack(data)
-        if magic != 0 or id != id_:
-            raise IntegrityError('Invalid segment entry header')
-        data = fd.read(size - self.header_fmt.size)
-        if crc32(data) & 0xffffffff != hash:
-            raise IntegrityError('Segment checksum mismatch')
-        return data
+            pass
 
 
-    def iter_objects(self, segment, lookup):
+    def iter_objects(self, segment, lookup=None, include_data=False):
         fd = self.get_fd(segment)
         fd = self.get_fd(segment)
         fd.seek(0)
         fd.seek(0)
         if fd.read(8) != 'DSEGMENT':
         if fd.read(8) != 'DSEGMENT':
             raise IntegrityError('Invalid segment header')
             raise IntegrityError('Invalid segment header')
         offset = 8
         offset = 8
-        data = fd.read(self.header_fmt.size)
-        while data:
-            size, magic, hash, key = self.header_fmt.unpack(data)
-            if magic != 0:
-                raise IntegrityError('Unknown segment entry header')
+        header = fd.read(self.header_fmt.size)
+        while header:
+            crc, size, tag = self.header_fmt.unpack(header)
+            if size > MAX_OBJECT_SIZE:
+                raise IntegrityError('Invalid segment object size')
+            rest = fd.read(size - self.header_fmt.size)
+            if crc32(rest, crc32(buffer(header, 4))) & 0xffffffff != crc:
+                raise IntegrityError('Segment checksum mismatch')
+            if tag not in (TAG_PUT, TAG_DELETE, TAG_COMMIT):
+                raise IntegrityError('Invalid segment entry header')
+            key = None
+            if tag in (TAG_PUT, TAG_DELETE):
+                key = rest[:32]
+            if not lookup or lookup(tag, key):
+                if include_data:
+                    yield tag, key, rest[32:]
+                else:
+                    yield tag, key, offset
             offset += size
             offset += size
-            if lookup(key):
-                data = fd.read(size - self.header_fmt.size)
-                if crc32(data) & 0xffffffff != hash:
-                    raise IntegrityError('Segment checksum mismatch')
-                yield key, data
-            else:
-                fd.seek(offset)
-            data = fd.read(self.header_fmt.size)
+            header = fd.read(self.header_fmt.size)
 
 
-    def write(self, id, data):
-        size = len(data) + self.header_fmt.size
-        if self.offset and self.offset + size > self.limit:
-            self.close_segment()
-        fd = self.get_fd(self.segment, write=True)
-        fd.seek(self.offset)
-        if self.offset == 0:
-            fd.write('DSEGMENT')
-            self.offset = 8
+    def read(self, segment, offset, id):
+        if segment == self.segment:
+            self._write_fd.flush()
+        fd = self.get_fd(segment)
+        fd.seek(offset)
+        header = fd.read(self.put_header_fmt.size)
+        crc, size, tag, key = self.put_header_fmt.unpack(header)
+        if size > MAX_OBJECT_SIZE:
+            raise IntegrityError('Invalid segment object size')
+        data = fd.read(size - self.put_header_fmt.size)
+        if crc32(data, crc32(buffer(header, 4))) & 0xffffffff != crc:
+            raise IntegrityError('Segment checksum mismatch')
+        if tag != TAG_PUT or id != key:
+            raise IntegrityError('Invalid segment entry header')
+        return data
+
+    def write_put(self, id, data):
+        size = len(data) + self.put_header_fmt.size
+        fd = self.get_write_fd()
         offset = self.offset
         offset = self.offset
-        hash = crc32(data) & 0xffffffff
-        fd.write(self.header_fmt.pack(size, 0, hash, id))
-        fd.write(data)
+        header = self.header_no_crc_fmt.pack(size, TAG_PUT)
+        crc = self.crc_fmt.pack(crc32(data, crc32(id, crc32(header))) & 0xffffffff)
+        fd.write(''.join((crc, header, id, data)))
         self.offset += size
         self.offset += size
         return self.segment, offset
         return self.segment, offset
 
 
+    def write_delete(self, id):
+        fd = self.get_write_fd()
+        offset = self.offset
+        header = self.header_no_crc_fmt.pack(self.put_header_fmt.size, TAG_DELETE)
+        crc = self.crc_fmt.pack(crc32(id, crc32(header)) & 0xffffffff)
+        fd.write(''.join((crc, header, id)))
+        self.offset += self.put_header_fmt.size
+        return self.segment
+
+    def write_commit(self):
+        fd = self.get_write_fd()
+        header = self.header_no_crc_fmt.pack(self.header_fmt.size, TAG_COMMIT)
+        crc = self.crc_fmt.pack(crc32(header) & 0xffffffff)
+        fd.write(''.join((crc, header)))
+        self.head = self.segment
+        self.close_segment()
+
     def close_segment(self):
     def close_segment(self):
-        self.segment += 1
-        self.offset = 0
+        if self._write_fd:
+            self.segment += 1
+            self.offset = 0
+            os.fsync(self._write_fd)
+            self._write_fd.close()
+            self._write_fd = None
 
 
 
 
 class StoreTestCase(unittest.TestCase):
 class StoreTestCase(unittest.TestCase):
@@ -342,6 +410,11 @@ class StoreTestCase(unittest.TestCase):
         self.store.commit()
         self.store.commit()
         self.store.close()
         self.store.close()
         store2 = Store(os.path.join(self.tmppath, 'store'))
         store2 = Store(os.path.join(self.tmppath, 'store'))
+        self.assertRaises(store2.DoesNotExist, lambda: store2.get(key50))
+        for x in range(100):
+            if x == 50:
+                continue
+            self.assertEqual(self.store.get('%-32d' % x), 'SOMEDATA')
 
 
     def test2(self):
     def test2(self):
         """Test multiple sequential transactions
         """Test multiple sequential transactions

+ 1 - 1
darc/test.py

@@ -115,7 +115,7 @@ class Test(unittest.TestCase):
     def test_corrupted_store(self):
     def test_corrupted_store(self):
         self.create_src_archive('test')
         self.create_src_archive('test')
         self.darc('verify', self.store_path + '::test')
         self.darc('verify', self.store_path + '::test')
-        fd = open(os.path.join(self.tmpdir, 'store', 'data', '0', '0'), 'r+')
+        fd = open(os.path.join(self.tmpdir, 'store', 'data', '0', '2'), 'r+')
         fd.seek(100)
         fd.seek(100)
         fd.write('X')
         fd.write('X')
         fd.close()
         fd.close()