|
@@ -6,8 +6,9 @@ import select
|
|
|
from subprocess import Popen, PIPE
|
|
|
import sys
|
|
|
import getpass
|
|
|
+import unittest
|
|
|
|
|
|
-from .store import Store
|
|
|
+from .store import Store, StoreTestCase
|
|
|
from .lrucache import LRUCache
|
|
|
|
|
|
BUFSIZE = 10 * 1024 * 1024
|
|
@@ -92,9 +93,29 @@ class RemoteStore(object):
|
|
|
self.id = self.call('open', (location.path, create))
|
|
|
|
|
|
def __del__(self):
|
|
|
- self.p.stdin.close()
|
|
|
- self.p.stdout.close()
|
|
|
- self.p.wait()
|
|
|
+ self.close()
|
|
|
+
|
|
|
+ def call(self, cmd, args, wait=True):
|
|
|
+ self.msgid += 1
|
|
|
+ self.p.stdin.write(msgpack.packb((1, self.msgid, cmd, args)))
|
|
|
+ while wait:
|
|
|
+ r, w, x = select.select(self.r_fds, [], self.x_fds, 1)
|
|
|
+ if x:
|
|
|
+ raise Exception('FD exception occured')
|
|
|
+ if r:
|
|
|
+ self.unpacker.feed(os.read(self.stdout_fd, BUFSIZE))
|
|
|
+ for type, msgid, error, res in self.unpacker:
|
|
|
+ if msgid == self.msgid:
|
|
|
+ assert msgid == self.msgid
|
|
|
+ self.received_msgid = msgid
|
|
|
+ if error:
|
|
|
+ raise self.RPCError(error)
|
|
|
+ else:
|
|
|
+ return res
|
|
|
+ else:
|
|
|
+ args = self.pending.pop(msgid, None)
|
|
|
+ if args is not None:
|
|
|
+ self.cache[args] = msgid, res, error
|
|
|
|
|
|
def _read(self):
|
|
|
data = os.read(self.stdout_fd, BUFSIZE)
|
|
@@ -104,32 +125,34 @@ class RemoteStore(object):
|
|
|
to_yield = []
|
|
|
for type, msgid, error, res in self.unpacker:
|
|
|
self.received_msgid = msgid
|
|
|
+ args = self.pending.pop(msgid, None)
|
|
|
+ if args is not None:
|
|
|
+ self.cache[args] = msgid, res, error
|
|
|
+ for args, resp, error in self.extra.pop(msgid, []):
|
|
|
+ if not resp and not error:
|
|
|
+ resp, error = self.cache[args][1:]
|
|
|
+ to_yield.append((resp, error))
|
|
|
+ for res, error in to_yield:
|
|
|
if error:
|
|
|
raise self.RPCError(error)
|
|
|
- args = self.pending.pop(msgid)
|
|
|
- self.cache[args] = msgid, res
|
|
|
- for args, resp in self.extra.pop(msgid, []):
|
|
|
- to_yield.append(resp or self.cache[args][1])
|
|
|
- for res in to_yield:
|
|
|
- yield res
|
|
|
-
|
|
|
- def call(self, cmd, args, wait=True):
|
|
|
- for res in self.call_multi(cmd, [args], wait=wait):
|
|
|
- return res
|
|
|
+ else:
|
|
|
+ yield res
|
|
|
|
|
|
- def gen_request(self, cmd, argsv):
|
|
|
+ def gen_request(self, cmd, argsv, wait):
|
|
|
data = []
|
|
|
m = self.received_msgid
|
|
|
for args in argsv:
|
|
|
+ # Make sure to invalidate any existing cache entries for non-get requests
|
|
|
if not args in self.cache:
|
|
|
self.msgid += 1
|
|
|
msgid = self.msgid
|
|
|
self.pending[msgid] = args
|
|
|
- self.cache[args] = msgid, None
|
|
|
+ self.cache[args] = msgid, None, None
|
|
|
data.append(msgpack.packb((1, msgid, cmd, args)))
|
|
|
- msgid, resp = self.cache[args]
|
|
|
- m = max(m, msgid)
|
|
|
- self.extra.setdefault(m, []).append((args, resp))
|
|
|
+ if wait:
|
|
|
+ msgid, resp, error = self.cache[args]
|
|
|
+ m = max(m, msgid)
|
|
|
+ self.extra.setdefault(m, []).append((args, resp, error))
|
|
|
return ''.join(data)
|
|
|
|
|
|
def gen_cache_requests(self, cmd, peek):
|
|
@@ -144,18 +167,23 @@ class RemoteStore(object):
|
|
|
self.msgid += 1
|
|
|
msgid = self.msgid
|
|
|
self.pending[msgid] = args
|
|
|
- self.cache[args] = msgid, None
|
|
|
+ self.cache[args] = msgid, None, None
|
|
|
data.append(msgpack.packb((1, msgid, cmd, args)))
|
|
|
return ''.join(data)
|
|
|
|
|
|
def call_multi(self, cmd, argsv, wait=True, peek=None):
|
|
|
w_fds = [self.stdin_fd]
|
|
|
left = len(argsv)
|
|
|
- data = self.gen_request(cmd, argsv)
|
|
|
+ data = self.gen_request(cmd, argsv, wait)
|
|
|
self.to_send += data
|
|
|
- for args, resp in self.extra.pop(self.received_msgid, []):
|
|
|
+ for args, resp, error in self.extra.pop(self.received_msgid, []):
|
|
|
left -= 1
|
|
|
- yield resp or self.cache[args][1]
|
|
|
+ if not resp and not error:
|
|
|
+ resp, error = self.cache[args][1:]
|
|
|
+ if error:
|
|
|
+ raise self.RPCError(error)
|
|
|
+ else:
|
|
|
+ yield resp
|
|
|
while left:
|
|
|
r, w, x = select.select(self.r_fds, w_fds, self.x_fds, 1)
|
|
|
if x:
|
|
@@ -173,16 +201,22 @@ class RemoteStore(object):
|
|
|
self.to_send = self.to_send[n:]
|
|
|
else:
|
|
|
w_fds = []
|
|
|
+ if not wait:
|
|
|
+ return
|
|
|
|
|
|
def commit(self, *args):
|
|
|
self.call('commit', args)
|
|
|
|
|
|
def rollback(self, *args):
|
|
|
+ self.cache.clear()
|
|
|
+ self.pending.clear()
|
|
|
+ self.extra.clear()
|
|
|
return self.call('rollback', args)
|
|
|
|
|
|
def get(self, id):
|
|
|
try:
|
|
|
- return self.call('get', (id, ))
|
|
|
+ for res in self.call_multi('get', [(id, )]):
|
|
|
+ return res
|
|
|
except self.RPCError, e:
|
|
|
if e.name == 'DoesNotExist':
|
|
|
raise self.DoesNotExist
|
|
@@ -191,12 +225,43 @@ class RemoteStore(object):
|
|
|
def get_many(self, ids, peek=None):
|
|
|
return self.call_multi('get', [(id, ) for id in ids], peek=peek)
|
|
|
|
|
|
+ def _invalidate(self, id):
|
|
|
+ key = (id, )
|
|
|
+ if key in self.cache:
|
|
|
+ self.pending.pop(self.cache.pop(key)[0], None)
|
|
|
+
|
|
|
def put(self, id, data, wait=True):
|
|
|
try:
|
|
|
- return self.call('put', (id, data), wait=wait)
|
|
|
+ resp = self.call('put', (id, data), wait=wait)
|
|
|
+ self._invalidate(id)
|
|
|
+ return resp
|
|
|
except self.RPCError, e:
|
|
|
if e.name == 'AlreadyExists':
|
|
|
raise self.AlreadyExists
|
|
|
|
|
|
def delete(self, id, wait=True):
|
|
|
- return self.call('delete', (id, ), wait=wait)
|
|
|
+ resp = self.call('delete', (id, ), wait=wait)
|
|
|
+ self._invalidate(id)
|
|
|
+ return resp
|
|
|
+
|
|
|
+ def close(self):
|
|
|
+ if self.p:
|
|
|
+ self.p.stdin.close()
|
|
|
+ self.p.stdout.close()
|
|
|
+ self.p.wait()
|
|
|
+ self.p = None
|
|
|
+
|
|
|
+
|
|
|
+class RemoteStoreTestCase(StoreTestCase):
|
|
|
+
|
|
|
+ def open(self, create=False):
|
|
|
+ from .helpers import Location
|
|
|
+ return RemoteStore(Location('localhost:' + os.path.join(self.tmppath, 'store')), create=create)
|
|
|
+
|
|
|
+
|
|
|
+def suite():
|
|
|
+ return unittest.TestLoader().loadTestsFromTestCase(RemoteStoreTestCase)
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ unittest.main()
|
|
|
+
|