remote.py 9.8 KB


  1. import fcntl
  2. import msgpack
  3. import os
  4. import select
  5. import shutil
  6. from subprocess import Popen, PIPE
  7. import sys
  8. import tempfile
  9. from .hashindex import NSIndex
  10. from .helpers import Error, IntegrityError
  11. from .repository import Repository
  12. BUFSIZE = 10 * 1024 * 1024
  13. class ConnectionClosed(Error):
  14. """Connection closed by remote host"""
  15. class RepositoryServer(object):
  16. def __init__(self):
  17. self.repository = None
  18. def serve(self):
  19. # Make stdin non-blocking
  20. fl = fcntl.fcntl(sys.stdin.fileno(), fcntl.F_GETFL)
  21. fcntl.fcntl(sys.stdin.fileno(), fcntl.F_SETFL, fl | os.O_NONBLOCK)
  22. # Make stdout blocking
  23. fl = fcntl.fcntl(sys.stdout.fileno(), fcntl.F_GETFL)
  24. fcntl.fcntl(sys.stdout.fileno(), fcntl.F_SETFL, fl & ~os.O_NONBLOCK)
  25. unpacker = msgpack.Unpacker(use_list=False)
  26. while True:
  27. r, w, es = select.select([sys.stdin], [], [], 10)
  28. if r:
  29. data = os.read(sys.stdin.fileno(), BUFSIZE)
  30. if not data:
  31. return
  32. unpacker.feed(data)
  33. for type, msgid, method, args in unpacker:
  34. method = method.decode('ascii')
  35. try:
  36. try:
  37. f = getattr(self, method)
  38. except AttributeError:
  39. f = getattr(self.repository, method)
  40. res = f(*args)
  41. except Exception as e:
  42. sys.stdout.buffer.write(msgpack.packb((1, msgid, e.__class__.__name__, e.args)))
  43. else:
  44. sys.stdout.buffer.write(msgpack.packb((1, msgid, None, res)))
  45. sys.stdout.flush()
  46. if es:
  47. return
  48. def negotiate(self, versions):
  49. return 1
  50. def open(self, path, create=False):
  51. path = os.fsdecode(path)
  52. if path.startswith('/~'):
  53. path = path[1:]
  54. self.repository = Repository(os.path.expanduser(path), create)
  55. return self.repository.id
  56. class RemoteRepository(object):
  57. class RPCError(Exception):
  58. def __init__(self, name):
  59. self.name = name
  60. def __init__(self, location, create=False):
  61. self.location = location
  62. self.preload_ids = []
  63. self.msgid = 0
  64. self.to_send = b''
  65. self.cache = {}
  66. self.ignore_responses = set()
  67. self.responses = {}
  68. self.unpacker = msgpack.Unpacker(use_list=False)
  69. self.p = None
  70. if location.host == '__testsuite__':
  71. args = [sys.executable, '-m', 'attic.archiver', 'serve']
  72. else:
  73. args = ['ssh']
  74. if location.port:
  75. args += ['-p', str(location.port)]
  76. if location.user:
  77. args.append('%s@%s' % (location.user, location.host))
  78. else:
  79. args.append('%s' % location.host)
  80. args += ['attic', 'serve']
  81. self.p = Popen(args, bufsize=0, stdin=PIPE, stdout=PIPE)
  82. self.stdin_fd = self.p.stdin.fileno()
  83. self.stdout_fd = self.p.stdout.fileno()
  84. fcntl.fcntl(self.stdin_fd, fcntl.F_SETFL, fcntl.fcntl(self.stdin_fd, fcntl.F_GETFL) | os.O_NONBLOCK)
  85. fcntl.fcntl(self.stdout_fd, fcntl.F_SETFL, fcntl.fcntl(self.stdout_fd, fcntl.F_GETFL) | os.O_NONBLOCK)
  86. self.r_fds = [self.stdout_fd]
  87. self.x_fds = [self.stdin_fd, self.stdout_fd]
  88. version = self.call('negotiate', 1)
  89. if version != 1:
  90. raise Exception('Server insisted on using unsupported protocol version %d' % version)
  91. self.id = self.call('open', location.path, create)
  92. def __del__(self):
  93. self.close()
  94. def call(self, cmd, *args, **kw):
  95. for resp in self.call_many(cmd, [args], **kw):
  96. return resp
  97. def call_many(self, cmd, calls, wait=True, is_preloaded=False):
  98. if not calls:
  99. return
  100. def fetch_from_cache(args):
  101. msgid = self.cache[args].pop(0)
  102. if not self.cache[args]:
  103. del self.cache[args]
  104. return msgid
  105. calls = list(calls)
  106. waiting_for = []
  107. w_fds = [self.stdin_fd]
  108. while wait or calls:
  109. while waiting_for:
  110. try:
  111. error, res = self.responses.pop(waiting_for[0])
  112. waiting_for.pop(0)
  113. if error:
  114. if error == b'DoesNotExist':
  115. raise Repository.DoesNotExist(self.location.orig)
  116. elif error == b'AlreadyExists':
  117. raise Repository.AlreadyExists(self.location.orig)
  118. elif error == b'CheckNeeded':
  119. raise Repository.CheckNeeded(self.location.orig)
  120. elif error == b'IntegrityError':
  121. raise IntegrityError(res)
  122. raise self.RPCError(error)
  123. else:
  124. yield res
  125. if not waiting_for and not calls:
  126. return
  127. except KeyError:
  128. break
  129. r, w, x = select.select(self.r_fds, w_fds, self.x_fds, 1)
  130. if x:
  131. raise Exception('FD exception occured')
  132. if r:
  133. data = os.read(self.stdout_fd, BUFSIZE)
  134. if not data:
  135. raise ConnectionClosed()
  136. self.unpacker.feed(data)
  137. for type, msgid, error, res in self.unpacker:
  138. if msgid in self.ignore_responses:
  139. self.ignore_responses.remove(msgid)
  140. else:
  141. self.responses[msgid] = error, res
  142. if w:
  143. while not self.to_send and (calls or self.preload_ids) and len(waiting_for) < 100:
  144. if calls:
  145. if is_preloaded:
  146. if calls[0] in self.cache:
  147. waiting_for.append(fetch_from_cache(calls.pop(0)))
  148. else:
  149. args = calls.pop(0)
  150. if cmd == 'get' and args in self.cache:
  151. waiting_for.append(fetch_from_cache(args))
  152. else:
  153. self.msgid += 1
  154. waiting_for.append(self.msgid)
  155. self.to_send = msgpack.packb((1, self.msgid, cmd, args))
  156. if not self.to_send and self.preload_ids:
  157. args = (self.preload_ids.pop(0),)
  158. self.msgid += 1
  159. self.cache.setdefault(args, []).append(self.msgid)
  160. self.to_send = msgpack.packb((1, self.msgid, cmd, args))
  161. if self.to_send:
  162. self.to_send = self.to_send[os.write(self.stdin_fd, self.to_send):]
  163. if not self.to_send and not (calls or self.preload_ids):
  164. w_fds = []
  165. self.ignore_responses |= set(waiting_for)
  166. def check(self, repair=False):
  167. return self.call('check', repair)
  168. def commit(self, *args):
  169. return self.call('commit')
  170. def rollback(self, *args):
  171. return self.call('rollback')
  172. def __len__(self):
  173. return self.call('__len__')
  174. def list(self, limit=None, marker=None):
  175. return self.call('list', limit, marker)
  176. def get(self, id_):
  177. for resp in self.get_many([id_]):
  178. return resp
  179. def get_many(self, ids, is_preloaded=False):
  180. for resp in self.call_many('get', [(id_,) for id_ in ids], is_preloaded=is_preloaded):
  181. yield resp
  182. def put(self, id_, data, wait=True):
  183. return self.call('put', id_, data, wait=wait)
  184. def delete(self, id_, wait=True):
  185. return self.call('delete', id_, wait=wait)
  186. def close(self):
  187. if self.p:
  188. self.p.stdin.close()
  189. self.p.stdout.close()
  190. self.p.wait()
  191. self.p = None
  192. def preload(self, ids):
  193. self.preload_ids += ids
  194. class RepositoryCache:
  195. """A caching Repository wrapper
  196. Caches Repository GET operations using a temporary file
  197. """
  198. def __init__(self, repository):
  199. self.tmppath = None
  200. self.index = None
  201. self.data_fd = None
  202. self.repository = repository
  203. self.entries = {}
  204. self.initialize()
  205. def __del__(self):
  206. self.cleanup()
  207. def initialize(self):
  208. self.tmppath = tempfile.mkdtemp()
  209. self.index = NSIndex.create(os.path.join(self.tmppath, 'index'))
  210. self.data_fd = open(os.path.join(self.tmppath, 'data'), 'a+b')
  211. def cleanup(self):
  212. del self.index
  213. if self.data_fd:
  214. self.data_fd.close()
  215. if self.tmppath:
  216. shutil.rmtree(self.tmppath)
  217. def load_object(self, offset, size):
  218. self.data_fd.seek(offset)
  219. data = self.data_fd.read(size)
  220. assert len(data) == size
  221. return data
  222. def store_object(self, key, data):
  223. self.data_fd.seek(0, os.SEEK_END)
  224. self.data_fd.write(data)
  225. offset = self.data_fd.tell()
  226. self.index[key] = offset - len(data), len(data)
  227. def get(self, key):
  228. return next(self.get_many([key]))
  229. def get_many(self, keys):
  230. unknown_keys = [key for key in keys if not key in self.index]
  231. repository_iterator = zip(unknown_keys, self.repository.get_many(unknown_keys))
  232. for key in keys:
  233. try:
  234. yield self.load_object(*self.index[key])
  235. except KeyError:
  236. for key_, data in repository_iterator:
  237. if key_ == key:
  238. self.store_object(key, data)
  239. yield data
  240. break
  241. # Consume any pending requests
  242. for _ in repository_iterator:
  243. pass