Răsfoiți Sursa

Improved remote repository performance and reliability

Jonas Borgström 11 ani în urmă
părinte
comite
bd5b72a646
6 a modificat fișierele cu 201 adăugiri și 249 ștergeri
  1. 1 0
      CHANGES
  2. 79 99
      attic/archive.py
  3. 6 8
      attic/cache.py
  4. 76 139
      attic/remote.py
  5. 7 3
      attic/repository.py
  6. 32 0
      attic/testsuite/archive.py

+ 1 - 0
CHANGES

@@ -8,6 +8,7 @@ Version 0.9
 
 (feature release, released on X)
 
+- Remote repository speed and reliability improvements.
 - Fix sorting of segment names to ignore NFS left over files. (#17)
 - Fix incorrect display of time (#13)
 - Improved error handling / reporting. (#12)

+ 79 - 99
attic/archive.py

@@ -23,54 +23,64 @@ has_mtime_ns = sys.version >= '3.3'
 has_lchmod = hasattr(os, 'lchmod')
 
 
-class ItemIter(object):
+class DownloadPipeline:
 
-    def __init__(self, unpacker, filter):
-        self.unpacker = iter(unpacker)
-        self.filter = filter
-        self.stack = []
-        self.peeks = 0
-        self._peek_iter = None
+    def __init__(self, repository, key):
+        self.repository = repository
+        self.key = key
 
-    def __iter__(self):
-        return self
+    def unpack_many(self, ids, filter=None):
+        unpacker = msgpack.Unpacker(use_list=False)
+        for data in self.fetch_many(ids):
+            unpacker.feed(data)
+            items = [decode_dict(item, (b'path', b'source', b'user', b'group')) for item in unpacker]
+            if filter:
+                items = [item for item in items if filter(item)]
+            for item in items:
+                if b'chunks' in item:
+                    self.repository.preload([c[0] for c in item[b'chunks']])
+            for item in items:
+                yield item
+
+    def fetch_many(self, ids, is_preloaded=False):
+        for id_, data in zip_longest(ids, self.repository.get_many(ids, is_preloaded=is_preloaded)):
+            yield self.key.decrypt(id_, data)
+
+
+class ChunkBuffer:
+    BUFFER_SIZE = 1 * 1024 * 1024
+
+    def __init__(self, cache, key, stats):
+        self.buffer = BytesIO()
+        self.packer = msgpack.Packer(unicode_errors='surrogateescape')
+        self.cache = cache
+        self.chunks = []
+        self.key = key
+        self.stats = stats
 
-    def __next__(self):
-        if self.stack:
-            item = self.stack.pop(0)
-        else:
-            self._peek_iter = None
-            item = self.get_next()
-        self.peeks = max(0, self.peeks - len(item.get(b'chunks', [])))
-        return item
+    def add(self, item):
+        self.buffer.write(self.packer.pack(item))
 
-    def get_next(self):
-        while True:
-            n = next(self.unpacker)
-            decode_dict(n, (b'path', b'source', b'user', b'group'))
-            if not self.filter or self.filter(n):
-                return n
-
-    def peek(self):
-        while True:
-            while not self._peek_iter:
-                if self.peeks > 100:
-                    raise StopIteration
-                _peek = self.get_next()
-                self.stack.append(_peek)
-                if b'chunks' in _peek:
-                    self._peek_iter = iter(_peek[b'chunks'])
-                else:
-                    self._peek_iter = None
-            try:
-                item = next(self._peek_iter)
-                self.peeks += 1
-                return item
-            except StopIteration:
-                self._peek_iter = None
+    def flush(self, flush=False):
+        if self.buffer.tell() == 0:
+            return
+        self.buffer.seek(0)
+        chunks = list(bytes(s) for s in chunkify(self.buffer, WINDOW_SIZE, CHUNK_MASK, CHUNK_MIN, self.key.chunk_seed))
+        self.buffer.seek(0)
+        self.buffer.truncate(0)
+        # Leave the last parital chunk in the buffer unless flush is True
+        end = None if flush or len(chunks) == 1 else -1
+        for chunk in chunks[:end]:
+            id_, _, _ = self.cache.add_chunk(self.key.id_hash(chunk), chunk, self.stats)
+            self.chunks.append(id_)
+        if end == -1:
+            self.buffer.write(chunks[-1])
 
+    def is_full(self):
+        return self.buffer.tell() > self.BUFFER_SIZE
 
-class Archive(object):
+
+class Archive:
 
     class DoesNotExist(Error):
         """Archive {} does not exist"""
@@ -85,13 +95,13 @@ class Archive(object):
         self.repository = repository
         self.cache = cache
         self.manifest = manifest
-        self.items = BytesIO()
-        self.items_ids = []
         self.hard_links = {}
         self.stats = Statistics()
         self.name = name
         self.checkpoint_interval = checkpoint_interval
         self.numeric_owner = numeric_owner
+        self.items_buffer = ChunkBuffer(self.cache, self.key, self.stats)
+        self.pipeline = DownloadPipeline(self.repository, self.key)
         if create:
             if name in manifest.archives:
                 raise self.AlreadyExists(name)
@@ -128,44 +138,17 @@ class Archive(object):
         return 'Archive(%r)' % self.name
 
     def iter_items(self, filter=None):
-        unpacker = msgpack.Unpacker(use_list=False)
-        i = 0
-        n = 20
-        while True:
-            items = self.metadata[b'items'][i:i + n]
-            i += n
-            if not items:
-                break
-            for id, chunk in [(id, chunk) for id, chunk in zip_longest(items, self.repository.get_many(items))]:
-                unpacker.feed(self.key.decrypt(id, chunk))
-                iter = ItemIter(unpacker, filter)
-                for item in iter:
-                    yield item, iter.peek
+        for item in self.pipeline.unpack_many(self.metadata[b'items'], filter=filter):
+            yield item, None
 
     def add_item(self, item):
-        self.items.write(msgpack.packb(item, unicode_errors='surrogateescape'))
+        self.items_buffer.add(item)
         now = time.time()
         if now - self.last_checkpoint > self.checkpoint_interval:
             self.last_checkpoint = now
             self.write_checkpoint()
-        if self.items.tell() > ITEMS_BUFFER:
-            self.flush_items()
-
-    def flush_items(self, flush=False):
-        if self.items.tell() == 0:
-            return
-        self.items.seek(0)
-        chunks = list(bytes(s) for s in chunkify(self.items, WINDOW_SIZE, CHUNK_MASK, CHUNK_MIN, self.key.chunk_seed))
-        self.items.seek(0)
-        self.items.truncate()
-        for chunk in chunks[:-1]:
-            id, _, _ = self.cache.add_chunk(self.key.id_hash(chunk), chunk, self.stats)
-            self.items_ids.append(id)
-        if flush or len(chunks) == 1:
-            id, _, _ = self.cache.add_chunk(self.key.id_hash(chunks[-1]), chunks[-1], self.stats)
-            self.items_ids.append(id)
-        else:
-            self.items.write(chunks[-1])
+        if self.items_buffer.is_full():
+            self.items_buffer.flush()
 
     def write_checkpoint(self):
         self.save(self.checkpoint_name)
@@ -176,11 +159,11 @@ class Archive(object):
         name = name or self.name
         if name in self.manifest.archives:
             raise self.AlreadyExists(name)
-        self.flush_items(flush=True)
+        self.items_buffer.flush(flush=True)
         metadata = {
             'version': 1,
             'name': name,
-            'items': self.items_ids,
+            'items': self.items_buffer.chunks,
             'cmdline': sys.argv,
             'hostname': socket.gethostname(),
             'username': getuser(),
@@ -199,6 +182,9 @@ class Archive(object):
             count, size, csize = self.cache.chunks[id]
             stats.update(size, csize, count == 1)
             self.cache.chunks[id] = count - 1, size, csize
+        def add_file_chunks(chunks):
+            for id, _, _ in chunks:
+                add(id)
         # This function is a bit evil since it abuses the cache to calculate
         # the stats. The cache transaction must be rolled back afterwards
         unpacker = msgpack.Unpacker(use_list=False)
@@ -209,12 +195,9 @@ class Archive(object):
             add(id)
             unpacker.feed(self.key.decrypt(id, chunk))
             for item in unpacker:
-                try:
-                    for id, size, csize in item[b'chunks']:
-                        add(id)
+                if b'chunks' in item:
                     stats.nfiles += 1
-                except KeyError:
-                    pass
+                    add_file_chunks(item[b'chunks'])
         cache.rollback()
         return stats
 
@@ -249,9 +232,8 @@ class Archive(object):
                 os.link(source, path)
             else:
                 with open(path, 'wb') as fd:
-                    ids = [id for id, size, csize in item[b'chunks']]
-                    for id, chunk in zip_longest(ids, self.repository.get_many(ids, peek)):
-                        data = self.key.decrypt(id, chunk)
+                    ids = [c[0] for c in item[b'chunks']]
+                    for data in self.pipeline.fetch_many(ids, is_preloaded=True):
                         fd.write(data)
                     fd.flush()
                     self.restore_attrs(path, item, fd=fd.fileno())
@@ -314,8 +296,8 @@ class Archive(object):
             start(item)
             ids = [id for id, size, csize in item[b'chunks']]
             try:
-                for id, chunk in zip_longest(ids, self.repository.get_many(ids, peek)):
-                    self.key.decrypt(id, chunk)
+                for _ in self.pipeline.fetch_many(ids, is_preloaded=True):
+                    pass
             except Exception:
                 result(item, False)
                 return
@@ -323,15 +305,14 @@ class Archive(object):
 
     def delete(self, cache):
         unpacker = msgpack.Unpacker(use_list=False)
-        for id in self.metadata[b'items']:
-            unpacker.feed(self.key.decrypt(id, self.repository.get(id)))
+        for id_, data in zip_longest(self.metadata[b'items'], self.repository.get_many(self.metadata[b'items'])):
+            unpacker.feed(self.key.decrypt(id_, data))
+            self.cache.chunk_decref(id_)
             for item in unpacker:
-                try:
+                if b'chunks' in item:
                     for chunk_id, size, csize in item[b'chunks']:
                         self.cache.chunk_decref(chunk_id)
-                except KeyError:
-                    pass
-            self.cache.chunk_decref(id)
+
         self.cache.chunk_decref(self.id)
         del self.manifest.archives[self.name]
         self.manifest.write()
@@ -385,19 +366,18 @@ class Archive(object):
         chunks = None
         if ids is not None:
             # Make sure all ids are available
-            for id in ids:
-                if not cache.seen_chunk(id):
+            for id_ in ids:
+                if not cache.seen_chunk(id_):
                     break
             else:
-                chunks = [cache.chunk_incref(id, self.stats) for id in ids]
+                chunks = [cache.chunk_incref(id_, self.stats) for id_ in ids]
         # Only chunkify the file if needed
         if chunks is None:
             with open(path, 'rb') as fd:
                 chunks = []
                 for chunk in chunkify(fd, WINDOW_SIZE, CHUNK_MASK, CHUNK_MIN, self.key.chunk_seed):
                     chunks.append(cache.add_chunk(self.key.id_hash(chunk), chunk, self.stats))
-            ids = [id for id, _, _ in chunks]
-            cache.memorize_file(path_hash, st, ids)
+            cache.memorize_file(path_hash, st, [c[0] for c in chunks])
         item = {b'path': safe_path, b'chunks': chunks}
         item.update(self.stat_attrs(st, path))
         self.stats.nfiles += 1

+ 6 - 8
attic/cache.py

@@ -154,16 +154,14 @@ class Cache(object):
             archive = msgpack.unpackb(data)
             decode_dict(archive, (b'name', b'hostname', b'username', b'time'))  # fixme: argv
             print('Analyzing archive:', archive[b'name'])
-            for id, chunk in zip_longest(archive[b'items'], self.repository.get_many(archive[b'items'])):
-                data = self.key.decrypt(id, chunk)
-                add(id, len(data), len(chunk))
+            for id_, chunk in zip_longest(archive[b'items'], self.repository.get_many(archive[b'items'])):
+                data = self.key.decrypt(id_, chunk)
+                add(id_, len(data), len(chunk))
                 unpacker.feed(data)
                 for item in unpacker:
-                    try:
-                        for id, size, csize in item[b'chunks']:
-                            add(id, size, csize)
-                    except KeyError:
-                        pass
+                    if b'chunks' in item:
+                        for id_, size, csize in item[b'chunks']:
+                            add(id_, size, csize)
 
     def add_chunk(self, id, data, stats):
         if not self.txn_active:

+ 76 - 139
attic/remote.py

@@ -7,7 +7,6 @@ import sys
 
 from .helpers import Error
 from .repository import Repository
-from .lrucache import LRUCache
 
 BUFSIZE = 10 * 1024 * 1024
 
@@ -71,19 +70,19 @@ class RemoteRepository(object):
             self.name = name
 
     def __init__(self, location, create=False):
-        self.repository_url = '%s@%s:%s' % (location.user, location.host, location.path)
-        self.p = None
-        self.cache = LRUCache(256)
+        self.preload_ids = []
+        self.msgid = 0
         self.to_send = b''
-        self.extra = {}
-        self.pending = {}
+        self.cache = {}
+        self.ignore_responses = set()
+        self.responses = {}
         self.unpacker = msgpack.Unpacker(use_list=False)
-        self.msgid = 0
-        self.received_msgid = 0
+        self.repository_url = '%s@%s:%s' % (location.user, location.host, location.path)
+        self.p = None
         if location.host == '__testsuite__':
             args = [sys.executable, '-m', 'attic.archiver', 'serve']
         else:
-            args = ['ssh',]
+            args = ['ssh']
             if location.port:
                 args += ['-p', str(location.port)]
             if location.user:
@@ -99,11 +98,11 @@ class RemoteRepository(object):
         self.r_fds = [self.stdout_fd]
         self.x_fds = [self.stdin_fd, self.stdout_fd]
 
-        version = self.call('negotiate', (1,))
+        version = self.call('negotiate', 1)
         if version != 1:
             raise Exception('Server insisted on using unsupported protocol version %d' % version)
         try:
-            self.id = self.call('open', (location.path, create))
+            self.id = self.call('open', location.path, create)
         except self.RPCError as e:
             if e.name == b'DoesNotExist':
                 raise Repository.DoesNotExist(self.repository_url)
@@ -113,11 +112,33 @@ class RemoteRepository(object):
     def __del__(self):
         self.close()
 
-    def call(self, cmd, args, wait=True):
-        self.msgid += 1
-        to_send = msgpack.packb((1, self.msgid, cmd, args))
+    def call(self, cmd, *args, **kw):
+        for resp in self.call_many(cmd, [args], **kw):
+            return resp
+
+    def call_many(self, cmd, calls, wait=True, is_preloaded=False):
+        def fetch_from_cache(args):
+            msgid = self.cache[args].pop(0)
+            if not self.cache[args]:
+                del self.cache[args]
+            return msgid
+
+        calls = list(calls)
+        waiting_for = []
         w_fds = [self.stdin_fd]
-        while wait or to_send:
+        while wait or calls:
+            while waiting_for:
+                try:
+                    error, res = self.responses.pop(waiting_for[0])
+                    waiting_for.pop(0)
+                    if error:
+                        raise self.RPCError(error)
+                    else:
+                        yield res
+                        if not waiting_for and not calls:
+                            return
+                except KeyError:
+                    break
             r, w, x = select.select(self.r_fds, w_fds, self.x_fds, 1)
             if x:
                 raise Exception('FD exception occured')
@@ -127,147 +148,60 @@ class RemoteRepository(object):
                     raise ConnectionClosed()
                 self.unpacker.feed(data)
                 for type, msgid, error, res in self.unpacker:
-                    if msgid == self.msgid:
-                        self.received_msgid = msgid
-                        if error:
-                            raise self.RPCError(error)
-                        else:
-                            return res
+                    if msgid in self.ignore_responses:
+                        self.ignore_responses.remove(msgid)
                     else:
-                        args = self.pending.pop(msgid, None)
-                        if args is not None:
-                            self.cache[args] = msgid, res, error
+                        self.responses[msgid] = error, res
             if w:
-                if to_send:
-                    n = os.write(self.stdin_fd, to_send)
-                    assert n > 0
-                    to_send = memoryview(to_send)[n:]
-                if not to_send:
-                    w_fds = []
-
-    def _read(self):
-        data = os.read(self.stdout_fd, BUFSIZE)
-        if not data:
-            raise Exception('Remote host closed connection')
-        self.unpacker.feed(data)
-        to_yield = []
-        for type, msgid, error, res in self.unpacker:
-            self.received_msgid = msgid
-            args = self.pending.pop(msgid, None)
-            if args is not None:
-                self.cache[args] = msgid, res, error
-                for args, resp, error in self.extra.pop(msgid, []):
-                    if not resp and not error:
-                        resp, error = self.cache[args][1:]
-                    to_yield.append((resp, error))
-        for res, error in to_yield:
-            if error:
-                raise self.RPCError(error)
-            else:
-                yield res
-
-    def gen_request(self, cmd, argsv, wait):
-        data = []
-        m = self.received_msgid
-        for args in argsv:
-            # Make sure to invalidate any existing cache entries for non-get requests
-            if not args in self.cache:
-                self.msgid += 1
-                msgid = self.msgid
-                self.pending[msgid] = args
-                self.cache[args] = msgid, None, None
-                data.append(msgpack.packb((1, msgid, cmd, args)))
-            if wait:
-                msgid, resp, error = self.cache[args]
-                m = max(m, msgid)
-                self.extra.setdefault(m, []).append((args, resp, error))
-        return b''.join(data)
-
-    def gen_cache_requests(self, cmd, peek):
-        data = []
-        while True:
-            try:
-                args = (peek()[0],)
-            except StopIteration:
-                break
-            if args in self.cache:
-                continue
-            self.msgid += 1
-            msgid = self.msgid
-            self.pending[msgid] = args
-            self.cache[args] = msgid, None, None
-            data.append(msgpack.packb((1, msgid, cmd, args)))
-        return b''.join(data)
+                while not self.to_send and (calls or self.preload_ids) and len(waiting_for) < 100:
+                    if calls:
+                        if is_preloaded:
+                            if calls[0] in self.cache:
+                                waiting_for.append(fetch_from_cache(calls.pop(0)))
+                        else:
+                            args = calls.pop(0)
+                            if cmd == 'get' and args in self.cache:
+                                waiting_for.append(fetch_from_cache(args))
+                            else:
+                                self.msgid += 1
+                                waiting_for.append(self.msgid)
+                                self.to_send = msgpack.packb((1, self.msgid, cmd, args))
+                    if not self.to_send and self.preload_ids:
+                        args = (self.preload_ids.pop(0),)
+                        self.msgid += 1
+                        self.cache.setdefault(args, []).append(self.msgid)
+                        self.to_send = msgpack.packb((1, self.msgid, cmd, args))
 
-    def call_multi(self, cmd, argsv, wait=True, peek=None):
-        w_fds = [self.stdin_fd]
-        left = len(argsv)
-        data = self.gen_request(cmd, argsv, wait)
-        self.to_send += data
-        for args, resp, error in self.extra.pop(self.received_msgid, []):
-            left -= 1
-            if not resp and not error:
-                resp, error = self.cache[args][1:]
-            if error:
-                raise self.RPCError(error)
-            else:
-                yield resp
-        while left:
-            r, w, x = select.select(self.r_fds, w_fds, self.x_fds, 1)
-            if x:
-                raise Exception('FD exception occured')
-            if r:
-                for res in self._read():
-                    left -= 1
-                    yield res
-            if w:
-                if not self.to_send and peek:
-                    self.to_send = self.gen_cache_requests(cmd, peek)
                 if self.to_send:
-                    n = os.write(self.stdin_fd, self.to_send)
-                    assert n > 0
-#                    self.to_send = memoryview(self.to_send)[n:]
-                    self.to_send = self.to_send[n:]
-                else:
+                    self.to_send = self.to_send[os.write(self.stdin_fd, self.to_send):]
+                if not self.to_send and not (calls or self.preload_ids):
                     w_fds = []
-                    if not wait:
-                        return
+        self.ignore_responses |= set(waiting_for)
 
     def commit(self, *args):
-        self.call('commit', args)
+        return self.call('commit')
 
     def rollback(self, *args):
-        self.cache.clear()
-        self.pending.clear()
-        self.extra.clear()
-        return self.call('rollback', args)
+        return self.call('rollback')
 
-    def get(self, id):
+    def get(self, id_):
+        for resp in self.get_many([id_]):
+            return resp
+
+    def get_many(self, ids, is_preloaded=False):
         try:
-            for res in self.call_multi('get', [(id, )]):
-                return res
+            for resp in self.call_many('get', [(id_,) for id_ in ids], is_preloaded=is_preloaded):
+                yield resp
         except self.RPCError as e:
             if e.name == b'DoesNotExist':
                 raise Repository.DoesNotExist(self.repository_url)
             raise
 
-    def get_many(self, ids, peek=None):
-        return self.call_multi('get', [(id, ) for id in ids], peek=peek)
+    def put(self, id_, data, wait=True):
+        return self.call('put', id_, data, wait=wait)
 
-    def _invalidate(self, id):
-        key = (id, )
-        if key in self.cache:
-            self.pending.pop(self.cache.pop(key)[0], None)
-
-    def put(self, id, data, wait=True):
-        resp = self.call('put', (id, data), wait=wait)
-        self._invalidate(id)
-        return resp
-
-    def delete(self, id, wait=True):
-        resp = self.call('delete', (id, ), wait=wait)
-        self._invalidate(id)
-        return resp
+    def delete(self, id_, wait=True):
+        return self.call('delete', id_, wait=wait)
 
     def close(self):
         if self.p:
@@ -275,3 +209,6 @@ class RemoteRepository(object):
             self.p.stdout.close()
             self.p.wait()
             self.p = None
+
+    def preload(self, ids):
+        self.preload_ids += ids

+ 7 - 3
attic/repository.py

@@ -220,9 +220,9 @@ class Repository(object):
         except KeyError:
             raise self.DoesNotExist(self.path)
 
-    def get_many(self, ids, peek=None):
-        for id in ids:
-            yield self.get(id)
+    def get_many(self, ids, is_preloaded=False):
+        for id_ in ids:
+            yield self.get(id_)
 
     def put(self, id, data, wait=True):
         if not self._active_txn:
@@ -261,6 +261,10 @@ class Repository(object):
     def add_callback(self, cb, data):
         cb(None, None, data)
 
+    def preload(self, ids):
+        """Preload objects (only applies to remote repositories
+        """
+
 
 class LoggedIO(object):
 

+ 32 - 0
attic/testsuite/archive.py

@@ -0,0 +1,32 @@
+import msgpack
+from attic.testsuite import AtticTestCase
+from attic.archive import ChunkBuffer
+from attic.key import PlaintextKey
+
+
+class MockCache:
+
+    def __init__(self):
+        self.objects = {}
+
+    def add_chunk(self, id, data, stats=None):
+        self.objects[id] = data
+        return id, len(data), len(data)
+
+
+class ChunkBufferTestCase(AtticTestCase):
+
+    def test(self):
+        data = [{b'foo': 1}, {b'bar': 2}]
+        cache = MockCache()
+        key = PlaintextKey()
+        chunks = ChunkBuffer(cache, key, None)
+        for d in data:
+            chunks.add(d)
+            chunks.flush()
+        chunks.flush(flush=True)
+        self.assert_equal(len(chunks.chunks), 2)
+        unpacker = msgpack.Unpacker()
+        for id in chunks.chunks:
+            unpacker.feed(cache.objects[id])
+        self.assert_equal(data, list(unpacker))