remote.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. import errno
  2. import fcntl
  3. import msgpack
  4. import os
  5. import select
  6. import shutil
  7. from subprocess import Popen, PIPE
  8. import sys
  9. import tempfile
  10. from .hashindex import NSIndex
  11. from .helpers import Error, IntegrityError
  12. from .repository import Repository
  13. BUFSIZE = 10 * 1024 * 1024
  14. class ConnectionClosed(Error):
  15. """Connection closed by remote host"""
  16. class PathNotAllowed(Error):
  17. """Repository path not allowed"""
  18. class RepositoryServer(object):
  19. def __init__(self, restrict_to_paths):
  20. self.repository = None
  21. self.restrict_to_paths = restrict_to_paths
  22. def serve(self):
  23. # Make stdin non-blocking
  24. fl = fcntl.fcntl(sys.stdin.fileno(), fcntl.F_GETFL)
  25. fcntl.fcntl(sys.stdin.fileno(), fcntl.F_SETFL, fl | os.O_NONBLOCK)
  26. # Make stdout blocking
  27. fl = fcntl.fcntl(sys.stdout.fileno(), fcntl.F_GETFL)
  28. fcntl.fcntl(sys.stdout.fileno(), fcntl.F_SETFL, fl & ~os.O_NONBLOCK)
  29. unpacker = msgpack.Unpacker(use_list=False)
  30. while True:
  31. r, w, es = select.select([sys.stdin], [], [], 10)
  32. if r:
  33. data = os.read(sys.stdin.fileno(), BUFSIZE)
  34. if not data:
  35. return
  36. unpacker.feed(data)
  37. for type, msgid, method, args in unpacker:
  38. method = method.decode('ascii')
  39. try:
  40. try:
  41. f = getattr(self, method)
  42. except AttributeError:
  43. f = getattr(self.repository, method)
  44. res = f(*args)
  45. except Exception as e:
  46. sys.stdout.buffer.write(msgpack.packb((1, msgid, e.__class__.__name__, e.args)))
  47. else:
  48. sys.stdout.buffer.write(msgpack.packb((1, msgid, None, res)))
  49. sys.stdout.flush()
  50. if es:
  51. return
  52. def negotiate(self, versions):
  53. return 1
  54. def open(self, path, create=False):
  55. path = os.fsdecode(path)
  56. if path.startswith('/~'):
  57. path = path[1:]
  58. path = os.path.realpath(os.path.expanduser(path))
  59. if self.restrict_to_paths:
  60. for restrict_to_path in self.restrict_to_paths:
  61. if path.startswith(os.path.realpath(restrict_to_path)):
  62. break
  63. else:
  64. raise PathNotAllowed(path)
  65. self.repository = Repository(path, create)
  66. return self.repository.id
  67. class RemoteRepository(object):
  68. extra_test_args = []
  69. class RPCError(Exception):
  70. def __init__(self, name):
  71. self.name = name
  72. def __init__(self, location, create=False):
  73. self.location = location
  74. self.preload_ids = []
  75. self.msgid = 0
  76. self.to_send = b''
  77. self.cache = {}
  78. self.ignore_responses = set()
  79. self.responses = {}
  80. self.unpacker = msgpack.Unpacker(use_list=False)
  81. self.p = None
  82. if location.host == '__testsuite__':
  83. args = [sys.executable, '-m', 'attic.archiver', 'serve'] + self.extra_test_args
  84. else:
  85. args = ['ssh']
  86. if location.port:
  87. args += ['-p', str(location.port)]
  88. if location.user:
  89. args.append('%s@%s' % (location.user, location.host))
  90. else:
  91. args.append('%s' % location.host)
  92. args += ['attic', 'serve']
  93. self.p = Popen(args, bufsize=0, stdin=PIPE, stdout=PIPE)
  94. self.stdin_fd = self.p.stdin.fileno()
  95. self.stdout_fd = self.p.stdout.fileno()
  96. fcntl.fcntl(self.stdin_fd, fcntl.F_SETFL, fcntl.fcntl(self.stdin_fd, fcntl.F_GETFL) | os.O_NONBLOCK)
  97. fcntl.fcntl(self.stdout_fd, fcntl.F_SETFL, fcntl.fcntl(self.stdout_fd, fcntl.F_GETFL) | os.O_NONBLOCK)
  98. self.r_fds = [self.stdout_fd]
  99. self.x_fds = [self.stdin_fd, self.stdout_fd]
  100. version = self.call('negotiate', 1)
  101. if version != 1:
  102. raise Exception('Server insisted on using unsupported protocol version %d' % version)
  103. self.id = self.call('open', location.path, create)
  104. def __del__(self):
  105. self.close()
  106. def call(self, cmd, *args, **kw):
  107. for resp in self.call_many(cmd, [args], **kw):
  108. return resp
  109. def call_many(self, cmd, calls, wait=True, is_preloaded=False):
  110. if not calls:
  111. return
  112. def fetch_from_cache(args):
  113. msgid = self.cache[args].pop(0)
  114. if not self.cache[args]:
  115. del self.cache[args]
  116. return msgid
  117. calls = list(calls)
  118. waiting_for = []
  119. w_fds = [self.stdin_fd]
  120. while wait or calls:
  121. while waiting_for:
  122. try:
  123. error, res = self.responses.pop(waiting_for[0])
  124. waiting_for.pop(0)
  125. if error:
  126. if error == b'DoesNotExist':
  127. raise Repository.DoesNotExist(self.location.orig)
  128. elif error == b'AlreadyExists':
  129. raise Repository.AlreadyExists(self.location.orig)
  130. elif error == b'CheckNeeded':
  131. raise Repository.CheckNeeded(self.location.orig)
  132. elif error == b'IntegrityError':
  133. raise IntegrityError(res)
  134. elif error == b'PathNotAllowed':
  135. raise PathNotAllowed(*res)
  136. if error == b'ObjectNotFound':
  137. raise Repository.ObjectNotFound(res[0], self.location.orig)
  138. raise self.RPCError(error)
  139. else:
  140. yield res
  141. if not waiting_for and not calls:
  142. return
  143. except KeyError:
  144. break
  145. r, w, x = select.select(self.r_fds, w_fds, self.x_fds, 1)
  146. if x:
  147. raise Exception('FD exception occured')
  148. if r:
  149. data = os.read(self.stdout_fd, BUFSIZE)
  150. if not data:
  151. raise ConnectionClosed()
  152. self.unpacker.feed(data)
  153. for type, msgid, error, res in self.unpacker:
  154. if msgid in self.ignore_responses:
  155. self.ignore_responses.remove(msgid)
  156. else:
  157. self.responses[msgid] = error, res
  158. if w:
  159. while not self.to_send and (calls or self.preload_ids) and len(waiting_for) < 100:
  160. if calls:
  161. if is_preloaded:
  162. if calls[0] in self.cache:
  163. waiting_for.append(fetch_from_cache(calls.pop(0)))
  164. else:
  165. args = calls.pop(0)
  166. if cmd == 'get' and args in self.cache:
  167. waiting_for.append(fetch_from_cache(args))
  168. else:
  169. self.msgid += 1
  170. waiting_for.append(self.msgid)
  171. self.to_send = msgpack.packb((1, self.msgid, cmd, args))
  172. if not self.to_send and self.preload_ids:
  173. args = (self.preload_ids.pop(0),)
  174. self.msgid += 1
  175. self.cache.setdefault(args, []).append(self.msgid)
  176. self.to_send = msgpack.packb((1, self.msgid, cmd, args))
  177. if self.to_send:
  178. try:
  179. self.to_send = self.to_send[os.write(self.stdin_fd, self.to_send):]
  180. except OSError as e:
  181. # io.write might raise EAGAIN even though select indicates
  182. # that the fd should be writable
  183. if e.errno != errno.EAGAIN:
  184. raise
  185. if not self.to_send and not (calls or self.preload_ids):
  186. w_fds = []
  187. self.ignore_responses |= set(waiting_for)
  188. def check(self, repair=False):
  189. return self.call('check', repair)
  190. def commit(self, *args):
  191. return self.call('commit')
  192. def rollback(self, *args):
  193. return self.call('rollback')
  194. def __len__(self):
  195. return self.call('__len__')
  196. def list(self, limit=None, marker=None):
  197. return self.call('list', limit, marker)
  198. def get(self, id_):
  199. for resp in self.get_many([id_]):
  200. return resp
  201. def get_many(self, ids, is_preloaded=False):
  202. for resp in self.call_many('get', [(id_,) for id_ in ids], is_preloaded=is_preloaded):
  203. yield resp
  204. def put(self, id_, data, wait=True):
  205. return self.call('put', id_, data, wait=wait)
  206. def delete(self, id_, wait=True):
  207. return self.call('delete', id_, wait=wait)
  208. def close(self):
  209. if self.p:
  210. self.p.stdin.close()
  211. self.p.stdout.close()
  212. self.p.wait()
  213. self.p = None
  214. def preload(self, ids):
  215. self.preload_ids += ids
  216. class RepositoryCache:
  217. """A caching Repository wrapper
  218. Caches Repository GET operations using a temporary file
  219. """
  220. def __init__(self, repository):
  221. self.tmppath = None
  222. self.index = None
  223. self.data_fd = None
  224. self.repository = repository
  225. self.entries = {}
  226. self.initialize()
  227. def __del__(self):
  228. self.cleanup()
  229. def initialize(self):
  230. self.tmppath = tempfile.mkdtemp()
  231. self.index = NSIndex()
  232. self.data_fd = open(os.path.join(self.tmppath, 'data'), 'a+b')
  233. def cleanup(self):
  234. del self.index
  235. if self.data_fd:
  236. self.data_fd.close()
  237. if self.tmppath:
  238. shutil.rmtree(self.tmppath)
  239. def load_object(self, offset, size):
  240. self.data_fd.seek(offset)
  241. data = self.data_fd.read(size)
  242. assert len(data) == size
  243. return data
  244. def store_object(self, key, data):
  245. self.data_fd.seek(0, os.SEEK_END)
  246. self.data_fd.write(data)
  247. offset = self.data_fd.tell()
  248. self.index[key] = offset - len(data), len(data)
  249. def get(self, key):
  250. return next(self.get_many([key]))
  251. def get_many(self, keys):
  252. unknown_keys = [key for key in keys if not key in self.index]
  253. repository_iterator = zip(unknown_keys, self.repository.get_many(unknown_keys))
  254. for key in keys:
  255. try:
  256. yield self.load_object(*self.index[key])
  257. except KeyError:
  258. for key_, data in repository_iterator:
  259. if key_ == key:
  260. self.store_object(key, data)
  261. yield data
  262. break
  263. # Consume any pending requests
  264. for _ in repository_iterator:
  265. pass
  266. def cache_if_remote(repository):
  267. if isinstance(repository, RemoteRepository):
  268. return RepositoryCache(repository)
  269. return repository