remote.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. from __future__ import with_statement
  2. import fcntl
  3. import msgpack
  4. import os
  5. import select
  6. from subprocess import Popen, PIPE
  7. import sys
  8. import getpass
  9. import unittest
  10. from .store import Store, StoreTestCase
  11. from .lrucache import LRUCache
  12. BUFSIZE = 10 * 1024 * 1024
  13. class StoreServer(object):
  14. def __init__(self):
  15. self.store = None
  16. def serve(self):
  17. # Make stdin non-blocking
  18. fl = fcntl.fcntl(sys.stdin.fileno(), fcntl.F_GETFL)
  19. fcntl.fcntl(sys.stdin.fileno(), fcntl.F_SETFL, fl | os.O_NONBLOCK)
  20. # Make stdout blocking
  21. fl = fcntl.fcntl(sys.stdout.fileno(), fcntl.F_GETFL)
  22. fcntl.fcntl(sys.stdout.fileno(), fcntl.F_SETFL, fl & ~os.O_NONBLOCK)
  23. unpacker = msgpack.Unpacker(use_list=False)
  24. while True:
  25. r, w, es = select.select([sys.stdin], [], [], 10)
  26. if r:
  27. data = os.read(sys.stdin.fileno(), BUFSIZE)
  28. if not data:
  29. return
  30. unpacker.feed(data)
  31. for type, msgid, method, args in unpacker:
  32. try:
  33. try:
  34. f = getattr(self, method)
  35. except AttributeError:
  36. f = getattr(self.store, method)
  37. res = f(*args)
  38. except Exception, e:
  39. sys.stdout.write(msgpack.packb((1, msgid, e.__class__.__name__, None)))
  40. else:
  41. sys.stdout.write(msgpack.packb((1, msgid, None, res)))
  42. sys.stdout.flush()
  43. if es:
  44. return
  45. def negotiate(self, versions):
  46. return 1
  47. def open(self, path, create=False):
  48. if path.startswith('/~'):
  49. path = path[1:]
  50. self.store = Store(os.path.expanduser(path), create)
  51. return self.store.id
  52. class RemoteStore(object):
  53. class RPCError(Exception):
  54. def __init__(self, name):
  55. self.name = name
  56. def __init__(self, location, create=False):
  57. self.cache = LRUCache(256)
  58. self.to_send = ''
  59. self.extra = {}
  60. self.pending = {}
  61. self.unpacker = msgpack.Unpacker(use_list=False)
  62. self.msgid = 0
  63. self.received_msgid = 0
  64. args = ['ssh', '-p', str(location.port), '%s@%s' % (location.user or getpass.getuser(), location.host), 'darc', 'serve']
  65. self.p = Popen(args, bufsize=0, stdin=PIPE, stdout=PIPE)
  66. self.stdin_fd = self.p.stdin.fileno()
  67. self.stdout_fd = self.p.stdout.fileno()
  68. fcntl.fcntl(self.stdin_fd, fcntl.F_SETFL, fcntl.fcntl(self.stdin_fd, fcntl.F_GETFL) | os.O_NONBLOCK)
  69. fcntl.fcntl(self.stdout_fd, fcntl.F_SETFL, fcntl.fcntl(self.stdout_fd, fcntl.F_GETFL) | os.O_NONBLOCK)
  70. self.r_fds = [self.stdout_fd]
  71. self.x_fds = [self.stdin_fd, self.stdout_fd]
  72. version = self.call('negotiate', (1,))
  73. if version != 1:
  74. raise Exception('Server insisted on using unsupported protocol version %d' % version)
  75. try:
  76. self.id = self.call('open', (location.path, create))
  77. except self.RPCError, e:
  78. if e.name == 'DoesNotExist':
  79. raise Store.DoesNotExist
  80. elif e.name == 'AlreadyExists':
  81. raise Store.AlreadyExists
  82. def __del__(self):
  83. self.close()
  84. def call(self, cmd, args, wait=True):
  85. self.msgid += 1
  86. to_send = msgpack.packb((1, self.msgid, cmd, args))
  87. w_fds = [self.stdin_fd]
  88. while wait or to_send:
  89. r, w, x = select.select(self.r_fds, w_fds, self.x_fds, 1)
  90. if x:
  91. raise Exception('FD exception occured')
  92. if r:
  93. data = os.read(self.stdout_fd, BUFSIZE)
  94. if not data:
  95. raise Exception('Remote host closed connection')
  96. self.unpacker.feed(data)
  97. for type, msgid, error, res in self.unpacker:
  98. if msgid == self.msgid:
  99. assert msgid == self.msgid
  100. self.received_msgid = msgid
  101. if error:
  102. raise self.RPCError(error)
  103. else:
  104. return res
  105. else:
  106. args = self.pending.pop(msgid, None)
  107. if args is not None:
  108. self.cache[args] = msgid, res, error
  109. if w:
  110. if to_send:
  111. n = os.write(self.stdin_fd, to_send)
  112. assert n > 0
  113. to_send = to_send[n:]
  114. else:
  115. w_fds = []
  116. def _read(self):
  117. data = os.read(self.stdout_fd, BUFSIZE)
  118. if not data:
  119. raise Exception('Remote host closed connection')
  120. self.unpacker.feed(data)
  121. to_yield = []
  122. for type, msgid, error, res in self.unpacker:
  123. self.received_msgid = msgid
  124. args = self.pending.pop(msgid, None)
  125. if args is not None:
  126. self.cache[args] = msgid, res, error
  127. for args, resp, error in self.extra.pop(msgid, []):
  128. if not resp and not error:
  129. resp, error = self.cache[args][1:]
  130. to_yield.append((resp, error))
  131. for res, error in to_yield:
  132. if error:
  133. raise self.RPCError(error)
  134. else:
  135. yield res
  136. def gen_request(self, cmd, argsv, wait):
  137. data = []
  138. m = self.received_msgid
  139. for args in argsv:
  140. # Make sure to invalidate any existing cache entries for non-get requests
  141. if not args in self.cache:
  142. self.msgid += 1
  143. msgid = self.msgid
  144. self.pending[msgid] = args
  145. self.cache[args] = msgid, None, None
  146. data.append(msgpack.packb((1, msgid, cmd, args)))
  147. if wait:
  148. msgid, resp, error = self.cache[args]
  149. m = max(m, msgid)
  150. self.extra.setdefault(m, []).append((args, resp, error))
  151. return ''.join(data)
  152. def gen_cache_requests(self, cmd, peek):
  153. data = []
  154. while True:
  155. try:
  156. args = (peek()[0],)
  157. except StopIteration:
  158. break
  159. if args in self.cache:
  160. continue
  161. self.msgid += 1
  162. msgid = self.msgid
  163. self.pending[msgid] = args
  164. self.cache[args] = msgid, None, None
  165. data.append(msgpack.packb((1, msgid, cmd, args)))
  166. return ''.join(data)
  167. def call_multi(self, cmd, argsv, wait=True, peek=None):
  168. w_fds = [self.stdin_fd]
  169. left = len(argsv)
  170. data = self.gen_request(cmd, argsv, wait)
  171. self.to_send += data
  172. for args, resp, error in self.extra.pop(self.received_msgid, []):
  173. left -= 1
  174. if not resp and not error:
  175. resp, error = self.cache[args][1:]
  176. if error:
  177. raise self.RPCError(error)
  178. else:
  179. yield resp
  180. while left:
  181. r, w, x = select.select(self.r_fds, w_fds, self.x_fds, 1)
  182. if x:
  183. raise Exception('FD exception occured')
  184. if r:
  185. for res in self._read():
  186. left -= 1
  187. yield res
  188. if w:
  189. if not self.to_send and peek:
  190. self.to_send = self.gen_cache_requests(cmd, peek)
  191. if self.to_send:
  192. n = os.write(self.stdin_fd, self.to_send)
  193. assert n > 0
  194. self.to_send = self.to_send[n:]
  195. else:
  196. w_fds = []
  197. if not wait:
  198. return
  199. def commit(self, *args):
  200. self.call('commit', args)
  201. def rollback(self, *args):
  202. self.cache.clear()
  203. self.pending.clear()
  204. self.extra.clear()
  205. return self.call('rollback', args)
  206. def get(self, id):
  207. try:
  208. for res in self.call_multi('get', [(id, )]):
  209. return res
  210. except self.RPCError, e:
  211. if e.name == 'DoesNotExist':
  212. raise Store.DoesNotExist
  213. raise
  214. def get_many(self, ids, peek=None):
  215. return self.call_multi('get', [(id, ) for id in ids], peek=peek)
  216. def _invalidate(self, id):
  217. key = (id, )
  218. if key in self.cache:
  219. self.pending.pop(self.cache.pop(key)[0], None)
  220. def put(self, id, data, wait=True):
  221. resp = self.call('put', (id, data), wait=wait)
  222. self._invalidate(id)
  223. return resp
  224. def delete(self, id, wait=True):
  225. resp = self.call('delete', (id, ), wait=wait)
  226. self._invalidate(id)
  227. return resp
  228. def close(self):
  229. if self.p:
  230. self.p.stdin.close()
  231. self.p.stdout.close()
  232. self.p.wait()
  233. self.p = None
  234. class RemoteStoreTestCase(StoreTestCase):
  235. def open(self, create=False):
  236. from .helpers import Location
  237. return RemoteStore(Location('localhost:' + os.path.join(self.tmppath, 'store')), create=create)
  238. def suite():
  239. return unittest.TestLoader().loadTestsFromTestCase(RemoteStoreTestCase)
  240. if __name__ == '__main__':
  241. unittest.main()