remote.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. import fcntl
  2. import msgpack
  3. import os
  4. import select
  5. from subprocess import Popen, PIPE
  6. import sys
  7. import getpass
  8. from .repository import Repository
  9. from .lrucache import LRUCache
  10. BUFSIZE = 10 * 1024 * 1024
  11. class ConnectionClosed(Exception):
  12. """Connection closed by remote host
  13. """
  14. class RepositoryServer(object):
  15. def __init__(self):
  16. self.repository = None
  17. def serve(self):
  18. # Make stdin non-blocking
  19. fl = fcntl.fcntl(sys.stdin.fileno(), fcntl.F_GETFL)
  20. fcntl.fcntl(sys.stdin.fileno(), fcntl.F_SETFL, fl | os.O_NONBLOCK)
  21. # Make stdout blocking
  22. fl = fcntl.fcntl(sys.stdout.fileno(), fcntl.F_GETFL)
  23. fcntl.fcntl(sys.stdout.fileno(), fcntl.F_SETFL, fl & ~os.O_NONBLOCK)
  24. unpacker = msgpack.Unpacker(use_list=False)
  25. while True:
  26. r, w, es = select.select([sys.stdin], [], [], 10)
  27. if r:
  28. data = os.read(sys.stdin.fileno(), BUFSIZE)
  29. if not data:
  30. return
  31. unpacker.feed(data)
  32. for type, msgid, method, args in unpacker:
  33. method = method.decode('ascii')
  34. try:
  35. try:
  36. f = getattr(self, method)
  37. except AttributeError:
  38. f = getattr(self.repository, method)
  39. res = f(*args)
  40. except Exception as e:
  41. sys.stdout.buffer.write(msgpack.packb((1, msgid, e.__class__.__name__, None)))
  42. else:
  43. sys.stdout.buffer.write(msgpack.packb((1, msgid, None, res)))
  44. sys.stdout.flush()
  45. if es:
  46. return
  47. def negotiate(self, versions):
  48. return 1
  49. def open(self, path, create=False):
  50. path = os.fsdecode(path)
  51. if path.startswith('/~'):
  52. path = path[1:]
  53. self.repository = Repository(os.path.expanduser(path), create)
  54. return self.repository.id
  55. class RemoteRepository(object):
  56. class RPCError(Exception):
  57. def __init__(self, name):
  58. self.name = name
  59. def __init__(self, location, create=False):
  60. self.p = None
  61. self.cache = LRUCache(256)
  62. self.to_send = b''
  63. self.extra = {}
  64. self.pending = {}
  65. self.unpacker = msgpack.Unpacker(use_list=False)
  66. self.msgid = 0
  67. self.received_msgid = 0
  68. if location.host == '__testsuite__':
  69. args = [sys.executable, '-m', 'attic.archiver', 'serve']
  70. else:
  71. args = ['ssh', '-p', str(location.port), '%s@%s' % (location.user or getpass.getuser(), location.host), 'attic', 'serve']
  72. self.p = Popen(args, bufsize=0, stdin=PIPE, stdout=PIPE)
  73. self.stdin_fd = self.p.stdin.fileno()
  74. self.stdout_fd = self.p.stdout.fileno()
  75. fcntl.fcntl(self.stdin_fd, fcntl.F_SETFL, fcntl.fcntl(self.stdin_fd, fcntl.F_GETFL) | os.O_NONBLOCK)
  76. fcntl.fcntl(self.stdout_fd, fcntl.F_SETFL, fcntl.fcntl(self.stdout_fd, fcntl.F_GETFL) | os.O_NONBLOCK)
  77. self.r_fds = [self.stdout_fd]
  78. self.x_fds = [self.stdin_fd, self.stdout_fd]
  79. version = self.call('negotiate', (1,))
  80. if version != 1:
  81. raise Exception('Server insisted on using unsupported protocol version %d' % version)
  82. try:
  83. self.id = self.call('open', (location.path, create))
  84. except self.RPCError as e:
  85. if e.name == b'DoesNotExist':
  86. raise Repository.DoesNotExist
  87. elif e.name == b'AlreadyExists':
  88. raise Repository.AlreadyExists
  89. def __del__(self):
  90. self.close()
  91. def call(self, cmd, args, wait=True):
  92. self.msgid += 1
  93. to_send = msgpack.packb((1, self.msgid, cmd, args))
  94. w_fds = [self.stdin_fd]
  95. while wait or to_send:
  96. r, w, x = select.select(self.r_fds, w_fds, self.x_fds, 1)
  97. if x:
  98. raise Exception('FD exception occured')
  99. if r:
  100. data = os.read(self.stdout_fd, BUFSIZE)
  101. if not data:
  102. raise ConnectionClosed()
  103. self.unpacker.feed(data)
  104. for type, msgid, error, res in self.unpacker:
  105. if msgid == self.msgid:
  106. self.received_msgid = msgid
  107. if error:
  108. raise self.RPCError(error)
  109. else:
  110. return res
  111. else:
  112. args = self.pending.pop(msgid, None)
  113. if args is not None:
  114. self.cache[args] = msgid, res, error
  115. if w:
  116. if to_send:
  117. n = os.write(self.stdin_fd, to_send)
  118. assert n > 0
  119. to_send = memoryview(to_send)[n:]
  120. if not to_send:
  121. w_fds = []
  122. def _read(self):
  123. data = os.read(self.stdout_fd, BUFSIZE)
  124. if not data:
  125. raise Exception('Remote host closed connection')
  126. self.unpacker.feed(data)
  127. to_yield = []
  128. for type, msgid, error, res in self.unpacker:
  129. self.received_msgid = msgid
  130. args = self.pending.pop(msgid, None)
  131. if args is not None:
  132. self.cache[args] = msgid, res, error
  133. for args, resp, error in self.extra.pop(msgid, []):
  134. if not resp and not error:
  135. resp, error = self.cache[args][1:]
  136. to_yield.append((resp, error))
  137. for res, error in to_yield:
  138. if error:
  139. raise self.RPCError(error)
  140. else:
  141. yield res
  142. def gen_request(self, cmd, argsv, wait):
  143. data = []
  144. m = self.received_msgid
  145. for args in argsv:
  146. # Make sure to invalidate any existing cache entries for non-get requests
  147. if not args in self.cache:
  148. self.msgid += 1
  149. msgid = self.msgid
  150. self.pending[msgid] = args
  151. self.cache[args] = msgid, None, None
  152. data.append(msgpack.packb((1, msgid, cmd, args)))
  153. if wait:
  154. msgid, resp, error = self.cache[args]
  155. m = max(m, msgid)
  156. self.extra.setdefault(m, []).append((args, resp, error))
  157. return b''.join(data)
  158. def gen_cache_requests(self, cmd, peek):
  159. data = []
  160. while True:
  161. try:
  162. args = (peek()[0],)
  163. except StopIteration:
  164. break
  165. if args in self.cache:
  166. continue
  167. self.msgid += 1
  168. msgid = self.msgid
  169. self.pending[msgid] = args
  170. self.cache[args] = msgid, None, None
  171. data.append(msgpack.packb((1, msgid, cmd, args)))
  172. return b''.join(data)
  173. def call_multi(self, cmd, argsv, wait=True, peek=None):
  174. w_fds = [self.stdin_fd]
  175. left = len(argsv)
  176. data = self.gen_request(cmd, argsv, wait)
  177. self.to_send += data
  178. for args, resp, error in self.extra.pop(self.received_msgid, []):
  179. left -= 1
  180. if not resp and not error:
  181. resp, error = self.cache[args][1:]
  182. if error:
  183. raise self.RPCError(error)
  184. else:
  185. yield resp
  186. while left:
  187. r, w, x = select.select(self.r_fds, w_fds, self.x_fds, 1)
  188. if x:
  189. raise Exception('FD exception occured')
  190. if r:
  191. for res in self._read():
  192. left -= 1
  193. yield res
  194. if w:
  195. if not self.to_send and peek:
  196. self.to_send = self.gen_cache_requests(cmd, peek)
  197. if self.to_send:
  198. n = os.write(self.stdin_fd, self.to_send)
  199. assert n > 0
  200. # self.to_send = memoryview(self.to_send)[n:]
  201. self.to_send = self.to_send[n:]
  202. else:
  203. w_fds = []
  204. if not wait:
  205. return
  206. def commit(self, *args):
  207. self.call('commit', args)
  208. def rollback(self, *args):
  209. self.cache.clear()
  210. self.pending.clear()
  211. self.extra.clear()
  212. return self.call('rollback', args)
  213. def get(self, id):
  214. try:
  215. for res in self.call_multi('get', [(id, )]):
  216. return res
  217. except self.RPCError as e:
  218. if e.name == b'DoesNotExist':
  219. raise Repository.DoesNotExist
  220. raise
  221. def get_many(self, ids, peek=None):
  222. return self.call_multi('get', [(id, ) for id in ids], peek=peek)
  223. def _invalidate(self, id):
  224. key = (id, )
  225. if key in self.cache:
  226. self.pending.pop(self.cache.pop(key)[0], None)
  227. def put(self, id, data, wait=True):
  228. resp = self.call('put', (id, data), wait=wait)
  229. self._invalidate(id)
  230. return resp
  231. def delete(self, id, wait=True):
  232. resp = self.call('delete', (id, ), wait=wait)
  233. self._invalidate(id)
  234. return resp
  235. def close(self):
  236. if self.p:
  237. self.p.stdin.close()
  238. self.p.stdout.close()
  239. self.p.wait()
  240. self.p = None