2
0

remote.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. import fcntl
  2. import msgpack
  3. import os
  4. import select
  5. from subprocess import Popen, PIPE
  6. import sys
  7. import getpass
  8. from .repository import Repository
  9. from .lrucache import LRUCache
  10. BUFSIZE = 10 * 1024 * 1024
  11. class RepositoryServer(object):
  12. def __init__(self):
  13. self.repository = None
  14. def serve(self):
  15. # Make stdin non-blocking
  16. fl = fcntl.fcntl(sys.stdin.fileno(), fcntl.F_GETFL)
  17. fcntl.fcntl(sys.stdin.fileno(), fcntl.F_SETFL, fl | os.O_NONBLOCK)
  18. # Make stdout blocking
  19. fl = fcntl.fcntl(sys.stdout.fileno(), fcntl.F_GETFL)
  20. fcntl.fcntl(sys.stdout.fileno(), fcntl.F_SETFL, fl & ~os.O_NONBLOCK)
  21. unpacker = msgpack.Unpacker(use_list=False)
  22. while True:
  23. r, w, es = select.select([sys.stdin], [], [], 10)
  24. if r:
  25. data = os.read(sys.stdin.fileno(), BUFSIZE)
  26. if not data:
  27. return
  28. unpacker.feed(data)
  29. for type, msgid, method, args in unpacker:
  30. method = method.decode('ascii')
  31. try:
  32. try:
  33. f = getattr(self, method)
  34. except AttributeError:
  35. f = getattr(self.repository, method)
  36. res = f(*args)
  37. except Exception as e:
  38. sys.stdout.buffer.write(msgpack.packb((1, msgid, e.__class__.__name__, None)))
  39. else:
  40. sys.stdout.buffer.write(msgpack.packb((1, msgid, None, res)))
  41. sys.stdout.flush()
  42. if es:
  43. return
  44. def negotiate(self, versions):
  45. return 1
  46. def open(self, path, create=False):
  47. path = os.fsdecode(path)
  48. if path.startswith('/~'):
  49. path = path[1:]
  50. self.repository = Repository(os.path.expanduser(path), create)
  51. return self.repository.id
  52. class RemoteRepository(object):
  53. class RPCError(Exception):
  54. def __init__(self, name):
  55. self.name = name
  56. def __init__(self, location, create=False):
  57. self.p = None
  58. self.cache = LRUCache(256)
  59. self.to_send = b''
  60. self.extra = {}
  61. self.pending = {}
  62. self.unpacker = msgpack.Unpacker(use_list=False)
  63. self.msgid = 0
  64. self.received_msgid = 0
  65. args = ['ssh', '-p', str(location.port), '%s@%s' % (location.user or getpass.getuser(), location.host), 'darc', 'serve']
  66. self.p = Popen(args, bufsize=0, stdin=PIPE, stdout=PIPE)
  67. self.stdin_fd = self.p.stdin.fileno()
  68. self.stdout_fd = self.p.stdout.fileno()
  69. fcntl.fcntl(self.stdin_fd, fcntl.F_SETFL, fcntl.fcntl(self.stdin_fd, fcntl.F_GETFL) | os.O_NONBLOCK)
  70. fcntl.fcntl(self.stdout_fd, fcntl.F_SETFL, fcntl.fcntl(self.stdout_fd, fcntl.F_GETFL) | os.O_NONBLOCK)
  71. self.r_fds = [self.stdout_fd]
  72. self.x_fds = [self.stdin_fd, self.stdout_fd]
  73. version = self.call('negotiate', (1,))
  74. if version != 1:
  75. raise Exception('Server insisted on using unsupported protocol version %d' % version)
  76. try:
  77. self.id = self.call('open', (location.path, create))
  78. except self.RPCError as e:
  79. if e.name == b'DoesNotExist':
  80. raise Repository.DoesNotExist
  81. elif e.name == b'AlreadyExists':
  82. raise Repository.AlreadyExists
  83. def __del__(self):
  84. self.close()
  85. def call(self, cmd, args, wait=True):
  86. self.msgid += 1
  87. to_send = msgpack.packb((1, self.msgid, cmd, args))
  88. w_fds = [self.stdin_fd]
  89. while wait or to_send:
  90. r, w, x = select.select(self.r_fds, w_fds, self.x_fds, 1)
  91. if x:
  92. raise Exception('FD exception occured')
  93. if r:
  94. data = os.read(self.stdout_fd, BUFSIZE)
  95. if not data:
  96. raise Exception('Remote host closed connection')
  97. self.unpacker.feed(data)
  98. for type, msgid, error, res in self.unpacker:
  99. if msgid == self.msgid:
  100. assert msgid == self.msgid
  101. self.received_msgid = msgid
  102. if error:
  103. raise self.RPCError(error)
  104. else:
  105. return res
  106. else:
  107. args = self.pending.pop(msgid, None)
  108. if args is not None:
  109. self.cache[args] = msgid, res, error
  110. if w:
  111. if to_send:
  112. n = os.write(self.stdin_fd, to_send)
  113. assert n > 0
  114. to_send = memoryview(to_send)[n:]
  115. else:
  116. w_fds = []
  117. def _read(self):
  118. data = os.read(self.stdout_fd, BUFSIZE)
  119. if not data:
  120. raise Exception('Remote host closed connection')
  121. self.unpacker.feed(data)
  122. to_yield = []
  123. for type, msgid, error, res in self.unpacker:
  124. self.received_msgid = msgid
  125. args = self.pending.pop(msgid, None)
  126. if args is not None:
  127. self.cache[args] = msgid, res, error
  128. for args, resp, error in self.extra.pop(msgid, []):
  129. if not resp and not error:
  130. resp, error = self.cache[args][1:]
  131. to_yield.append((resp, error))
  132. for res, error in to_yield:
  133. if error:
  134. raise self.RPCError(error)
  135. else:
  136. yield res
  137. def gen_request(self, cmd, argsv, wait):
  138. data = []
  139. m = self.received_msgid
  140. for args in argsv:
  141. # Make sure to invalidate any existing cache entries for non-get requests
  142. if not args in self.cache:
  143. self.msgid += 1
  144. msgid = self.msgid
  145. self.pending[msgid] = args
  146. self.cache[args] = msgid, None, None
  147. data.append(msgpack.packb((1, msgid, cmd, args)))
  148. if wait:
  149. msgid, resp, error = self.cache[args]
  150. m = max(m, msgid)
  151. self.extra.setdefault(m, []).append((args, resp, error))
  152. return b''.join(data)
  153. def gen_cache_requests(self, cmd, peek):
  154. data = []
  155. while True:
  156. try:
  157. args = (peek()[0],)
  158. except StopIteration:
  159. break
  160. if args in self.cache:
  161. continue
  162. self.msgid += 1
  163. msgid = self.msgid
  164. self.pending[msgid] = args
  165. self.cache[args] = msgid, None, None
  166. data.append(msgpack.packb((1, msgid, cmd, args)))
  167. return b''.join(data)
  168. def call_multi(self, cmd, argsv, wait=True, peek=None):
  169. w_fds = [self.stdin_fd]
  170. left = len(argsv)
  171. data = self.gen_request(cmd, argsv, wait)
  172. self.to_send += data
  173. for args, resp, error in self.extra.pop(self.received_msgid, []):
  174. left -= 1
  175. if not resp and not error:
  176. resp, error = self.cache[args][1:]
  177. if error:
  178. raise self.RPCError(error)
  179. else:
  180. yield resp
  181. while left:
  182. r, w, x = select.select(self.r_fds, w_fds, self.x_fds, 1)
  183. if x:
  184. raise Exception('FD exception occured')
  185. if r:
  186. for res in self._read():
  187. left -= 1
  188. yield res
  189. if w:
  190. if not self.to_send and peek:
  191. self.to_send = self.gen_cache_requests(cmd, peek)
  192. if self.to_send:
  193. n = os.write(self.stdin_fd, self.to_send)
  194. assert n > 0
  195. # self.to_send = memoryview(self.to_send)[n:]
  196. self.to_send = self.to_send[n:]
  197. else:
  198. w_fds = []
  199. if not wait:
  200. return
  201. def commit(self, *args):
  202. self.call('commit', args)
  203. def rollback(self, *args):
  204. self.cache.clear()
  205. self.pending.clear()
  206. self.extra.clear()
  207. return self.call('rollback', args)
  208. def get(self, id):
  209. try:
  210. for res in self.call_multi('get', [(id, )]):
  211. return res
  212. except self.RPCError as e:
  213. if e.name == b'DoesNotExist':
  214. raise Repository.DoesNotExist
  215. raise
  216. def get_many(self, ids, peek=None):
  217. return self.call_multi('get', [(id, ) for id in ids], peek=peek)
  218. def _invalidate(self, id):
  219. key = (id, )
  220. if key in self.cache:
  221. self.pending.pop(self.cache.pop(key)[0], None)
  222. def put(self, id, data, wait=True):
  223. resp = self.call('put', (id, data), wait=wait)
  224. self._invalidate(id)
  225. return resp
  226. def delete(self, id, wait=True):
  227. resp = self.call('delete', (id, ), wait=wait)
  228. self._invalidate(id)
  229. return resp
  230. def close(self):
  231. if self.p:
  232. self.p.stdin.close()
  233. self.p.stdout.close()
  234. self.p.wait()
  235. self.p = None