remote.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. import fcntl
  2. import msgpack
  3. import os
  4. import select
  5. import shutil
  6. from subprocess import Popen, PIPE
  7. import sys
  8. import tempfile
  9. from .hashindex import NSIndex
  10. from .helpers import Error, IntegrityError
  11. from .repository import Repository
  12. BUFSIZE = 10 * 1024 * 1024
  13. class ConnectionClosed(Error):
  14. """Connection closed by remote host"""
  15. class PathNotAllowed(Error):
  16. """Repository path not allowed"""
  17. class RepositoryServer(object):
  18. def __init__(self, restrict_to_paths):
  19. self.repository = None
  20. self.restrict_to_paths = restrict_to_paths
  21. def serve(self):
  22. # Make stdin non-blocking
  23. fl = fcntl.fcntl(sys.stdin.fileno(), fcntl.F_GETFL)
  24. fcntl.fcntl(sys.stdin.fileno(), fcntl.F_SETFL, fl | os.O_NONBLOCK)
  25. # Make stdout blocking
  26. fl = fcntl.fcntl(sys.stdout.fileno(), fcntl.F_GETFL)
  27. fcntl.fcntl(sys.stdout.fileno(), fcntl.F_SETFL, fl & ~os.O_NONBLOCK)
  28. unpacker = msgpack.Unpacker(use_list=False)
  29. while True:
  30. r, w, es = select.select([sys.stdin], [], [], 10)
  31. if r:
  32. data = os.read(sys.stdin.fileno(), BUFSIZE)
  33. if not data:
  34. return
  35. unpacker.feed(data)
  36. for type, msgid, method, args in unpacker:
  37. method = method.decode('ascii')
  38. try:
  39. try:
  40. f = getattr(self, method)
  41. except AttributeError:
  42. f = getattr(self.repository, method)
  43. res = f(*args)
  44. except Exception as e:
  45. sys.stdout.buffer.write(msgpack.packb((1, msgid, e.__class__.__name__, e.args)))
  46. else:
  47. sys.stdout.buffer.write(msgpack.packb((1, msgid, None, res)))
  48. sys.stdout.flush()
  49. if es:
  50. return
  51. def negotiate(self, versions):
  52. return 1
  53. def open(self, path, create=False):
  54. path = os.fsdecode(path)
  55. if path.startswith('/~'):
  56. path = path[1:]
  57. path = os.path.realpath(os.path.expanduser(path))
  58. if self.restrict_to_paths:
  59. for restrict_to_path in self.restrict_to_paths:
  60. if path.startswith(os.path.realpath(restrict_to_path)):
  61. break
  62. else:
  63. raise PathNotAllowed(path)
  64. self.repository = Repository(path, create)
  65. return self.repository.id
  66. class RemoteRepository(object):
  67. extra_test_args = []
  68. class RPCError(Exception):
  69. def __init__(self, name):
  70. self.name = name
  71. def __init__(self, location, create=False):
  72. self.location = location
  73. self.preload_ids = []
  74. self.msgid = 0
  75. self.to_send = b''
  76. self.cache = {}
  77. self.ignore_responses = set()
  78. self.responses = {}
  79. self.unpacker = msgpack.Unpacker(use_list=False)
  80. self.p = None
  81. if location.host == '__testsuite__':
  82. args = [sys.executable, '-m', 'attic.archiver', 'serve'] + self.extra_test_args
  83. else:
  84. args = ['ssh']
  85. if location.port:
  86. args += ['-p', str(location.port)]
  87. if location.user:
  88. args.append('%s@%s' % (location.user, location.host))
  89. else:
  90. args.append('%s' % location.host)
  91. args += ['attic', 'serve']
  92. self.p = Popen(args, bufsize=0, stdin=PIPE, stdout=PIPE)
  93. self.stdin_fd = self.p.stdin.fileno()
  94. self.stdout_fd = self.p.stdout.fileno()
  95. fcntl.fcntl(self.stdin_fd, fcntl.F_SETFL, fcntl.fcntl(self.stdin_fd, fcntl.F_GETFL) | os.O_NONBLOCK)
  96. fcntl.fcntl(self.stdout_fd, fcntl.F_SETFL, fcntl.fcntl(self.stdout_fd, fcntl.F_GETFL) | os.O_NONBLOCK)
  97. self.r_fds = [self.stdout_fd]
  98. self.x_fds = [self.stdin_fd, self.stdout_fd]
  99. version = self.call('negotiate', 1)
  100. if version != 1:
  101. raise Exception('Server insisted on using unsupported protocol version %d' % version)
  102. self.id = self.call('open', location.path, create)
  103. def __del__(self):
  104. self.close()
  105. def call(self, cmd, *args, **kw):
  106. for resp in self.call_many(cmd, [args], **kw):
  107. return resp
  108. def call_many(self, cmd, calls, wait=True, is_preloaded=False):
  109. if not calls:
  110. return
  111. def fetch_from_cache(args):
  112. msgid = self.cache[args].pop(0)
  113. if not self.cache[args]:
  114. del self.cache[args]
  115. return msgid
  116. calls = list(calls)
  117. waiting_for = []
  118. w_fds = [self.stdin_fd]
  119. while wait or calls:
  120. while waiting_for:
  121. try:
  122. error, res = self.responses.pop(waiting_for[0])
  123. waiting_for.pop(0)
  124. if error:
  125. if error == b'DoesNotExist':
  126. raise Repository.DoesNotExist(self.location.orig)
  127. elif error == b'AlreadyExists':
  128. raise Repository.AlreadyExists(self.location.orig)
  129. elif error == b'CheckNeeded':
  130. raise Repository.CheckNeeded(self.location.orig)
  131. elif error == b'IntegrityError':
  132. raise IntegrityError(res)
  133. elif error == b'PathNotAllowed':
  134. raise PathNotAllowed(*res)
  135. raise self.RPCError(error)
  136. else:
  137. yield res
  138. if not waiting_for and not calls:
  139. return
  140. except KeyError:
  141. break
  142. r, w, x = select.select(self.r_fds, w_fds, self.x_fds, 1)
  143. if x:
  144. raise Exception('FD exception occured')
  145. if r:
  146. data = os.read(self.stdout_fd, BUFSIZE)
  147. if not data:
  148. raise ConnectionClosed()
  149. self.unpacker.feed(data)
  150. for type, msgid, error, res in self.unpacker:
  151. if msgid in self.ignore_responses:
  152. self.ignore_responses.remove(msgid)
  153. else:
  154. self.responses[msgid] = error, res
  155. if w:
  156. while not self.to_send and (calls or self.preload_ids) and len(waiting_for) < 100:
  157. if calls:
  158. if is_preloaded:
  159. if calls[0] in self.cache:
  160. waiting_for.append(fetch_from_cache(calls.pop(0)))
  161. else:
  162. args = calls.pop(0)
  163. if cmd == 'get' and args in self.cache:
  164. waiting_for.append(fetch_from_cache(args))
  165. else:
  166. self.msgid += 1
  167. waiting_for.append(self.msgid)
  168. self.to_send = msgpack.packb((1, self.msgid, cmd, args))
  169. if not self.to_send and self.preload_ids:
  170. args = (self.preload_ids.pop(0),)
  171. self.msgid += 1
  172. self.cache.setdefault(args, []).append(self.msgid)
  173. self.to_send = msgpack.packb((1, self.msgid, cmd, args))
  174. if self.to_send:
  175. self.to_send = self.to_send[os.write(self.stdin_fd, self.to_send):]
  176. if not self.to_send and not (calls or self.preload_ids):
  177. w_fds = []
  178. self.ignore_responses |= set(waiting_for)
  179. def check(self, repair=False):
  180. return self.call('check', repair)
  181. def commit(self, *args):
  182. return self.call('commit')
  183. def rollback(self, *args):
  184. return self.call('rollback')
  185. def __len__(self):
  186. return self.call('__len__')
  187. def list(self, limit=None, marker=None):
  188. return self.call('list', limit, marker)
  189. def get(self, id_):
  190. for resp in self.get_many([id_]):
  191. return resp
  192. def get_many(self, ids, is_preloaded=False):
  193. for resp in self.call_many('get', [(id_,) for id_ in ids], is_preloaded=is_preloaded):
  194. yield resp
  195. def put(self, id_, data, wait=True):
  196. return self.call('put', id_, data, wait=wait)
  197. def delete(self, id_, wait=True):
  198. return self.call('delete', id_, wait=wait)
  199. def close(self):
  200. if self.p:
  201. self.p.stdin.close()
  202. self.p.stdout.close()
  203. self.p.wait()
  204. self.p = None
  205. def preload(self, ids):
  206. self.preload_ids += ids
  207. class RepositoryCache:
  208. """A caching Repository wrapper
  209. Caches Repository GET operations using a temporary file
  210. """
  211. def __init__(self, repository):
  212. self.tmppath = None
  213. self.index = None
  214. self.data_fd = None
  215. self.repository = repository
  216. self.entries = {}
  217. self.initialize()
  218. def __del__(self):
  219. self.cleanup()
  220. def initialize(self):
  221. self.tmppath = tempfile.mkdtemp()
  222. self.index = NSIndex.create(os.path.join(self.tmppath, 'index'))
  223. self.data_fd = open(os.path.join(self.tmppath, 'data'), 'a+b')
  224. def cleanup(self):
  225. del self.index
  226. if self.data_fd:
  227. self.data_fd.close()
  228. if self.tmppath:
  229. shutil.rmtree(self.tmppath)
  230. def load_object(self, offset, size):
  231. self.data_fd.seek(offset)
  232. data = self.data_fd.read(size)
  233. assert len(data) == size
  234. return data
  235. def store_object(self, key, data):
  236. self.data_fd.seek(0, os.SEEK_END)
  237. self.data_fd.write(data)
  238. offset = self.data_fd.tell()
  239. self.index[key] = offset - len(data), len(data)
  240. def get(self, key):
  241. return next(self.get_many([key]))
  242. def get_many(self, keys):
  243. unknown_keys = [key for key in keys if not key in self.index]
  244. repository_iterator = zip(unknown_keys, self.repository.get_many(unknown_keys))
  245. for key in keys:
  246. try:
  247. yield self.load_object(*self.index[key])
  248. except KeyError:
  249. for key_, data in repository_iterator:
  250. if key_ == key:
  251. self.store_object(key, data)
  252. yield data
  253. break
  254. # Consume any pending requests
  255. for _ in repository_iterator:
  256. pass
  257. def cache_if_remote(repository):
  258. if isinstance(repository, RemoteRepository):
  259. return RepositoryCache(repository)
  260. return repository