2
0

remote.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. import fcntl
  2. import msgpack
  3. import os
  4. import select
  5. from subprocess import Popen, PIPE
  6. import sys
  7. from .helpers import Error
  8. from .repository import Repository
  9. from .lrucache import LRUCache
  10. BUFSIZE = 10 * 1024 * 1024
  11. class ConnectionClosed(Error):
  12. """Connection closed by remote host"""
  13. class RepositoryServer(object):
  14. def __init__(self):
  15. self.repository = None
  16. def serve(self):
  17. # Make stdin non-blocking
  18. fl = fcntl.fcntl(sys.stdin.fileno(), fcntl.F_GETFL)
  19. fcntl.fcntl(sys.stdin.fileno(), fcntl.F_SETFL, fl | os.O_NONBLOCK)
  20. # Make stdout blocking
  21. fl = fcntl.fcntl(sys.stdout.fileno(), fcntl.F_GETFL)
  22. fcntl.fcntl(sys.stdout.fileno(), fcntl.F_SETFL, fl & ~os.O_NONBLOCK)
  23. unpacker = msgpack.Unpacker(use_list=False)
  24. while True:
  25. r, w, es = select.select([sys.stdin], [], [], 10)
  26. if r:
  27. data = os.read(sys.stdin.fileno(), BUFSIZE)
  28. if not data:
  29. return
  30. unpacker.feed(data)
  31. for type, msgid, method, args in unpacker:
  32. method = method.decode('ascii')
  33. try:
  34. try:
  35. f = getattr(self, method)
  36. except AttributeError:
  37. f = getattr(self.repository, method)
  38. res = f(*args)
  39. except Exception as e:
  40. sys.stdout.buffer.write(msgpack.packb((1, msgid, e.__class__.__name__, None)))
  41. else:
  42. sys.stdout.buffer.write(msgpack.packb((1, msgid, None, res)))
  43. sys.stdout.flush()
  44. if es:
  45. return
  46. def negotiate(self, versions):
  47. return 1
  48. def open(self, path, create=False):
  49. path = os.fsdecode(path)
  50. if path.startswith('/~'):
  51. path = path[1:]
  52. self.repository = Repository(os.path.expanduser(path), create)
  53. return self.repository.id
  54. class RemoteRepository(object):
  55. class RPCError(Exception):
  56. def __init__(self, name):
  57. self.name = name
  58. def __init__(self, location, create=False):
  59. self.repository_url = '%s@%s:%s' % (location.user, location.host, location.path)
  60. self.p = None
  61. self.cache = LRUCache(256)
  62. self.to_send = b''
  63. self.extra = {}
  64. self.pending = {}
  65. self.unpacker = msgpack.Unpacker(use_list=False)
  66. self.msgid = 0
  67. self.received_msgid = 0
  68. if location.host == '__testsuite__':
  69. args = [sys.executable, '-m', 'attic.archiver', 'serve']
  70. else:
  71. args = ['ssh',]
  72. if location.port:
  73. args += ['-p', str(location.port)]
  74. if location.user:
  75. args.append('%s@%s' % (location.user, location.host))
  76. else:
  77. args.append('%s' % location.host)
  78. args += ['attic', 'serve']
  79. self.p = Popen(args, bufsize=0, stdin=PIPE, stdout=PIPE)
  80. self.stdin_fd = self.p.stdin.fileno()
  81. self.stdout_fd = self.p.stdout.fileno()
  82. fcntl.fcntl(self.stdin_fd, fcntl.F_SETFL, fcntl.fcntl(self.stdin_fd, fcntl.F_GETFL) | os.O_NONBLOCK)
  83. fcntl.fcntl(self.stdout_fd, fcntl.F_SETFL, fcntl.fcntl(self.stdout_fd, fcntl.F_GETFL) | os.O_NONBLOCK)
  84. self.r_fds = [self.stdout_fd]
  85. self.x_fds = [self.stdin_fd, self.stdout_fd]
  86. version = self.call('negotiate', (1,))
  87. if version != 1:
  88. raise Exception('Server insisted on using unsupported protocol version %d' % version)
  89. try:
  90. self.id = self.call('open', (location.path, create))
  91. except self.RPCError as e:
  92. if e.name == b'DoesNotExist':
  93. raise Repository.DoesNotExist(self.repository_url)
  94. elif e.name == b'AlreadyExists':
  95. raise Repository.AlreadyExists(self.repository_url)
  96. def __del__(self):
  97. self.close()
  98. def call(self, cmd, args, wait=True):
  99. self.msgid += 1
  100. to_send = msgpack.packb((1, self.msgid, cmd, args))
  101. w_fds = [self.stdin_fd]
  102. while wait or to_send:
  103. r, w, x = select.select(self.r_fds, w_fds, self.x_fds, 1)
  104. if x:
  105. raise Exception('FD exception occured')
  106. if r:
  107. data = os.read(self.stdout_fd, BUFSIZE)
  108. if not data:
  109. raise ConnectionClosed()
  110. self.unpacker.feed(data)
  111. for type, msgid, error, res in self.unpacker:
  112. if msgid == self.msgid:
  113. self.received_msgid = msgid
  114. if error:
  115. raise self.RPCError(error)
  116. else:
  117. return res
  118. else:
  119. args = self.pending.pop(msgid, None)
  120. if args is not None:
  121. self.cache[args] = msgid, res, error
  122. if w:
  123. if to_send:
  124. n = os.write(self.stdin_fd, to_send)
  125. assert n > 0
  126. to_send = memoryview(to_send)[n:]
  127. if not to_send:
  128. w_fds = []
  129. def _read(self):
  130. data = os.read(self.stdout_fd, BUFSIZE)
  131. if not data:
  132. raise Exception('Remote host closed connection')
  133. self.unpacker.feed(data)
  134. to_yield = []
  135. for type, msgid, error, res in self.unpacker:
  136. self.received_msgid = msgid
  137. args = self.pending.pop(msgid, None)
  138. if args is not None:
  139. self.cache[args] = msgid, res, error
  140. for args, resp, error in self.extra.pop(msgid, []):
  141. if not resp and not error:
  142. resp, error = self.cache[args][1:]
  143. to_yield.append((resp, error))
  144. for res, error in to_yield:
  145. if error:
  146. raise self.RPCError(error)
  147. else:
  148. yield res
  149. def gen_request(self, cmd, argsv, wait):
  150. data = []
  151. m = self.received_msgid
  152. for args in argsv:
  153. # Make sure to invalidate any existing cache entries for non-get requests
  154. if not args in self.cache:
  155. self.msgid += 1
  156. msgid = self.msgid
  157. self.pending[msgid] = args
  158. self.cache[args] = msgid, None, None
  159. data.append(msgpack.packb((1, msgid, cmd, args)))
  160. if wait:
  161. msgid, resp, error = self.cache[args]
  162. m = max(m, msgid)
  163. self.extra.setdefault(m, []).append((args, resp, error))
  164. return b''.join(data)
  165. def gen_cache_requests(self, cmd, peek):
  166. data = []
  167. while True:
  168. try:
  169. args = (peek()[0],)
  170. except StopIteration:
  171. break
  172. if args in self.cache:
  173. continue
  174. self.msgid += 1
  175. msgid = self.msgid
  176. self.pending[msgid] = args
  177. self.cache[args] = msgid, None, None
  178. data.append(msgpack.packb((1, msgid, cmd, args)))
  179. return b''.join(data)
  180. def call_multi(self, cmd, argsv, wait=True, peek=None):
  181. w_fds = [self.stdin_fd]
  182. left = len(argsv)
  183. data = self.gen_request(cmd, argsv, wait)
  184. self.to_send += data
  185. for args, resp, error in self.extra.pop(self.received_msgid, []):
  186. left -= 1
  187. if not resp and not error:
  188. resp, error = self.cache[args][1:]
  189. if error:
  190. raise self.RPCError(error)
  191. else:
  192. yield resp
  193. while left:
  194. r, w, x = select.select(self.r_fds, w_fds, self.x_fds, 1)
  195. if x:
  196. raise Exception('FD exception occured')
  197. if r:
  198. for res in self._read():
  199. left -= 1
  200. yield res
  201. if w:
  202. if not self.to_send and peek:
  203. self.to_send = self.gen_cache_requests(cmd, peek)
  204. if self.to_send:
  205. n = os.write(self.stdin_fd, self.to_send)
  206. assert n > 0
  207. # self.to_send = memoryview(self.to_send)[n:]
  208. self.to_send = self.to_send[n:]
  209. else:
  210. w_fds = []
  211. if not wait:
  212. return
  213. def commit(self, *args):
  214. self.call('commit', args)
  215. def rollback(self, *args):
  216. self.cache.clear()
  217. self.pending.clear()
  218. self.extra.clear()
  219. return self.call('rollback', args)
  220. def get(self, id):
  221. try:
  222. for res in self.call_multi('get', [(id, )]):
  223. return res
  224. except self.RPCError as e:
  225. if e.name == b'DoesNotExist':
  226. raise Repository.DoesNotExist(self.repository_url)
  227. raise
  228. def get_many(self, ids, peek=None):
  229. return self.call_multi('get', [(id, ) for id in ids], peek=peek)
  230. def _invalidate(self, id):
  231. key = (id, )
  232. if key in self.cache:
  233. self.pending.pop(self.cache.pop(key)[0], None)
  234. def put(self, id, data, wait=True):
  235. resp = self.call('put', (id, data), wait=wait)
  236. self._invalidate(id)
  237. return resp
  238. def delete(self, id, wait=True):
  239. resp = self.call('delete', (id, ), wait=wait)
  240. self._invalidate(id)
  241. return resp
  242. def close(self):
  243. if self.p:
  244. self.p.stdin.close()
  245. self.p.stdout.close()
  246. self.p.wait()
  247. self.p = None