remote.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. import fcntl
  2. import msgpack
  3. import os
  4. import select
  5. from subprocess import Popen, PIPE
  6. import sys
  7. from .helpers import Error
  8. from .repository import Repository
  9. BUFSIZE = 10 * 1024 * 1024
  10. class ConnectionClosed(Error):
  11. """Connection closed by remote host"""
  12. class RepositoryServer(object):
  13. def __init__(self):
  14. self.repository = None
  15. def serve(self):
  16. # Make stdin non-blocking
  17. fl = fcntl.fcntl(sys.stdin.fileno(), fcntl.F_GETFL)
  18. fcntl.fcntl(sys.stdin.fileno(), fcntl.F_SETFL, fl | os.O_NONBLOCK)
  19. # Make stdout blocking
  20. fl = fcntl.fcntl(sys.stdout.fileno(), fcntl.F_GETFL)
  21. fcntl.fcntl(sys.stdout.fileno(), fcntl.F_SETFL, fl & ~os.O_NONBLOCK)
  22. unpacker = msgpack.Unpacker(use_list=False)
  23. while True:
  24. r, w, es = select.select([sys.stdin], [], [], 10)
  25. if r:
  26. data = os.read(sys.stdin.fileno(), BUFSIZE)
  27. if not data:
  28. return
  29. unpacker.feed(data)
  30. for type, msgid, method, args in unpacker:
  31. method = method.decode('ascii')
  32. try:
  33. try:
  34. f = getattr(self, method)
  35. except AttributeError:
  36. f = getattr(self.repository, method)
  37. res = f(*args)
  38. except Exception as e:
  39. sys.stdout.buffer.write(msgpack.packb((1, msgid, e.__class__.__name__, e.args)))
  40. else:
  41. sys.stdout.buffer.write(msgpack.packb((1, msgid, None, res)))
  42. sys.stdout.flush()
  43. if es:
  44. return
  45. def negotiate(self, versions):
  46. return 1
  47. def open(self, path, create=False):
  48. path = os.fsdecode(path)
  49. if path.startswith('/~'):
  50. path = path[1:]
  51. self.repository = Repository(os.path.expanduser(path), create)
  52. return self.repository.id
  53. class RemoteRepository(object):
  54. class RPCError(Exception):
  55. def __init__(self, name):
  56. self.name = name
  57. def __init__(self, location, create=False):
  58. self.location = location
  59. self.preload_ids = []
  60. self.msgid = 0
  61. self.to_send = b''
  62. self.cache = {}
  63. self.ignore_responses = set()
  64. self.responses = {}
  65. self.unpacker = msgpack.Unpacker(use_list=False)
  66. self.p = None
  67. if location.host == '__testsuite__':
  68. args = [sys.executable, '-m', 'attic.archiver', 'serve']
  69. else:
  70. args = ['ssh']
  71. if location.port:
  72. args += ['-p', str(location.port)]
  73. if location.user:
  74. args.append('%s@%s' % (location.user, location.host))
  75. else:
  76. args.append('%s' % location.host)
  77. args += ['attic', 'serve']
  78. self.p = Popen(args, bufsize=0, stdin=PIPE, stdout=PIPE)
  79. self.stdin_fd = self.p.stdin.fileno()
  80. self.stdout_fd = self.p.stdout.fileno()
  81. fcntl.fcntl(self.stdin_fd, fcntl.F_SETFL, fcntl.fcntl(self.stdin_fd, fcntl.F_GETFL) | os.O_NONBLOCK)
  82. fcntl.fcntl(self.stdout_fd, fcntl.F_SETFL, fcntl.fcntl(self.stdout_fd, fcntl.F_GETFL) | os.O_NONBLOCK)
  83. self.r_fds = [self.stdout_fd]
  84. self.x_fds = [self.stdin_fd, self.stdout_fd]
  85. version = self.call('negotiate', 1)
  86. if version != 1:
  87. raise Exception('Server insisted on using unsupported protocol version %d' % version)
  88. self.id = self.call('open', location.path, create)
  89. def __del__(self):
  90. self.close()
  91. def call(self, cmd, *args, **kw):
  92. for resp in self.call_many(cmd, [args], **kw):
  93. return resp
  94. def call_many(self, cmd, calls, wait=True, is_preloaded=False):
  95. def fetch_from_cache(args):
  96. msgid = self.cache[args].pop(0)
  97. if not self.cache[args]:
  98. del self.cache[args]
  99. return msgid
  100. calls = list(calls)
  101. waiting_for = []
  102. w_fds = [self.stdin_fd]
  103. while wait or calls:
  104. while waiting_for:
  105. try:
  106. error, res = self.responses.pop(waiting_for[0])
  107. waiting_for.pop(0)
  108. if error:
  109. if error == b'DoesNotExist':
  110. raise Repository.DoesNotExist(self.location.orig)
  111. elif error == b'AlreadyExists':
  112. raise Repository.AlreadyExists(self.location.orig)
  113. raise self.RPCError(error)
  114. else:
  115. yield res
  116. if not waiting_for and not calls:
  117. return
  118. except KeyError:
  119. break
  120. r, w, x = select.select(self.r_fds, w_fds, self.x_fds, 1)
  121. if x:
  122. raise Exception('FD exception occured')
  123. if r:
  124. data = os.read(self.stdout_fd, BUFSIZE)
  125. if not data:
  126. raise ConnectionClosed()
  127. self.unpacker.feed(data)
  128. for type, msgid, error, res in self.unpacker:
  129. if msgid in self.ignore_responses:
  130. self.ignore_responses.remove(msgid)
  131. else:
  132. self.responses[msgid] = error, res
  133. if w:
  134. while not self.to_send and (calls or self.preload_ids) and len(waiting_for) < 100:
  135. if calls:
  136. if is_preloaded:
  137. if calls[0] in self.cache:
  138. waiting_for.append(fetch_from_cache(calls.pop(0)))
  139. else:
  140. args = calls.pop(0)
  141. if cmd == 'get' and args in self.cache:
  142. waiting_for.append(fetch_from_cache(args))
  143. else:
  144. self.msgid += 1
  145. waiting_for.append(self.msgid)
  146. self.to_send = msgpack.packb((1, self.msgid, cmd, args))
  147. if not self.to_send and self.preload_ids:
  148. args = (self.preload_ids.pop(0),)
  149. self.msgid += 1
  150. self.cache.setdefault(args, []).append(self.msgid)
  151. self.to_send = msgpack.packb((1, self.msgid, cmd, args))
  152. if self.to_send:
  153. self.to_send = self.to_send[os.write(self.stdin_fd, self.to_send):]
  154. if not self.to_send and not (calls or self.preload_ids):
  155. w_fds = []
  156. self.ignore_responses |= set(waiting_for)
  157. def commit(self, *args):
  158. return self.call('commit')
  159. def rollback(self, *args):
  160. return self.call('rollback')
  161. def get(self, id_):
  162. for resp in self.get_many([id_]):
  163. return resp
  164. def get_many(self, ids, is_preloaded=False):
  165. for resp in self.call_many('get', [(id_,) for id_ in ids], is_preloaded=is_preloaded):
  166. yield resp
  167. def put(self, id_, data, wait=True):
  168. return self.call('put', id_, data, wait=wait)
  169. def delete(self, id_, wait=True):
  170. return self.call('delete', id_, wait=wait)
  171. def close(self):
  172. if self.p:
  173. self.p.stdin.close()
  174. self.p.stdout.close()
  175. self.p.wait()
  176. self.p = None
  177. def preload(self, ids):
  178. self.preload_ids += ids