remote.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. import errno
  2. import fcntl
  3. import msgpack
  4. import os
  5. import select
  6. import shutil
  7. from subprocess import Popen, PIPE
  8. import sys
  9. import tempfile
  10. import traceback
  11. from attic import __version__
  12. from .hashindex import NSIndex
  13. from .helpers import Error, IntegrityError
  14. from .repository import Repository
  15. BUFSIZE = 10 * 1024 * 1024
  16. class ConnectionClosed(Error):
  17. """Connection closed by remote host"""
  18. class PathNotAllowed(Error):
  19. """Repository path not allowed"""
  20. class InvalidRPCMethod(Error):
  21. """RPC method is not valid"""
  22. class RepositoryServer:
  23. rpc_methods = (
  24. '__len__',
  25. 'check',
  26. 'commit',
  27. 'delete',
  28. 'get',
  29. 'list',
  30. 'negotiate',
  31. 'open',
  32. 'put',
  33. 'repair',
  34. 'rollback',
  35. )
  36. def __init__(self, restrict_to_paths):
  37. self.repository = None
  38. self.restrict_to_paths = restrict_to_paths
  39. def serve(self):
  40. stdin_fd = sys.stdin.fileno()
  41. stdout_fd = sys.stdout.fileno()
  42. # Make stdin non-blocking
  43. fl = fcntl.fcntl(stdin_fd, fcntl.F_GETFL)
  44. fcntl.fcntl(stdin_fd, fcntl.F_SETFL, fl | os.O_NONBLOCK)
  45. # Make stdout blocking
  46. fl = fcntl.fcntl(stdout_fd, fcntl.F_GETFL)
  47. fcntl.fcntl(stdout_fd, fcntl.F_SETFL, fl & ~os.O_NONBLOCK)
  48. unpacker = msgpack.Unpacker(use_list=False)
  49. while True:
  50. r, w, es = select.select([stdin_fd], [], [], 10)
  51. if r:
  52. data = os.read(stdin_fd, BUFSIZE)
  53. if not data:
  54. return
  55. unpacker.feed(data)
  56. for unpacked in unpacker:
  57. if not (isinstance(unpacked, tuple) and len(unpacked) == 4):
  58. raise Exception("Unexpected RPC data format.")
  59. type, msgid, method, args = unpacked
  60. method = method.decode('ascii')
  61. try:
  62. if not method in self.rpc_methods:
  63. raise InvalidRPCMethod(method)
  64. try:
  65. f = getattr(self, method)
  66. except AttributeError:
  67. f = getattr(self.repository, method)
  68. res = f(*args)
  69. except BaseException as e:
  70. exc = "Remote Traceback by Borg %s%s%s" % (__version__, os.linesep, traceback.format_exc())
  71. os.write(stdout_fd, msgpack.packb((1, msgid, e.__class__.__name__, exc)))
  72. else:
  73. os.write(stdout_fd, msgpack.packb((1, msgid, None, res)))
  74. if es:
  75. return
  76. def negotiate(self, versions):
  77. return 1
  78. def open(self, path, create=False):
  79. path = os.fsdecode(path)
  80. if path.startswith('/~'):
  81. path = path[1:]
  82. path = os.path.realpath(os.path.expanduser(path))
  83. if self.restrict_to_paths:
  84. for restrict_to_path in self.restrict_to_paths:
  85. if path.startswith(os.path.realpath(restrict_to_path)):
  86. break
  87. else:
  88. raise PathNotAllowed(path)
  89. self.repository = Repository(path, create)
  90. return self.repository.id
  91. class RemoteRepository:
  92. extra_test_args = []
  93. class RPCError(Exception):
  94. def __init__(self, name):
  95. self.name = name
  96. def __init__(self, location, create=False):
  97. self.location = location
  98. self.preload_ids = []
  99. self.msgid = 0
  100. self.to_send = b''
  101. self.cache = {}
  102. self.ignore_responses = set()
  103. self.responses = {}
  104. self.unpacker = msgpack.Unpacker(use_list=False)
  105. self.p = None
  106. if location.host == '__testsuite__':
  107. args = [sys.executable, '-m', 'attic.archiver', 'serve'] + self.extra_test_args
  108. else:
  109. args = ['ssh']
  110. if location.port:
  111. args += ['-p', str(location.port)]
  112. if location.user:
  113. args.append('%s@%s' % (location.user, location.host))
  114. else:
  115. args.append('%s' % location.host)
  116. args += ['borg', 'serve']
  117. self.p = Popen(args, bufsize=0, stdin=PIPE, stdout=PIPE)
  118. self.stdin_fd = self.p.stdin.fileno()
  119. self.stdout_fd = self.p.stdout.fileno()
  120. fcntl.fcntl(self.stdin_fd, fcntl.F_SETFL, fcntl.fcntl(self.stdin_fd, fcntl.F_GETFL) | os.O_NONBLOCK)
  121. fcntl.fcntl(self.stdout_fd, fcntl.F_SETFL, fcntl.fcntl(self.stdout_fd, fcntl.F_GETFL) | os.O_NONBLOCK)
  122. self.r_fds = [self.stdout_fd]
  123. self.x_fds = [self.stdin_fd, self.stdout_fd]
  124. version = self.call('negotiate', 1)
  125. if version != 1:
  126. raise Exception('Server insisted on using unsupported protocol version %d' % version)
  127. self.id = self.call('open', location.path, create)
  128. def __del__(self):
  129. self.close()
  130. def call(self, cmd, *args, **kw):
  131. for resp in self.call_many(cmd, [args], **kw):
  132. return resp
  133. def call_many(self, cmd, calls, wait=True, is_preloaded=False):
  134. if not calls:
  135. return
  136. def fetch_from_cache(args):
  137. msgid = self.cache[args].pop(0)
  138. if not self.cache[args]:
  139. del self.cache[args]
  140. return msgid
  141. calls = list(calls)
  142. waiting_for = []
  143. w_fds = [self.stdin_fd]
  144. while wait or calls:
  145. while waiting_for:
  146. try:
  147. error, res = self.responses.pop(waiting_for[0])
  148. waiting_for.pop(0)
  149. if error:
  150. if error == b'DoesNotExist':
  151. raise Repository.DoesNotExist(self.location.orig)
  152. elif error == b'AlreadyExists':
  153. raise Repository.AlreadyExists(self.location.orig)
  154. elif error == b'CheckNeeded':
  155. raise Repository.CheckNeeded(self.location.orig)
  156. elif error == b'IntegrityError':
  157. raise IntegrityError(res)
  158. elif error == b'PathNotAllowed':
  159. raise PathNotAllowed(*res)
  160. elif error == b'ObjectNotFound':
  161. raise Repository.ObjectNotFound(res[0], self.location.orig)
  162. elif error == b'InvalidRPCMethod':
  163. raise InvalidRPCMethod(*res)
  164. else:
  165. raise self.RPCError(res.decode('utf-8'))
  166. else:
  167. yield res
  168. if not waiting_for and not calls:
  169. return
  170. except KeyError:
  171. break
  172. r, w, x = select.select(self.r_fds, w_fds, self.x_fds, 1)
  173. if x:
  174. raise Exception('FD exception occured')
  175. if r:
  176. data = os.read(self.stdout_fd, BUFSIZE)
  177. if not data:
  178. raise ConnectionClosed()
  179. self.unpacker.feed(data)
  180. for unpacked in self.unpacker:
  181. if not (isinstance(unpacked, tuple) and len(unpacked) == 4):
  182. raise Exception("Unexpected RPC data format.")
  183. type, msgid, error, res = unpacked
  184. if msgid in self.ignore_responses:
  185. self.ignore_responses.remove(msgid)
  186. else:
  187. self.responses[msgid] = error, res
  188. if w:
  189. while not self.to_send and (calls or self.preload_ids) and len(waiting_for) < 100:
  190. if calls:
  191. if is_preloaded:
  192. if calls[0] in self.cache:
  193. waiting_for.append(fetch_from_cache(calls.pop(0)))
  194. else:
  195. args = calls.pop(0)
  196. if cmd == 'get' and args in self.cache:
  197. waiting_for.append(fetch_from_cache(args))
  198. else:
  199. self.msgid += 1
  200. waiting_for.append(self.msgid)
  201. self.to_send = msgpack.packb((1, self.msgid, cmd, args))
  202. if not self.to_send and self.preload_ids:
  203. args = (self.preload_ids.pop(0),)
  204. self.msgid += 1
  205. self.cache.setdefault(args, []).append(self.msgid)
  206. self.to_send = msgpack.packb((1, self.msgid, cmd, args))
  207. if self.to_send:
  208. try:
  209. self.to_send = self.to_send[os.write(self.stdin_fd, self.to_send):]
  210. except OSError as e:
  211. # io.write might raise EAGAIN even though select indicates
  212. # that the fd should be writable
  213. if e.errno != errno.EAGAIN:
  214. raise
  215. if not self.to_send and not (calls or self.preload_ids):
  216. w_fds = []
  217. self.ignore_responses |= set(waiting_for)
  218. def check(self, repair=False):
  219. return self.call('check', repair)
  220. def commit(self, *args):
  221. return self.call('commit')
  222. def rollback(self, *args):
  223. return self.call('rollback')
  224. def destroy(self):
  225. return self.call('destroy')
  226. def __len__(self):
  227. return self.call('__len__')
  228. def list(self, limit=None, marker=None):
  229. return self.call('list', limit, marker)
  230. def get(self, id_):
  231. for resp in self.get_many([id_]):
  232. return resp
  233. def get_many(self, ids, is_preloaded=False):
  234. for resp in self.call_many('get', [(id_,) for id_ in ids], is_preloaded=is_preloaded):
  235. yield resp
  236. def put(self, id_, data, wait=True):
  237. return self.call('put', id_, data, wait=wait)
  238. def delete(self, id_, wait=True):
  239. return self.call('delete', id_, wait=wait)
  240. def close(self):
  241. if self.p:
  242. self.p.stdin.close()
  243. self.p.stdout.close()
  244. self.p.wait()
  245. self.p = None
  246. def preload(self, ids):
  247. self.preload_ids += ids
  248. class RepositoryCache:
  249. """A caching Repository wrapper
  250. Caches Repository GET operations using a temporary file
  251. """
  252. def __init__(self, repository):
  253. self.tmppath = None
  254. self.index = None
  255. self.data_fd = None
  256. self.repository = repository
  257. self.entries = {}
  258. self.initialize()
  259. def __del__(self):
  260. self.cleanup()
  261. def initialize(self):
  262. self.tmppath = tempfile.mkdtemp()
  263. self.index = NSIndex()
  264. self.data_fd = open(os.path.join(self.tmppath, 'data'), 'a+b')
  265. def cleanup(self):
  266. del self.index
  267. if self.data_fd:
  268. self.data_fd.close()
  269. if self.tmppath:
  270. shutil.rmtree(self.tmppath)
  271. def load_object(self, offset, size):
  272. self.data_fd.seek(offset)
  273. data = self.data_fd.read(size)
  274. assert len(data) == size
  275. return data
  276. def store_object(self, key, data):
  277. self.data_fd.seek(0, os.SEEK_END)
  278. self.data_fd.write(data)
  279. offset = self.data_fd.tell()
  280. self.index[key] = offset - len(data), len(data)
  281. def get(self, key):
  282. return next(self.get_many([key]))
  283. def get_many(self, keys):
  284. unknown_keys = [key for key in keys if key not in self.index]
  285. repository_iterator = zip(unknown_keys, self.repository.get_many(unknown_keys))
  286. for key in keys:
  287. try:
  288. yield self.load_object(*self.index[key])
  289. except KeyError:
  290. for key_, data in repository_iterator:
  291. if key_ == key:
  292. self.store_object(key, data)
  293. yield data
  294. break
  295. # Consume any pending requests
  296. for _ in repository_iterator:
  297. pass
  298. def cache_if_remote(repository):
  299. if isinstance(repository, RemoteRepository):
  300. return RepositoryCache(repository)
  301. return repository