remote.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. import fcntl
  2. import msgpack
  3. import os
  4. import select
  5. from subprocess import Popen, PIPE
  6. import sys
  7. import getpass
  8. from .repository import Repository
  9. from .lrucache import LRUCache
  10. BUFSIZE = 10 * 1024 * 1024
  11. class ConnectionClosed(Exception):
  12. """Connection closed by remote host
  13. """
  14. class RepositoryServer(object):
  15. def __init__(self):
  16. self.repository = None
  17. def serve(self):
  18. # Make stdin non-blocking
  19. fl = fcntl.fcntl(sys.stdin.fileno(), fcntl.F_GETFL)
  20. fcntl.fcntl(sys.stdin.fileno(), fcntl.F_SETFL, fl | os.O_NONBLOCK)
  21. # Make stdout blocking
  22. fl = fcntl.fcntl(sys.stdout.fileno(), fcntl.F_GETFL)
  23. fcntl.fcntl(sys.stdout.fileno(), fcntl.F_SETFL, fl & ~os.O_NONBLOCK)
  24. unpacker = msgpack.Unpacker(use_list=False)
  25. while True:
  26. r, w, es = select.select([sys.stdin], [], [], 10)
  27. if r:
  28. data = os.read(sys.stdin.fileno(), BUFSIZE)
  29. if not data:
  30. return
  31. unpacker.feed(data)
  32. for type, msgid, method, args in unpacker:
  33. method = method.decode('ascii')
  34. try:
  35. try:
  36. f = getattr(self, method)
  37. except AttributeError:
  38. f = getattr(self.repository, method)
  39. res = f(*args)
  40. except Exception as e:
  41. sys.stdout.buffer.write(msgpack.packb((1, msgid, e.__class__.__name__, None)))
  42. else:
  43. sys.stdout.buffer.write(msgpack.packb((1, msgid, None, res)))
  44. sys.stdout.flush()
  45. if es:
  46. return
  47. def negotiate(self, versions):
  48. return 1
  49. def open(self, path, create=False):
  50. path = os.fsdecode(path)
  51. if path.startswith('/~'):
  52. path = path[1:]
  53. self.repository = Repository(os.path.expanduser(path), create)
  54. return self.repository.id
  55. class RemoteRepository(object):
  56. class RPCError(Exception):
  57. def __init__(self, name):
  58. self.name = name
  59. def __init__(self, location, create=False):
  60. self.p = None
  61. self.cache = LRUCache(256)
  62. self.to_send = b''
  63. self.extra = {}
  64. self.pending = {}
  65. self.unpacker = msgpack.Unpacker(use_list=False)
  66. self.msgid = 0
  67. self.received_msgid = 0
  68. if location.host == '__testsuite__':
  69. args = [sys.executable, '-m', 'attic.archiver', 'serve']
  70. else:
  71. args = ['ssh', '-p', str(location.port), '%s@%s' % (location.user or getpass.getuser(), location.host), 'attic', 'serve']
  72. self.p = Popen(args, bufsize=0, stdin=PIPE, stdout=PIPE)
  73. self.stdin_fd = self.p.stdin.fileno()
  74. self.stdout_fd = self.p.stdout.fileno()
  75. fcntl.fcntl(self.stdin_fd, fcntl.F_SETFL, fcntl.fcntl(self.stdin_fd, fcntl.F_GETFL) | os.O_NONBLOCK)
  76. fcntl.fcntl(self.stdout_fd, fcntl.F_SETFL, fcntl.fcntl(self.stdout_fd, fcntl.F_GETFL) | os.O_NONBLOCK)
  77. self.r_fds = [self.stdout_fd]
  78. self.x_fds = [self.stdin_fd, self.stdout_fd]
  79. version = self.call('negotiate', (1,))
  80. if version != 1:
  81. raise Exception('Server insisted on using unsupported protocol version %d' % version)
  82. try:
  83. self.id = self.call('open', (location.path, create))
  84. except self.RPCError as e:
  85. if e.name == b'DoesNotExist':
  86. raise Repository.DoesNotExist
  87. elif e.name == b'AlreadyExists':
  88. raise Repository.AlreadyExists
  89. def __del__(self):
  90. self.close()
  91. def call(self, cmd, args, wait=True):
  92. self.msgid += 1
  93. to_send = msgpack.packb((1, self.msgid, cmd, args))
  94. w_fds = [self.stdin_fd]
  95. while wait or to_send:
  96. r, w, x = select.select(self.r_fds, w_fds, self.x_fds, 1)
  97. if x:
  98. raise Exception('FD exception occured')
  99. if r:
  100. data = os.read(self.stdout_fd, BUFSIZE)
  101. if not data:
  102. raise ConnectionClosed()
  103. self.unpacker.feed(data)
  104. for type, msgid, error, res in self.unpacker:
  105. if msgid == self.msgid:
  106. assert msgid == self.msgid
  107. self.received_msgid = msgid
  108. if error:
  109. raise self.RPCError(error)
  110. else:
  111. return res
  112. else:
  113. args = self.pending.pop(msgid, None)
  114. if args is not None:
  115. self.cache[args] = msgid, res, error
  116. if w:
  117. if to_send:
  118. n = os.write(self.stdin_fd, to_send)
  119. assert n > 0
  120. to_send = memoryview(to_send)[n:]
  121. else:
  122. w_fds = []
  123. def _read(self):
  124. data = os.read(self.stdout_fd, BUFSIZE)
  125. if not data:
  126. raise Exception('Remote host closed connection')
  127. self.unpacker.feed(data)
  128. to_yield = []
  129. for type, msgid, error, res in self.unpacker:
  130. self.received_msgid = msgid
  131. args = self.pending.pop(msgid, None)
  132. if args is not None:
  133. self.cache[args] = msgid, res, error
  134. for args, resp, error in self.extra.pop(msgid, []):
  135. if not resp and not error:
  136. resp, error = self.cache[args][1:]
  137. to_yield.append((resp, error))
  138. for res, error in to_yield:
  139. if error:
  140. raise self.RPCError(error)
  141. else:
  142. yield res
  143. def gen_request(self, cmd, argsv, wait):
  144. data = []
  145. m = self.received_msgid
  146. for args in argsv:
  147. # Make sure to invalidate any existing cache entries for non-get requests
  148. if not args in self.cache:
  149. self.msgid += 1
  150. msgid = self.msgid
  151. self.pending[msgid] = args
  152. self.cache[args] = msgid, None, None
  153. data.append(msgpack.packb((1, msgid, cmd, args)))
  154. if wait:
  155. msgid, resp, error = self.cache[args]
  156. m = max(m, msgid)
  157. self.extra.setdefault(m, []).append((args, resp, error))
  158. return b''.join(data)
  159. def gen_cache_requests(self, cmd, peek):
  160. data = []
  161. while True:
  162. try:
  163. args = (peek()[0],)
  164. except StopIteration:
  165. break
  166. if args in self.cache:
  167. continue
  168. self.msgid += 1
  169. msgid = self.msgid
  170. self.pending[msgid] = args
  171. self.cache[args] = msgid, None, None
  172. data.append(msgpack.packb((1, msgid, cmd, args)))
  173. return b''.join(data)
  174. def call_multi(self, cmd, argsv, wait=True, peek=None):
  175. w_fds = [self.stdin_fd]
  176. left = len(argsv)
  177. data = self.gen_request(cmd, argsv, wait)
  178. self.to_send += data
  179. for args, resp, error in self.extra.pop(self.received_msgid, []):
  180. left -= 1
  181. if not resp and not error:
  182. resp, error = self.cache[args][1:]
  183. if error:
  184. raise self.RPCError(error)
  185. else:
  186. yield resp
  187. while left:
  188. r, w, x = select.select(self.r_fds, w_fds, self.x_fds, 1)
  189. if x:
  190. raise Exception('FD exception occured')
  191. if r:
  192. for res in self._read():
  193. left -= 1
  194. yield res
  195. if w:
  196. if not self.to_send and peek:
  197. self.to_send = self.gen_cache_requests(cmd, peek)
  198. if self.to_send:
  199. n = os.write(self.stdin_fd, self.to_send)
  200. assert n > 0
  201. # self.to_send = memoryview(self.to_send)[n:]
  202. self.to_send = self.to_send[n:]
  203. else:
  204. w_fds = []
  205. if not wait:
  206. return
  207. def commit(self, *args):
  208. self.call('commit', args)
  209. def rollback(self, *args):
  210. self.cache.clear()
  211. self.pending.clear()
  212. self.extra.clear()
  213. return self.call('rollback', args)
  214. def get(self, id):
  215. try:
  216. for res in self.call_multi('get', [(id, )]):
  217. return res
  218. except self.RPCError as e:
  219. if e.name == b'DoesNotExist':
  220. raise Repository.DoesNotExist
  221. raise
  222. def get_many(self, ids, peek=None):
  223. return self.call_multi('get', [(id, ) for id in ids], peek=peek)
  224. def _invalidate(self, id):
  225. key = (id, )
  226. if key in self.cache:
  227. self.pending.pop(self.cache.pop(key)[0], None)
  228. def put(self, id, data, wait=True):
  229. resp = self.call('put', (id, data), wait=wait)
  230. self._invalidate(id)
  231. return resp
  232. def delete(self, id, wait=True):
  233. resp = self.call('delete', (id, ), wait=wait)
  234. self._invalidate(id)
  235. return resp
  236. def close(self):
  237. if self.p:
  238. self.p.stdin.close()
  239. self.p.stdout.close()
  240. self.p.wait()
  241. self.p = None