remote.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. import fcntl
  2. import msgpack
  3. import os
  4. import select
  5. from subprocess import Popen, PIPE
  6. import sys
  7. import getpass
  8. from .repository import Repository
  9. from .lrucache import LRUCache
  10. BUFSIZE = 10 * 1024 * 1024
  11. class ConnectionClosed(Exception):
  12. """Connection closed by remote host
  13. """
  14. class RepositoryServer(object):
  15. def __init__(self):
  16. self.repository = None
  17. def serve(self):
  18. # Make stdin non-blocking
  19. fl = fcntl.fcntl(sys.stdin.fileno(), fcntl.F_GETFL)
  20. fcntl.fcntl(sys.stdin.fileno(), fcntl.F_SETFL, fl | os.O_NONBLOCK)
  21. # Make stdout blocking
  22. fl = fcntl.fcntl(sys.stdout.fileno(), fcntl.F_GETFL)
  23. fcntl.fcntl(sys.stdout.fileno(), fcntl.F_SETFL, fl & ~os.O_NONBLOCK)
  24. unpacker = msgpack.Unpacker(use_list=False)
  25. while True:
  26. r, w, es = select.select([sys.stdin], [], [], 10)
  27. if r:
  28. data = os.read(sys.stdin.fileno(), BUFSIZE)
  29. if not data:
  30. return
  31. unpacker.feed(data)
  32. for type, msgid, method, args in unpacker:
  33. method = method.decode('ascii')
  34. try:
  35. try:
  36. f = getattr(self, method)
  37. except AttributeError:
  38. f = getattr(self.repository, method)
  39. res = f(*args)
  40. except Exception as e:
  41. sys.stdout.buffer.write(msgpack.packb((1, msgid, e.__class__.__name__, None)))
  42. else:
  43. sys.stdout.buffer.write(msgpack.packb((1, msgid, None, res)))
  44. sys.stdout.flush()
  45. if es:
  46. return
  47. def negotiate(self, versions):
  48. return 1
  49. def open(self, path, create=False):
  50. path = os.fsdecode(path)
  51. if path.startswith('/~'):
  52. path = path[1:]
  53. self.repository = Repository(os.path.expanduser(path), create)
  54. return self.repository.id
  55. class RemoteRepository(object):
  56. class RPCError(Exception):
  57. def __init__(self, name):
  58. self.name = name
  59. def __init__(self, location, create=False):
  60. self.p = None
  61. self.cache = LRUCache(256)
  62. self.to_send = b''
  63. self.extra = {}
  64. self.pending = {}
  65. self.unpacker = msgpack.Unpacker(use_list=False)
  66. self.msgid = 0
  67. self.received_msgid = 0
  68. args = ['ssh', '-p', str(location.port), '%s@%s' % (location.user or getpass.getuser(), location.host), 'darc', 'serve']
  69. self.p = Popen(args, bufsize=0, stdin=PIPE, stdout=PIPE)
  70. self.stdin_fd = self.p.stdin.fileno()
  71. self.stdout_fd = self.p.stdout.fileno()
  72. fcntl.fcntl(self.stdin_fd, fcntl.F_SETFL, fcntl.fcntl(self.stdin_fd, fcntl.F_GETFL) | os.O_NONBLOCK)
  73. fcntl.fcntl(self.stdout_fd, fcntl.F_SETFL, fcntl.fcntl(self.stdout_fd, fcntl.F_GETFL) | os.O_NONBLOCK)
  74. self.r_fds = [self.stdout_fd]
  75. self.x_fds = [self.stdin_fd, self.stdout_fd]
  76. version = self.call('negotiate', (1,))
  77. if version != 1:
  78. raise Exception('Server insisted on using unsupported protocol version %d' % version)
  79. try:
  80. self.id = self.call('open', (location.path, create))
  81. except self.RPCError as e:
  82. if e.name == b'DoesNotExist':
  83. raise Repository.DoesNotExist
  84. elif e.name == b'AlreadyExists':
  85. raise Repository.AlreadyExists
  86. def __del__(self):
  87. self.close()
  88. def call(self, cmd, args, wait=True):
  89. self.msgid += 1
  90. to_send = msgpack.packb((1, self.msgid, cmd, args))
  91. w_fds = [self.stdin_fd]
  92. while wait or to_send:
  93. r, w, x = select.select(self.r_fds, w_fds, self.x_fds, 1)
  94. if x:
  95. raise Exception('FD exception occured')
  96. if r:
  97. data = os.read(self.stdout_fd, BUFSIZE)
  98. if not data:
  99. raise ConnectionClosed()
  100. self.unpacker.feed(data)
  101. for type, msgid, error, res in self.unpacker:
  102. if msgid == self.msgid:
  103. assert msgid == self.msgid
  104. self.received_msgid = msgid
  105. if error:
  106. raise self.RPCError(error)
  107. else:
  108. return res
  109. else:
  110. args = self.pending.pop(msgid, None)
  111. if args is not None:
  112. self.cache[args] = msgid, res, error
  113. if w:
  114. if to_send:
  115. n = os.write(self.stdin_fd, to_send)
  116. assert n > 0
  117. to_send = memoryview(to_send)[n:]
  118. else:
  119. w_fds = []
  120. def _read(self):
  121. data = os.read(self.stdout_fd, BUFSIZE)
  122. if not data:
  123. raise Exception('Remote host closed connection')
  124. self.unpacker.feed(data)
  125. to_yield = []
  126. for type, msgid, error, res in self.unpacker:
  127. self.received_msgid = msgid
  128. args = self.pending.pop(msgid, None)
  129. if args is not None:
  130. self.cache[args] = msgid, res, error
  131. for args, resp, error in self.extra.pop(msgid, []):
  132. if not resp and not error:
  133. resp, error = self.cache[args][1:]
  134. to_yield.append((resp, error))
  135. for res, error in to_yield:
  136. if error:
  137. raise self.RPCError(error)
  138. else:
  139. yield res
  140. def gen_request(self, cmd, argsv, wait):
  141. data = []
  142. m = self.received_msgid
  143. for args in argsv:
  144. # Make sure to invalidate any existing cache entries for non-get requests
  145. if not args in self.cache:
  146. self.msgid += 1
  147. msgid = self.msgid
  148. self.pending[msgid] = args
  149. self.cache[args] = msgid, None, None
  150. data.append(msgpack.packb((1, msgid, cmd, args)))
  151. if wait:
  152. msgid, resp, error = self.cache[args]
  153. m = max(m, msgid)
  154. self.extra.setdefault(m, []).append((args, resp, error))
  155. return b''.join(data)
  156. def gen_cache_requests(self, cmd, peek):
  157. data = []
  158. while True:
  159. try:
  160. args = (peek()[0],)
  161. except StopIteration:
  162. break
  163. if args in self.cache:
  164. continue
  165. self.msgid += 1
  166. msgid = self.msgid
  167. self.pending[msgid] = args
  168. self.cache[args] = msgid, None, None
  169. data.append(msgpack.packb((1, msgid, cmd, args)))
  170. return b''.join(data)
  171. def call_multi(self, cmd, argsv, wait=True, peek=None):
  172. w_fds = [self.stdin_fd]
  173. left = len(argsv)
  174. data = self.gen_request(cmd, argsv, wait)
  175. self.to_send += data
  176. for args, resp, error in self.extra.pop(self.received_msgid, []):
  177. left -= 1
  178. if not resp and not error:
  179. resp, error = self.cache[args][1:]
  180. if error:
  181. raise self.RPCError(error)
  182. else:
  183. yield resp
  184. while left:
  185. r, w, x = select.select(self.r_fds, w_fds, self.x_fds, 1)
  186. if x:
  187. raise Exception('FD exception occured')
  188. if r:
  189. for res in self._read():
  190. left -= 1
  191. yield res
  192. if w:
  193. if not self.to_send and peek:
  194. self.to_send = self.gen_cache_requests(cmd, peek)
  195. if self.to_send:
  196. n = os.write(self.stdin_fd, self.to_send)
  197. assert n > 0
  198. # self.to_send = memoryview(self.to_send)[n:]
  199. self.to_send = self.to_send[n:]
  200. else:
  201. w_fds = []
  202. if not wait:
  203. return
  204. def commit(self, *args):
  205. self.call('commit', args)
  206. def rollback(self, *args):
  207. self.cache.clear()
  208. self.pending.clear()
  209. self.extra.clear()
  210. return self.call('rollback', args)
  211. def get(self, id):
  212. try:
  213. for res in self.call_multi('get', [(id, )]):
  214. return res
  215. except self.RPCError as e:
  216. if e.name == b'DoesNotExist':
  217. raise Repository.DoesNotExist
  218. raise
  219. def get_many(self, ids, peek=None):
  220. return self.call_multi('get', [(id, ) for id in ids], peek=peek)
  221. def _invalidate(self, id):
  222. key = (id, )
  223. if key in self.cache:
  224. self.pending.pop(self.cache.pop(key)[0], None)
  225. def put(self, id, data, wait=True):
  226. resp = self.call('put', (id, data), wait=wait)
  227. self._invalidate(id)
  228. return resp
  229. def delete(self, id, wait=True):
  230. resp = self.call('delete', (id, ), wait=wait)
  231. self._invalidate(id)
  232. return resp
  233. def close(self):
  234. if self.p:
  235. self.p.stdin.close()
  236. self.p.stdout.close()
  237. self.p.wait()
  238. self.p = None