remote.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. import fcntl
  2. import msgpack
  3. import os
  4. import select
  5. from subprocess import Popen, PIPE
  6. import sys
  7. from .helpers import Error, IntegrityError
  8. from .repository import Repository
  9. BUFSIZE = 10 * 1024 * 1024
  10. class ConnectionClosed(Error):
  11. """Connection closed by remote host"""
  12. class RepositoryServer(object):
  13. def __init__(self):
  14. self.repository = None
  15. def serve(self):
  16. # Make stdin non-blocking
  17. fl = fcntl.fcntl(sys.stdin.fileno(), fcntl.F_GETFL)
  18. fcntl.fcntl(sys.stdin.fileno(), fcntl.F_SETFL, fl | os.O_NONBLOCK)
  19. # Make stdout blocking
  20. fl = fcntl.fcntl(sys.stdout.fileno(), fcntl.F_GETFL)
  21. fcntl.fcntl(sys.stdout.fileno(), fcntl.F_SETFL, fl & ~os.O_NONBLOCK)
  22. unpacker = msgpack.Unpacker(use_list=False)
  23. while True:
  24. r, w, es = select.select([sys.stdin], [], [], 10)
  25. if r:
  26. data = os.read(sys.stdin.fileno(), BUFSIZE)
  27. if not data:
  28. return
  29. unpacker.feed(data)
  30. for type, msgid, method, args in unpacker:
  31. method = method.decode('ascii')
  32. try:
  33. try:
  34. f = getattr(self, method)
  35. except AttributeError:
  36. f = getattr(self.repository, method)
  37. res = f(*args)
  38. except Exception as e:
  39. sys.stdout.buffer.write(msgpack.packb((1, msgid, e.__class__.__name__, e.args)))
  40. else:
  41. sys.stdout.buffer.write(msgpack.packb((1, msgid, None, res)))
  42. sys.stdout.flush()
  43. if es:
  44. return
  45. def negotiate(self, versions):
  46. return 1
  47. def open(self, path, create=False):
  48. path = os.fsdecode(path)
  49. if path.startswith('/~'):
  50. path = path[1:]
  51. self.repository = Repository(os.path.expanduser(path), create)
  52. return self.repository.id
  53. class RemoteRepository(object):
  54. class RPCError(Exception):
  55. def __init__(self, name):
  56. self.name = name
  57. def __init__(self, location, create=False):
  58. self.location = location
  59. self.preload_ids = []
  60. self.msgid = 0
  61. self.to_send = b''
  62. self.cache = {}
  63. self.ignore_responses = set()
  64. self.responses = {}
  65. self.unpacker = msgpack.Unpacker(use_list=False)
  66. self.p = None
  67. if location.host == '__testsuite__':
  68. args = [sys.executable, '-m', 'attic.archiver', 'serve']
  69. else:
  70. args = ['ssh']
  71. if location.port:
  72. args += ['-p', str(location.port)]
  73. if location.user:
  74. args.append('%s@%s' % (location.user, location.host))
  75. else:
  76. args.append('%s' % location.host)
  77. args += ['attic', 'serve']
  78. self.p = Popen(args, bufsize=0, stdin=PIPE, stdout=PIPE)
  79. self.stdin_fd = self.p.stdin.fileno()
  80. self.stdout_fd = self.p.stdout.fileno()
  81. fcntl.fcntl(self.stdin_fd, fcntl.F_SETFL, fcntl.fcntl(self.stdin_fd, fcntl.F_GETFL) | os.O_NONBLOCK)
  82. fcntl.fcntl(self.stdout_fd, fcntl.F_SETFL, fcntl.fcntl(self.stdout_fd, fcntl.F_GETFL) | os.O_NONBLOCK)
  83. self.r_fds = [self.stdout_fd]
  84. self.x_fds = [self.stdin_fd, self.stdout_fd]
  85. version = self.call('negotiate', 1)
  86. if version != 1:
  87. raise Exception('Server insisted on using unsupported protocol version %d' % version)
  88. self.id = self.call('open', location.path, create)
  89. def __del__(self):
  90. self.close()
  91. def call(self, cmd, *args, **kw):
  92. for resp in self.call_many(cmd, [args], **kw):
  93. return resp
  94. def call_many(self, cmd, calls, wait=True, is_preloaded=False):
  95. if not calls:
  96. return
  97. def fetch_from_cache(args):
  98. msgid = self.cache[args].pop(0)
  99. if not self.cache[args]:
  100. del self.cache[args]
  101. return msgid
  102. calls = list(calls)
  103. waiting_for = []
  104. w_fds = [self.stdin_fd]
  105. while wait or calls:
  106. while waiting_for:
  107. try:
  108. error, res = self.responses.pop(waiting_for[0])
  109. waiting_for.pop(0)
  110. if error:
  111. if error == b'DoesNotExist':
  112. raise Repository.DoesNotExist(self.location.orig)
  113. elif error == b'AlreadyExists':
  114. raise Repository.AlreadyExists(self.location.orig)
  115. elif error == b'CheckNeeded':
  116. raise Repository.CheckNeeded(self.location.orig)
  117. elif error == b'IntegrityError':
  118. raise IntegrityError
  119. raise self.RPCError(error)
  120. else:
  121. yield res
  122. if not waiting_for and not calls:
  123. return
  124. except KeyError:
  125. break
  126. r, w, x = select.select(self.r_fds, w_fds, self.x_fds, 1)
  127. if x:
  128. raise Exception('FD exception occured')
  129. if r:
  130. data = os.read(self.stdout_fd, BUFSIZE)
  131. if not data:
  132. raise ConnectionClosed()
  133. self.unpacker.feed(data)
  134. for type, msgid, error, res in self.unpacker:
  135. if msgid in self.ignore_responses:
  136. self.ignore_responses.remove(msgid)
  137. else:
  138. self.responses[msgid] = error, res
  139. if w:
  140. while not self.to_send and (calls or self.preload_ids) and len(waiting_for) < 100:
  141. if calls:
  142. if is_preloaded:
  143. if calls[0] in self.cache:
  144. waiting_for.append(fetch_from_cache(calls.pop(0)))
  145. else:
  146. args = calls.pop(0)
  147. if cmd == 'get' and args in self.cache:
  148. waiting_for.append(fetch_from_cache(args))
  149. else:
  150. self.msgid += 1
  151. waiting_for.append(self.msgid)
  152. self.to_send = msgpack.packb((1, self.msgid, cmd, args))
  153. if not self.to_send and self.preload_ids:
  154. args = (self.preload_ids.pop(0),)
  155. self.msgid += 1
  156. self.cache.setdefault(args, []).append(self.msgid)
  157. self.to_send = msgpack.packb((1, self.msgid, cmd, args))
  158. if self.to_send:
  159. self.to_send = self.to_send[os.write(self.stdin_fd, self.to_send):]
  160. if not self.to_send and not (calls or self.preload_ids):
  161. w_fds = []
  162. self.ignore_responses |= set(waiting_for)
  163. def check(self, progress=False, repair=False):
  164. return self.call('check', progress, repair)
  165. def commit(self, *args):
  166. return self.call('commit')
  167. def rollback(self, *args):
  168. return self.call('rollback')
  169. def __len__(self):
  170. return self.call('__len__')
  171. def get(self, id_):
  172. for resp in self.get_many([id_]):
  173. return resp
  174. def get_many(self, ids, is_preloaded=False):
  175. for resp in self.call_many('get', [(id_,) for id_ in ids], is_preloaded=is_preloaded):
  176. yield resp
  177. def put(self, id_, data, wait=True):
  178. return self.call('put', id_, data, wait=wait)
  179. def delete(self, id_, wait=True):
  180. return self.call('delete', id_, wait=wait)
  181. def close(self):
  182. if self.p:
  183. self.p.stdin.close()
  184. self.p.stdout.close()
  185. self.p.wait()
  186. self.p = None
  187. def preload(self, ids):
  188. self.preload_ids += ids