Bladeren bron

Improved remote store request pipelining.

Jonas Borgström 12 jaren geleden
bovenliggende
commit
46c6793979
4 gewijzigde bestanden met toevoegingen van 138 en 34 verwijderingen
  1. 53 8
      darc/archive.py
  2. 5 9
      darc/archiver.py
  3. 79 16
      darc/remote.py
  4. 1 1
      darc/store.py

+ 53 - 8
darc/archive.py

@@ -23,6 +23,49 @@ have_lchmod = hasattr(os, 'lchmod')
 linux = sys.platform == 'linux2'
 
 
+class ItemIter(object):
+
+    def __init__(self, unpacker, filter):
+        self.unpacker = iter(unpacker)
+        self.filter = filter
+        self.stack = []
+        self._peek = None
+        self._peek_iter = None
+        global foo
+        foo = self
+
+    def __iter__(self):
+        return self
+
+    def next(self):
+        if self.stack:
+            return self.stack.pop(0)
+        self._peek = None
+        return self.get_next()
+
+    def get_next(self):
+        next = self.unpacker.next()
+        while self.filter and not self.filter(next):
+            next = self.unpacker.next()
+        return next
+
+    def peek(self):
+        while True:
+            while not self._peek or not self._peek_iter:
+                if len(self.stack) > 100:
+                    raise StopIteration
+                self._peek = self.get_next()
+                self.stack.append(self._peek)
+                if 'chunks' in self._peek:
+                    self._peek_iter = iter(self._peek['chunks'])
+                else:
+                    self._peek_iter = None
+            try:
+                return self._peek_iter.next()
+            except StopIteration:
+                self._peek = None
+
+
 class Archive(object):
 
     class DoesNotExist(Exception):
@@ -82,12 +125,13 @@ class Archive(object):
     def __repr__(self):
         return 'Archive(%r)' % self.name
 
-    def iter_items(self):
+    def iter_items(self, filter=None):
         unpacker = msgpack.Unpacker()
         for id in self.metadata['items']:
             unpacker.feed(self.key.decrypt(id, self.store.get(id)))
-            for item in unpacker:
-                yield item
+            iter = ItemIter(unpacker, filter)
+            for item in iter:
+                yield item, iter.peek
 
     def add_item(self, item):
         self.items.write(msgpack.packb(item))
@@ -164,7 +208,7 @@ class Archive(object):
         cache.rollback()
         return stats
 
-    def extract_item(self, item, dest=None, start_cb=None, restore_attrs=True):
+    def extract_item(self, item, dest=None, start_cb=None, restore_attrs=True, peek=None):
         dest = dest or self.cwd
         assert item['path'][0] not in ('/', '\\', ':')
         path = os.path.join(dest, encode_filename(item['path']))
@@ -193,7 +237,7 @@ class Archive(object):
                     fd = open(path, 'wb')
                     start_cb(item)
                     ids = [id for id, size, csize in item['chunks']]
-                    for id, chunk in izip(ids, self.store.get_many(ids)):
+                    for id, chunk in izip(ids, self.store.get_many(ids, peek)):
                         data = self.key.decrypt(id, chunk)
                         fd.write(data)
                     fd.close()
@@ -244,7 +288,7 @@ class Archive(object):
             # FIXME: We should really call futimes here (c extension required)
             os.utime(path, (item['mtime'], item['mtime']))
 
-    def verify_file(self, item, start, result):
+    def verify_file(self, item, start, result, peek=None):
         if not item['chunks']:
             start(item)
             result(item, True)
@@ -252,9 +296,10 @@ class Archive(object):
             start(item)
             ids = [id for id, size, csize in item['chunks']]
             try:
-                for id, chunk in izip(ids, self.store.get_many(ids)):
+                for id, chunk in izip(ids, self.store.get_many(ids, peek)):
+                    if chunk:
                         self.key.decrypt(id, chunk)
-            except Exception:
+            except Exception, e:
                 result(item, False)
                 return
             result(item, True)

+ 5 - 9
darc/archiver.py

@@ -163,14 +163,12 @@ class Archiver(object):
         archive = Archive(store, key, manifest, args.archive.archive,
                           numeric_owner=args.numeric_owner)
         dirs = []
-        for item in archive.iter_items():
-            if exclude_path(item['path'], args.patterns):
-                continue
+        for item, peek in archive.iter_items(lambda item: not exclude_path(item['path'], args.patterns)):
             if stat.S_ISDIR(item['mode']):
                 dirs.append(item)
                 archive.extract_item(item, args.dest, start_cb, restore_attrs=False)
             else:
-                archive.extract_item(item, args.dest, start_cb)
+                archive.extract_item(item, args.dest, start_cb, peek=peek)
             if dirs and not item['path'].startswith(dirs[-1]['path']):
                 archive.extract_item(dirs.pop(-1), args.dest)
         while dirs:
@@ -193,7 +191,7 @@ class Archiver(object):
         if args.src.archive:
             tmap = {1: 'p', 2: 'c', 4: 'd', 6: 'b', 010: '-', 012: 'l', 014: 's'}
             archive = Archive(store, key, manifest, args.src.archive)
-            for item in archive.iter_items():
+            for item, _ in archive.iter_items():
                 type = tmap.get(item['mode'] / 4096, '?')
                 mode = format_file_mode(item['mode'])
                 size = 0
@@ -234,11 +232,9 @@ class Archiver(object):
             else:
                 self.print_verbose('ERROR')
                 self.print_error('%s: verification failed' % item['path'])
-        for item in archive.iter_items():
-            if exclude_path(item['path'], args.patterns):
-                return
+        for item, peek in archive.iter_items(lambda item: not exclude_path(item['path'], args.patterns)):
             if stat.S_ISREG(item['mode']) and 'chunks' in item:
-                archive.verify_file(item, start_cb, result_cb)
+                archive.verify_file(item, start_cb, result_cb, peek=peek)
         return self.exit_code
 
     def do_info(self, args):

+ 79 - 16
darc/remote.py

@@ -8,6 +8,7 @@ import sys
 import getpass
 
 from .store import Store
+from .lrucache import LRUCache
 
 BUFSIZE = 10 * 1024 * 1024
 
@@ -71,11 +72,20 @@ class RemoteStore(object):
             self.name = name
 
     def __init__(self, location, create=False):
+        self.cache = LRUCache(200)
+        self.to_send = ''
+        self.extra = {}
+        self.pending_cache = {}
         self.unpacker = msgpack.Unpacker()
         self.msgid = 0
+        self.received_msgid = 0
         args = ['ssh', '-p', str(location.port), '%s@%s' % (location.user or getpass.getuser(), location.host), 'darc', 'serve']
         self.p = Popen(args, bufsize=0, stdin=PIPE, stdout=PIPE)
+        self.stdin_fd = self.p.stdin.fileno()
         self.stdout_fd = self.p.stdout.fileno()
+        self.r_fds = [self.stdout_fd]
+        self.x_fds = [self.stdin_fd, self.stdout_fd]
+
         version = self.call('negotiate', (1,))
         if version != 1:
             raise Exception('Server insisted on using unsupported protocol version %d' % version)
@@ -86,33 +96,86 @@ class RemoteStore(object):
         self.p.stdout.close()
         self.p.wait()
 
-    def _read(self, msgids):
+    def _read(self):
         data = os.read(self.stdout_fd, BUFSIZE)
+        if not data:
+            raise Exception('EOF')
         self.unpacker.feed(data)
+        to_yield = []
         for type, msgid, error, res in self.unpacker:
+            self.received_msgid = msgid
             if error:
                 raise self.RPCError(error)
-            if msgid in msgids:
-                msgids.remove(msgid)
-                yield res
+            if msgid in self.pending_cache:
+                args = self.pending_cache.pop(msgid)
+                self.cache[args] = msgid, res
+            else:
+                print 'unknown response'
+            for args in self.extra.pop(msgid, []):
+                to_yield.append(self.cache[args][1])
+        for res in to_yield:
+            yield res
 
     def call(self, cmd, args, wait=True):
         for res in self.call_multi(cmd, [args], wait=wait):
             return res
 
-    def call_multi(self, cmd, argsv, wait=True):
-        msgids = set()
+    def gen_request(self, cmd, argsv):
+        data = []
+        m = self.received_msgid
         for args in argsv:
-            if select.select([self.stdout_fd], [], [], 0)[0]:
-                for res in self._read(msgids):
-                    yield res
+            if not args in self.cache:
+                self.msgid += 1
+                msgid = self.msgid
+                self.pending_cache[msgid] = args
+                self.cache[args] = msgid, None
+                data.append(msgpack.packb((1, msgid, cmd, args)))
+            msgid, resp = self.cache[args]
+            m = max(m, msgid)
+            self.extra.setdefault(m, []).append(args)
+        return ''.join(data)
+
+    def gen_cache_requests(self, cmd, peek):
+        data = []
+        while True:
+            try:
+                args = (peek()[0],)
+            except StopIteration:
+                break
+            if args in self.cache:
+                continue
             self.msgid += 1
             msgid = self.msgid
-            msgids.add(msgid)
-            self.p.stdin.write(msgpack.packb((1, msgid, cmd, args)))
-        while msgids and wait:
-            for res in self._read(msgids):
-                yield res
+            self.pending_cache[msgid] = args
+            self.cache[args] = msgid, None
+            data.append(msgpack.packb((1, msgid, cmd, args)))
+        return ''.join(data)
+
+    def call_multi(self, cmd, argsv, wait=True, peek=None):
+        w_fds = [self.stdin_fd]
+        left = len(argsv)
+        data = self.gen_request(cmd, argsv)
+        self.to_send += data
+        for args in self.extra.pop(self.received_msgid, []):
+            left -= 1
+            yield self.cache[args][1]
+        while left:
+            r, w, x = select.select(self.r_fds, w_fds, self.x_fds, 1)
+            if x:
+                raise Exception('FD exception occured')
+            if r:
+                for res in self._read():
+                    left -= 1
+                    yield res
+            if w:
+                if not self.to_send and peek:
+                    self.to_send = self.gen_cache_requests(cmd, peek)
+                if self.to_send:
+                    n = os.write(self.stdin_fd, self.to_send)
+                    assert n > 0
+                    self.to_send = self.to_send[n:]
+                else:
+                    w_fds = []
 
     def commit(self, *args):
         self.call('commit', args)
@@ -128,8 +191,8 @@ class RemoteStore(object):
                 raise self.DoesNotExist
             raise
 
-    def get_many(self, ids):
-        return self.call_multi('get', [(id, ) for id in ids])
+    def get_many(self, ids, peek=None):
+        return self.call_multi('get', [(id, ) for id in ids], peek=peek)
 
     def put(self, id, data, wait=True):
         try:

+ 1 - 1
darc/store.py

@@ -201,7 +201,7 @@ class Store(object):
         except KeyError:
             raise self.DoesNotExist
 
-    def get_many(self, ids):
+    def get_many(self, ids, peek=None):
         for id in ids:
             yield self.get(id)