remote.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. import errno
  2. import fcntl
  3. import os
  4. import select
  5. import shlex
  6. from subprocess import Popen, PIPE
  7. import sys
  8. import tempfile
  9. import traceback
  10. from . import __version__
  11. from .helpers import Error, IntegrityError, have_cython
  12. from .repository import Repository
  13. if have_cython():
  14. import msgpack
  15. BUFSIZE = 10 * 1024 * 1024
  16. class ConnectionClosed(Error):
  17. """Connection closed by remote host"""
  18. class ConnectionClosedWithHint(ConnectionClosed):
  19. """Connection closed by remote host. {}"""
  20. class PathNotAllowed(Error):
  21. """Repository path not allowed"""
  22. class InvalidRPCMethod(Error):
  23. """RPC method {} is not valid"""
  24. class RepositoryServer: # pragma: no cover
  25. rpc_methods = (
  26. '__len__',
  27. 'check',
  28. 'commit',
  29. 'delete',
  30. 'destroy',
  31. 'get',
  32. 'list',
  33. 'negotiate',
  34. 'open',
  35. 'put',
  36. 'repair',
  37. 'rollback',
  38. 'save_key',
  39. 'load_key',
  40. )
  41. def __init__(self, restrict_to_paths):
  42. self.repository = None
  43. self.restrict_to_paths = restrict_to_paths
  44. def serve(self):
  45. stdin_fd = sys.stdin.fileno()
  46. stdout_fd = sys.stdout.fileno()
  47. # Make stdin non-blocking
  48. fl = fcntl.fcntl(stdin_fd, fcntl.F_GETFL)
  49. fcntl.fcntl(stdin_fd, fcntl.F_SETFL, fl | os.O_NONBLOCK)
  50. # Make stdout blocking
  51. fl = fcntl.fcntl(stdout_fd, fcntl.F_GETFL)
  52. fcntl.fcntl(stdout_fd, fcntl.F_SETFL, fl & ~os.O_NONBLOCK)
  53. unpacker = msgpack.Unpacker(use_list=False)
  54. while True:
  55. r, w, es = select.select([stdin_fd], [], [], 10)
  56. if r:
  57. data = os.read(stdin_fd, BUFSIZE)
  58. if not data:
  59. return
  60. unpacker.feed(data)
  61. for unpacked in unpacker:
  62. if not (isinstance(unpacked, tuple) and len(unpacked) == 4):
  63. raise Exception("Unexpected RPC data format.")
  64. type, msgid, method, args = unpacked
  65. method = method.decode('ascii')
  66. try:
  67. if method not in self.rpc_methods:
  68. raise InvalidRPCMethod(method)
  69. try:
  70. f = getattr(self, method)
  71. except AttributeError:
  72. f = getattr(self.repository, method)
  73. res = f(*args)
  74. except BaseException as e:
  75. exc = "Remote Traceback by Borg %s%s%s" % (__version__, os.linesep, traceback.format_exc())
  76. os.write(stdout_fd, msgpack.packb((1, msgid, e.__class__.__name__, exc)))
  77. else:
  78. os.write(stdout_fd, msgpack.packb((1, msgid, None, res)))
  79. if es:
  80. return
  81. def negotiate(self, versions):
  82. return 1
  83. def open(self, path, create=False):
  84. path = os.fsdecode(path)
  85. if path.startswith('/~'):
  86. path = path[1:]
  87. path = os.path.realpath(os.path.expanduser(path))
  88. if self.restrict_to_paths:
  89. for restrict_to_path in self.restrict_to_paths:
  90. if path.startswith(os.path.realpath(restrict_to_path)):
  91. break
  92. else:
  93. raise PathNotAllowed(path)
  94. self.repository = Repository(path, create)
  95. return self.repository.id
  96. class RemoteRepository:
  97. extra_test_args = []
  98. remote_path = 'borg'
  99. # default umask, overriden by --umask, defaults to read/write only for owner
  100. umask = 0o077
  101. class RPCError(Exception):
  102. def __init__(self, name):
  103. self.name = name
  104. def __init__(self, location, create=False):
  105. self.location = location
  106. self.preload_ids = []
  107. self.msgid = 0
  108. self.to_send = b''
  109. self.cache = {}
  110. self.ignore_responses = set()
  111. self.responses = {}
  112. self.unpacker = msgpack.Unpacker(use_list=False)
  113. self.p = None
  114. # XXX: ideally, the testsuite would subclass Repository and
  115. # override ssh_cmd() instead of this crude hack, although
  116. # __testsuite__ is not a valid domain name so this is pretty
  117. # safe.
  118. if location.host == '__testsuite__':
  119. args = [sys.executable, '-m', 'borg.archiver', 'serve' ] + self.extra_test_args
  120. else: # pragma: no cover
  121. args = self.ssh_cmd(location)
  122. self.p = Popen(args, bufsize=0, stdin=PIPE, stdout=PIPE)
  123. self.stdin_fd = self.p.stdin.fileno()
  124. self.stdout_fd = self.p.stdout.fileno()
  125. fcntl.fcntl(self.stdin_fd, fcntl.F_SETFL, fcntl.fcntl(self.stdin_fd, fcntl.F_GETFL) | os.O_NONBLOCK)
  126. fcntl.fcntl(self.stdout_fd, fcntl.F_SETFL, fcntl.fcntl(self.stdout_fd, fcntl.F_GETFL) | os.O_NONBLOCK)
  127. self.r_fds = [self.stdout_fd]
  128. self.x_fds = [self.stdin_fd, self.stdout_fd]
  129. try:
  130. version = self.call('negotiate', 1)
  131. except ConnectionClosed:
  132. raise ConnectionClosedWithHint('Is borg working on the server?')
  133. if version != 1:
  134. raise Exception('Server insisted on using unsupported protocol version %d' % version)
  135. self.id = self.call('open', location.path, create)
  136. def __del__(self):
  137. self.close()
  138. def __repr__(self):
  139. return '<%s %s>' % (self.__class__.__name__, self.location.canonical_path())
  140. def umask_flag(self):
  141. return ['--umask', '%03o' % self.umask]
  142. def ssh_cmd(self, location):
  143. args = shlex.split(os.environ.get('BORG_RSH', 'ssh'))
  144. if location.port:
  145. args += ['-p', str(location.port)]
  146. if location.user:
  147. args.append('%s@%s' % (location.user, location.host))
  148. else:
  149. args.append('%s' % location.host)
  150. # use local umask also for the remote process
  151. args += [self.remote_path, 'serve'] + self.umask_flag()
  152. return args
  153. def call(self, cmd, *args, **kw):
  154. for resp in self.call_many(cmd, [args], **kw):
  155. return resp
  156. def call_many(self, cmd, calls, wait=True, is_preloaded=False):
  157. if not calls:
  158. return
  159. def fetch_from_cache(args):
  160. msgid = self.cache[args].pop(0)
  161. if not self.cache[args]:
  162. del self.cache[args]
  163. return msgid
  164. calls = list(calls)
  165. waiting_for = []
  166. w_fds = [self.stdin_fd]
  167. while wait or calls:
  168. while waiting_for:
  169. try:
  170. error, res = self.responses.pop(waiting_for[0])
  171. waiting_for.pop(0)
  172. if error:
  173. if error == b'DoesNotExist':
  174. raise Repository.DoesNotExist(self.location.orig)
  175. elif error == b'AlreadyExists':
  176. raise Repository.AlreadyExists(self.location.orig)
  177. elif error == b'CheckNeeded':
  178. raise Repository.CheckNeeded(self.location.orig)
  179. elif error == b'IntegrityError':
  180. raise IntegrityError(res)
  181. elif error == b'PathNotAllowed':
  182. raise PathNotAllowed(*res)
  183. elif error == b'ObjectNotFound':
  184. raise Repository.ObjectNotFound(res[0], self.location.orig)
  185. elif error == b'InvalidRPCMethod':
  186. raise InvalidRPCMethod(*res)
  187. else:
  188. raise self.RPCError(res.decode('utf-8'))
  189. else:
  190. yield res
  191. if not waiting_for and not calls:
  192. return
  193. except KeyError:
  194. break
  195. r, w, x = select.select(self.r_fds, w_fds, self.x_fds, 1)
  196. if x:
  197. raise Exception('FD exception occurred')
  198. if r:
  199. data = os.read(self.stdout_fd, BUFSIZE)
  200. if not data:
  201. raise ConnectionClosed()
  202. self.unpacker.feed(data)
  203. for unpacked in self.unpacker:
  204. if not (isinstance(unpacked, tuple) and len(unpacked) == 4):
  205. raise Exception("Unexpected RPC data format.")
  206. type, msgid, error, res = unpacked
  207. if msgid in self.ignore_responses:
  208. self.ignore_responses.remove(msgid)
  209. else:
  210. self.responses[msgid] = error, res
  211. if w:
  212. while not self.to_send and (calls or self.preload_ids) and len(waiting_for) < 100:
  213. if calls:
  214. if is_preloaded:
  215. if calls[0] in self.cache:
  216. waiting_for.append(fetch_from_cache(calls.pop(0)))
  217. else:
  218. args = calls.pop(0)
  219. if cmd == 'get' and args in self.cache:
  220. waiting_for.append(fetch_from_cache(args))
  221. else:
  222. self.msgid += 1
  223. waiting_for.append(self.msgid)
  224. self.to_send = msgpack.packb((1, self.msgid, cmd, args))
  225. if not self.to_send and self.preload_ids:
  226. args = (self.preload_ids.pop(0),)
  227. self.msgid += 1
  228. self.cache.setdefault(args, []).append(self.msgid)
  229. self.to_send = msgpack.packb((1, self.msgid, cmd, args))
  230. if self.to_send:
  231. try:
  232. self.to_send = self.to_send[os.write(self.stdin_fd, self.to_send):]
  233. except OSError as e:
  234. # io.write might raise EAGAIN even though select indicates
  235. # that the fd should be writable
  236. if e.errno != errno.EAGAIN:
  237. raise
  238. if not self.to_send and not (calls or self.preload_ids):
  239. w_fds = []
  240. self.ignore_responses |= set(waiting_for)
  241. def check(self, repair=False):
  242. return self.call('check', repair)
  243. def commit(self, *args):
  244. return self.call('commit')
  245. def rollback(self, *args):
  246. return self.call('rollback')
  247. def destroy(self):
  248. return self.call('destroy')
  249. def __len__(self):
  250. return self.call('__len__')
  251. def list(self, limit=None, marker=None):
  252. return self.call('list', limit, marker)
  253. def get(self, id_):
  254. for resp in self.get_many([id_]):
  255. return resp
  256. def get_many(self, ids, is_preloaded=False):
  257. for resp in self.call_many('get', [(id_,) for id_ in ids], is_preloaded=is_preloaded):
  258. yield resp
  259. def put(self, id_, data, wait=True):
  260. return self.call('put', id_, data, wait=wait)
  261. def delete(self, id_, wait=True):
  262. return self.call('delete', id_, wait=wait)
  263. def save_key(self, keydata):
  264. return self.call('save_key', keydata)
  265. def load_key(self):
  266. return self.call('load_key')
  267. def close(self):
  268. if self.p:
  269. self.p.stdin.close()
  270. self.p.stdout.close()
  271. self.p.wait()
  272. self.p = None
  273. def preload(self, ids):
  274. self.preload_ids += ids
  275. class RepositoryCache:
  276. """A caching Repository wrapper
  277. Caches Repository GET operations using a local temporary Repository.
  278. """
  279. def __init__(self, repository):
  280. self.repository = repository
  281. tmppath = tempfile.mkdtemp(prefix='borg-tmp')
  282. self.caching_repo = Repository(tmppath, create=True, exclusive=True)
  283. def __del__(self):
  284. self.caching_repo.destroy()
  285. def get(self, key):
  286. return next(self.get_many([key]))
  287. def get_many(self, keys):
  288. unknown_keys = [key for key in keys if key not in self.caching_repo]
  289. repository_iterator = zip(unknown_keys, self.repository.get_many(unknown_keys))
  290. for key in keys:
  291. try:
  292. yield self.caching_repo.get(key)
  293. except Repository.ObjectNotFound:
  294. for key_, data in repository_iterator:
  295. if key_ == key:
  296. self.caching_repo.put(key, data)
  297. yield data
  298. break
  299. # Consume any pending requests
  300. for _ in repository_iterator:
  301. pass
  302. def cache_if_remote(repository):
  303. if isinstance(repository, RemoteRepository):
  304. return RepositoryCache(repository)
  305. return repository