remote.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. import fcntl
  2. import msgpack
  3. import os
  4. import select
  5. from subprocess import Popen, PIPE
  6. import sys
  7. import getpass
  8. import unittest
  9. from .store import Store, StoreTestCase
  10. from .lrucache import LRUCache
  11. BUFSIZE = 10 * 1024 * 1024
  12. class StoreServer(object):
  13. def __init__(self):
  14. self.store = None
  15. def serve(self):
  16. # Make stdin non-blocking
  17. fl = fcntl.fcntl(sys.stdin.fileno(), fcntl.F_GETFL)
  18. fcntl.fcntl(sys.stdin.fileno(), fcntl.F_SETFL, fl | os.O_NONBLOCK)
  19. # Make stdout blocking
  20. fl = fcntl.fcntl(sys.stdout.fileno(), fcntl.F_GETFL)
  21. fcntl.fcntl(sys.stdout.fileno(), fcntl.F_SETFL, fl & ~os.O_NONBLOCK)
  22. unpacker = msgpack.Unpacker(use_list=False)
  23. while True:
  24. r, w, es = select.select([sys.stdin], [], [], 10)
  25. if r:
  26. data = os.read(sys.stdin.fileno(), BUFSIZE)
  27. if not data:
  28. return
  29. unpacker.feed(data)
  30. for type, msgid, method, args in unpacker:
  31. method = method.decode('ascii')
  32. try:
  33. try:
  34. f = getattr(self, method)
  35. except AttributeError:
  36. f = getattr(self.store, method)
  37. res = f(*args)
  38. except Exception as e:
  39. sys.stdout.buffer.write(msgpack.packb((1, msgid, e.__class__.__name__, None)))
  40. else:
  41. sys.stdout.buffer.write(msgpack.packb((1, msgid, None, res)))
  42. sys.stdout.flush()
  43. if es:
  44. return
  45. def negotiate(self, versions):
  46. return 1
  47. def open(self, path, create=False):
  48. path = os.fsdecode(path)
  49. if path.startswith('/~'):
  50. path = path[1:]
  51. self.store = Store(os.path.expanduser(path), create)
  52. return self.store.id
  53. class RemoteStore(object):
  54. class RPCError(Exception):
  55. def __init__(self, name):
  56. self.name = name
  57. def __init__(self, location, create=False):
  58. self.p = None
  59. self.cache = LRUCache(256)
  60. self.to_send = b''
  61. self.extra = {}
  62. self.pending = {}
  63. self.unpacker = msgpack.Unpacker(use_list=False)
  64. self.msgid = 0
  65. self.received_msgid = 0
  66. args = ['ssh', '-p', str(location.port), '%s@%s' % (location.user or getpass.getuser(), location.host), 'darc', 'serve']
  67. self.p = Popen(args, bufsize=0, stdin=PIPE, stdout=PIPE)
  68. self.stdin_fd = self.p.stdin.fileno()
  69. self.stdout_fd = self.p.stdout.fileno()
  70. fcntl.fcntl(self.stdin_fd, fcntl.F_SETFL, fcntl.fcntl(self.stdin_fd, fcntl.F_GETFL) | os.O_NONBLOCK)
  71. fcntl.fcntl(self.stdout_fd, fcntl.F_SETFL, fcntl.fcntl(self.stdout_fd, fcntl.F_GETFL) | os.O_NONBLOCK)
  72. self.r_fds = [self.stdout_fd]
  73. self.x_fds = [self.stdin_fd, self.stdout_fd]
  74. version = self.call('negotiate', (1,))
  75. if version != 1:
  76. raise Exception('Server insisted on using unsupported protocol version %d' % version)
  77. try:
  78. self.id = self.call('open', (location.path, create))
  79. except self.RPCError as e:
  80. if e.name == b'DoesNotExist':
  81. raise Store.DoesNotExist
  82. elif e.name == b'AlreadyExists':
  83. raise Store.AlreadyExists
  84. def __del__(self):
  85. self.close()
  86. def call(self, cmd, args, wait=True):
  87. self.msgid += 1
  88. to_send = msgpack.packb((1, self.msgid, cmd, args))
  89. w_fds = [self.stdin_fd]
  90. while wait or to_send:
  91. r, w, x = select.select(self.r_fds, w_fds, self.x_fds, 1)
  92. if x:
  93. raise Exception('FD exception occured')
  94. if r:
  95. data = os.read(self.stdout_fd, BUFSIZE)
  96. if not data:
  97. raise Exception('Remote host closed connection')
  98. self.unpacker.feed(data)
  99. for type, msgid, error, res in self.unpacker:
  100. if msgid == self.msgid:
  101. assert msgid == self.msgid
  102. self.received_msgid = msgid
  103. if error:
  104. raise self.RPCError(error)
  105. else:
  106. return res
  107. else:
  108. args = self.pending.pop(msgid, None)
  109. if args is not None:
  110. self.cache[args] = msgid, res, error
  111. if w:
  112. if to_send:
  113. n = os.write(self.stdin_fd, to_send)
  114. assert n > 0
  115. to_send = memoryview(to_send)[n:]
  116. else:
  117. w_fds = []
  118. def _read(self):
  119. data = os.read(self.stdout_fd, BUFSIZE)
  120. if not data:
  121. raise Exception('Remote host closed connection')
  122. self.unpacker.feed(data)
  123. to_yield = []
  124. for type, msgid, error, res in self.unpacker:
  125. self.received_msgid = msgid
  126. args = self.pending.pop(msgid, None)
  127. if args is not None:
  128. self.cache[args] = msgid, res, error
  129. for args, resp, error in self.extra.pop(msgid, []):
  130. if not resp and not error:
  131. resp, error = self.cache[args][1:]
  132. to_yield.append((resp, error))
  133. for res, error in to_yield:
  134. if error:
  135. raise self.RPCError(error)
  136. else:
  137. yield res
  138. def gen_request(self, cmd, argsv, wait):
  139. data = []
  140. m = self.received_msgid
  141. for args in argsv:
  142. # Make sure to invalidate any existing cache entries for non-get requests
  143. if not args in self.cache:
  144. self.msgid += 1
  145. msgid = self.msgid
  146. self.pending[msgid] = args
  147. self.cache[args] = msgid, None, None
  148. data.append(msgpack.packb((1, msgid, cmd, args)))
  149. if wait:
  150. msgid, resp, error = self.cache[args]
  151. m = max(m, msgid)
  152. self.extra.setdefault(m, []).append((args, resp, error))
  153. return b''.join(data)
  154. def gen_cache_requests(self, cmd, peek):
  155. data = []
  156. while True:
  157. try:
  158. args = (peek()[0],)
  159. except StopIteration:
  160. break
  161. if args in self.cache:
  162. continue
  163. self.msgid += 1
  164. msgid = self.msgid
  165. self.pending[msgid] = args
  166. self.cache[args] = msgid, None, None
  167. data.append(msgpack.packb((1, msgid, cmd, args)))
  168. return b''.join(data)
  169. def call_multi(self, cmd, argsv, wait=True, peek=None):
  170. w_fds = [self.stdin_fd]
  171. left = len(argsv)
  172. data = self.gen_request(cmd, argsv, wait)
  173. self.to_send += data
  174. for args, resp, error in self.extra.pop(self.received_msgid, []):
  175. left -= 1
  176. if not resp and not error:
  177. resp, error = self.cache[args][1:]
  178. if error:
  179. raise self.RPCError(error)
  180. else:
  181. yield resp
  182. while left:
  183. r, w, x = select.select(self.r_fds, w_fds, self.x_fds, 1)
  184. if x:
  185. raise Exception('FD exception occured')
  186. if r:
  187. for res in self._read():
  188. left -= 1
  189. yield res
  190. if w:
  191. if not self.to_send and peek:
  192. self.to_send = self.gen_cache_requests(cmd, peek)
  193. if self.to_send:
  194. n = os.write(self.stdin_fd, self.to_send)
  195. assert n > 0
  196. # self.to_send = memoryview(self.to_send)[n:]
  197. self.to_send = self.to_send[n:]
  198. else:
  199. w_fds = []
  200. if not wait:
  201. return
  202. def commit(self, *args):
  203. self.call('commit', args)
  204. def rollback(self, *args):
  205. self.cache.clear()
  206. self.pending.clear()
  207. self.extra.clear()
  208. return self.call('rollback', args)
  209. def get(self, id):
  210. try:
  211. for res in self.call_multi('get', [(id, )]):
  212. return res
  213. except self.RPCError as e:
  214. if e.name == b'DoesNotExist':
  215. raise Store.DoesNotExist
  216. raise
  217. def get_many(self, ids, peek=None):
  218. return self.call_multi('get', [(id, ) for id in ids], peek=peek)
  219. def _invalidate(self, id):
  220. key = (id, )
  221. if key in self.cache:
  222. self.pending.pop(self.cache.pop(key)[0], None)
  223. def put(self, id, data, wait=True):
  224. resp = self.call('put', (id, data), wait=wait)
  225. self._invalidate(id)
  226. return resp
  227. def delete(self, id, wait=True):
  228. resp = self.call('delete', (id, ), wait=wait)
  229. self._invalidate(id)
  230. return resp
  231. def close(self):
  232. if self.p:
  233. self.p.stdin.close()
  234. self.p.stdout.close()
  235. self.p.wait()
  236. self.p = None
  237. class RemoteStoreTestCase(StoreTestCase):
  238. def open(self, create=False):
  239. from .helpers import Location
  240. return RemoteStore(Location('localhost:' + os.path.join(self.tmppath, 'store')), create=create)
  241. def suite():
  242. return unittest.TestLoader().loadTestsFromTestCase(RemoteStoreTestCase)
  243. if __name__ == '__main__':
  244. unittest.main()