Browse Source

Fix remote store cache issue

Jonas Borgström 12 years ago
parent
commit
177569d87b
5 changed files with 151 additions and 37 deletions
  1. 3 3
      darc/archive.py
  2. 5 3
      darc/lrucache.py
  3. 91 26
      darc/remote.py
  4. 32 3
      darc/store.py
  5. 20 2
      darc/test.py

+ 3 - 3
darc/archive.py

@@ -305,15 +305,15 @@ class Archive(object):
             try:
                 for id, chunk in izip_longest(ids, self.store.get_many(ids, peek)):
                     self.key.decrypt(id, chunk)
-            except Exception, e:
+            except Exception:
                 result(item, False)
                 return
             result(item, True)
 
     def delete(self, cache):
         unpacker = msgpack.Unpacker()
-        for id in self.metadata['items']:
-            unpacker.feed(self.key.decrypt(id, self.store.get(id)))
+        for id, chunk in izip_longest(self.metadata['items'], self.store.get_many(self.metadata['items'])):
+            unpacker.feed(self.key.decrypt(id, chunk))
             for item in unpacker:
                 try:
                     for chunk_id, size, csize in item['chunks']:

+ 5 - 3
darc/lrucache.py

@@ -20,12 +20,14 @@ class LRUCache(DictMixin):
         def __cmp__(self, other):
             return cmp(self.t, other.t)
 
-
     def __init__(self, size):
-        self._heap = []
-        self._dict = {}
         self.size = size
         self._t = 0
+        self.clear()
+
+    def clear(self):
+        self._heap = []
+        self._dict = {}
 
     def __setitem__(self, key, value):
         self._t += 1

+ 91 - 26
darc/remote.py

@@ -6,8 +6,9 @@ import select
 from subprocess import Popen, PIPE
 import sys
 import getpass
+import unittest
 
-from .store import Store
+from .store import Store, StoreTestCase
 from .lrucache import LRUCache
 
 BUFSIZE = 10 * 1024 * 1024
@@ -92,9 +93,29 @@ class RemoteStore(object):
         self.id = self.call('open', (location.path, create))
 
     def __del__(self):
-        self.p.stdin.close()
-        self.p.stdout.close()
-        self.p.wait()
+        self.close()
+
+    def call(self, cmd, args, wait=True):
+        self.msgid += 1
+        self.p.stdin.write(msgpack.packb((1, self.msgid, cmd, args)))
+        while wait:
+            r, w, x = select.select(self.r_fds, [], self.x_fds, 1)
+            if x:
+                raise Exception('FD exception occured')
+            if r:
+                self.unpacker.feed(os.read(self.stdout_fd, BUFSIZE))
+                for type, msgid, error, res in self.unpacker:
+                    if msgid == self.msgid:
+                        assert msgid == self.msgid
+                        self.received_msgid = msgid
+                        if error:
+                            raise self.RPCError(error)
+                        else:
+                            return res
+                    else:
+                        args = self.pending.pop(msgid, None)
+                        if args is not None:
+                            self.cache[args] = msgid, res, error
 
     def _read(self):
         data = os.read(self.stdout_fd, BUFSIZE)
@@ -104,32 +125,34 @@ class RemoteStore(object):
         to_yield = []
         for type, msgid, error, res in self.unpacker:
             self.received_msgid = msgid
+            args = self.pending.pop(msgid, None)
+            if args is not None:
+                self.cache[args] = msgid, res, error
+                for args, resp, error in self.extra.pop(msgid, []):
+                    if not resp and not error:
+                        resp, error = self.cache[args][1:]
+                    to_yield.append((resp, error))
+        for res, error in to_yield:
             if error:
                 raise self.RPCError(error)
-            args = self.pending.pop(msgid)
-            self.cache[args] = msgid, res
-            for args, resp in self.extra.pop(msgid, []):
-                to_yield.append(resp or self.cache[args][1])
-        for res in to_yield:
-            yield res
-
-    def call(self, cmd, args, wait=True):
-        for res in self.call_multi(cmd, [args], wait=wait):
-            return res
+            else:
+                yield res
 
-    def gen_request(self, cmd, argsv):
+    def gen_request(self, cmd, argsv, wait):
         data = []
         m = self.received_msgid
         for args in argsv:
+            # Make sure to invalidate any existing cache entries for non-get requests
             if not args in self.cache:
                 self.msgid += 1
                 msgid = self.msgid
                 self.pending[msgid] = args
-                self.cache[args] = msgid, None
+                self.cache[args] = msgid, None, None
                 data.append(msgpack.packb((1, msgid, cmd, args)))
-            msgid, resp = self.cache[args]
-            m = max(m, msgid)
-            self.extra.setdefault(m, []).append((args, resp))
+            if wait:
+                msgid, resp, error = self.cache[args]
+                m = max(m, msgid)
+                self.extra.setdefault(m, []).append((args, resp, error))
         return ''.join(data)
 
     def gen_cache_requests(self, cmd, peek):
@@ -144,18 +167,23 @@ class RemoteStore(object):
             self.msgid += 1
             msgid = self.msgid
             self.pending[msgid] = args
-            self.cache[args] = msgid, None
+            self.cache[args] = msgid, None, None
             data.append(msgpack.packb((1, msgid, cmd, args)))
         return ''.join(data)
 
     def call_multi(self, cmd, argsv, wait=True, peek=None):
         w_fds = [self.stdin_fd]
         left = len(argsv)
-        data = self.gen_request(cmd, argsv)
+        data = self.gen_request(cmd, argsv, wait)
         self.to_send += data
-        for args, resp in self.extra.pop(self.received_msgid, []):
+        for args, resp, error in self.extra.pop(self.received_msgid, []):
             left -= 1
-            yield resp or self.cache[args][1]
+            if not resp and not error:
+                resp, error = self.cache[args][1:]
+            if error:
+                raise self.RPCError(error)
+            else:
+                yield resp
         while left:
             r, w, x = select.select(self.r_fds, w_fds, self.x_fds, 1)
             if x:
@@ -173,16 +201,22 @@ class RemoteStore(object):
                     self.to_send = self.to_send[n:]
                 else:
                     w_fds = []
+                    if not wait:
+                        return
 
     def commit(self, *args):
         self.call('commit', args)
 
     def rollback(self, *args):
+        self.cache.clear()
+        self.pending.clear()
+        self.extra.clear()
         return self.call('rollback', args)
 
     def get(self, id):
         try:
-            return self.call('get', (id, ))
+            for res in self.call_multi('get', [(id, )]):
+                return res
         except self.RPCError, e:
             if e.name == 'DoesNotExist':
                 raise self.DoesNotExist
@@ -191,12 +225,43 @@ class RemoteStore(object):
     def get_many(self, ids, peek=None):
         return self.call_multi('get', [(id, ) for id in ids], peek=peek)
 
+    def _invalidate(self, id):
+        key = (id, )
+        if key in self.cache:
+            self.pending.pop(self.cache.pop(key)[0], None)
+
     def put(self, id, data, wait=True):
         try:
-            return self.call('put', (id, data), wait=wait)
+            resp = self.call('put', (id, data), wait=wait)
+            self._invalidate(id)
+            return resp
         except self.RPCError, e:
             if e.name == 'AlreadyExists':
                 raise self.AlreadyExists
 
     def delete(self, id, wait=True):
-        return self.call('delete', (id, ), wait=wait)
+        resp = self.call('delete', (id, ), wait=wait)
+        self._invalidate(id)
+        return resp
+
+    def close(self):
+        if self.p:
+            self.p.stdin.close()
+            self.p.stdout.close()
+            self.p.wait()
+            self.p = None
+
+
+class RemoteStoreTestCase(StoreTestCase):
+
+    def open(self, create=False):
+        from .helpers import Location
+        return RemoteStore(Location('localhost:' + os.path.join(self.tmppath, 'store')), create=create)
+
+
+def suite():
+    return unittest.TestLoader().loadTestsFromTestCase(RemoteStoreTestCase)
+
+if __name__ == '__main__':
+    unittest.main()
+

+ 32 - 3
darc/store.py

@@ -194,6 +194,9 @@ class Store(object):
             self.recover(self.path)
         self.open_index(self.io.head, read_only=True)
 
+    def _len(self):
+        return len(self.index)
+
     def get(self, id):
         try:
             segment, offset = self.index[id]
@@ -403,9 +406,12 @@ class LoggedIO(object):
 
 class StoreTestCase(unittest.TestCase):
 
+    def open(self, create=False):
+        return Store(os.path.join(self.tmppath, 'store'), create=create)
+
     def setUp(self):
         self.tmppath = tempfile.mkdtemp()
-        self.store = Store(os.path.join(self.tmppath, 'store'), create=True)
+        self.store = self.open(create=True)
 
     def tearDown(self):
         shutil.rmtree(self.tmppath)
@@ -419,12 +425,12 @@ class StoreTestCase(unittest.TestCase):
         self.assertRaises(self.store.DoesNotExist, lambda: self.store.get(key50))
         self.store.commit()
         self.store.close()
-        store2 = Store(os.path.join(self.tmppath, 'store'))
+        store2 = self.open()
         self.assertRaises(store2.DoesNotExist, lambda: store2.get(key50))
         for x in range(100):
             if x == 50:
                 continue
-            self.assertEqual(self.store.get('%-32d' % x), 'SOMEDATA')
+            self.assertEqual(store2.get('%-32d' % x), 'SOMEDATA')
 
     def test2(self):
         """Test multiple sequential transactions
@@ -437,6 +443,29 @@ class StoreTestCase(unittest.TestCase):
         self.store.commit()
         self.assertEqual(self.store.get('00000000000000000000000000000001'), 'bar')
 
+    def test_consistency(self):
+        """Test cache consistency
+        """
+        self.store.put('00000000000000000000000000000000', 'foo')
+        self.assertEqual(self.store.get('00000000000000000000000000000000'), 'foo')
+        self.store.put('00000000000000000000000000000000', 'foo2')
+        self.assertEqual(self.store.get('00000000000000000000000000000000'), 'foo2')
+        self.store.put('00000000000000000000000000000000', 'bar')
+        self.assertEqual(self.store.get('00000000000000000000000000000000'), 'bar')
+        self.store.delete('00000000000000000000000000000000')
+        self.assertRaises(self.store.DoesNotExist, lambda: self.store.get('00000000000000000000000000000000'))
+
+    def test_consistency2(self):
+        """Test cache consistency2
+        """
+        self.store.put('00000000000000000000000000000000', 'foo')
+        self.assertEqual(self.store.get('00000000000000000000000000000000'), 'foo')
+        self.store.commit()
+        self.store.put('00000000000000000000000000000000', 'foo2')
+        self.assertEqual(self.store.get('00000000000000000000000000000000'), 'foo2')
+        self.store.rollback()
+        self.assertEqual(self.store.get('00000000000000000000000000000000'), 'foo')
+
 
 def suite():
     return unittest.TestLoader().loadTestsFromTestCase(StoreTestCase)

+ 20 - 2
darc/test.py

@@ -9,8 +9,10 @@ import tempfile
 import unittest
 from xattr import xattr, XATTR_NOFOLLOW
 
-from . import store, helpers, lrucache
+from . import helpers, lrucache
 from .archiver import Archiver
+from .store import Store, suite as StoreSuite
+from .remote import Store, suite as RemoteStoreSuite
 
 
 class Test(unittest.TestCase):
@@ -112,6 +114,21 @@ class Test(unittest.TestCase):
         # end the same way as info_output
         assert info_output2.endswith(info_output)
 
+    def test_delete(self):
+        self.create_regual_file('file1', size=1024 * 80)
+        self.create_regual_file('dir2/file2', size=1024 * 80)
+        self.darc('init', '-p', '', self.store_location)
+        self.darc('create', self.store_location + '::test', 'input')
+        self.darc('create', self.store_location + '::test.2', 'input')
+        self.darc('verify', self.store_location + '::test')
+        self.darc('verify', self.store_location + '::test.2')
+        self.darc('delete', self.store_location + '::test')
+        self.darc('verify', self.store_location + '::test.2')
+        self.darc('delete', self.store_location + '::test.2')
+        # Make sure all data except the manifest has been deleted
+        store = Store(self.store_path)
+        self.assertEqual(store._len(), 1)
+
     def test_corrupted_store(self):
         self.create_src_archive('test')
         self.darc('verify', self.store_location + '::test')
@@ -141,7 +158,8 @@ def suite():
     suite = unittest.TestSuite()
     suite.addTest(unittest.TestLoader().loadTestsFromTestCase(Test))
     suite.addTest(unittest.TestLoader().loadTestsFromTestCase(RemoteTest))
-    suite.addTest(store.suite())
+    suite.addTest(StoreSuite())
+    suite.addTest(RemoteStoreSuite())
     suite.addTest(doctest.DocTestSuite(helpers))
     suite.addTest(lrucache.suite())
     return suite