remote.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463
  1. import errno
  2. import fcntl
  3. import logging
  4. import os
  5. import select
  6. import shlex
  7. from subprocess import Popen, PIPE
  8. import sys
  9. import tempfile
  10. from . import __version__
  11. from .helpers import Error, IntegrityError, get_home_dir, sysinfo, bin_to_hex
  12. from .repository import Repository
  13. import msgpack
  14. RPC_PROTOCOL_VERSION = 2
  15. BUFSIZE = 10 * 1024 * 1024
  16. class ConnectionClosed(Error):
  17. """Connection closed by remote host"""
  18. class ConnectionClosedWithHint(ConnectionClosed):
  19. """Connection closed by remote host. {}"""
  20. class PathNotAllowed(Error):
  21. """Repository path not allowed"""
  22. class InvalidRPCMethod(Error):
  23. """RPC method {} is not valid"""
  24. class RepositoryServer: # pragma: no cover
  25. rpc_methods = (
  26. '__len__',
  27. 'check',
  28. 'commit',
  29. 'delete',
  30. 'destroy',
  31. 'get',
  32. 'list',
  33. 'negotiate',
  34. 'open',
  35. 'put',
  36. 'rollback',
  37. 'save_key',
  38. 'load_key',
  39. 'break_lock',
  40. )
  41. def __init__(self, restrict_to_paths):
  42. self.repository = None
  43. self.restrict_to_paths = restrict_to_paths
  44. def serve(self):
  45. stdin_fd = sys.stdin.fileno()
  46. stdout_fd = sys.stdout.fileno()
  47. stderr_fd = sys.stdout.fileno()
  48. # Make stdin non-blocking
  49. fl = fcntl.fcntl(stdin_fd, fcntl.F_GETFL)
  50. fcntl.fcntl(stdin_fd, fcntl.F_SETFL, fl | os.O_NONBLOCK)
  51. # Make stdout blocking
  52. fl = fcntl.fcntl(stdout_fd, fcntl.F_GETFL)
  53. fcntl.fcntl(stdout_fd, fcntl.F_SETFL, fl & ~os.O_NONBLOCK)
  54. # Make stderr blocking
  55. fl = fcntl.fcntl(stderr_fd, fcntl.F_GETFL)
  56. fcntl.fcntl(stderr_fd, fcntl.F_SETFL, fl & ~os.O_NONBLOCK)
  57. unpacker = msgpack.Unpacker(use_list=False)
  58. while True:
  59. r, w, es = select.select([stdin_fd], [], [], 10)
  60. if r:
  61. data = os.read(stdin_fd, BUFSIZE)
  62. if not data:
  63. self.repository.close()
  64. return
  65. unpacker.feed(data)
  66. for unpacked in unpacker:
  67. if not (isinstance(unpacked, tuple) and len(unpacked) == 4):
  68. self.repository.close()
  69. raise Exception("Unexpected RPC data format.")
  70. type, msgid, method, args = unpacked
  71. method = method.decode('ascii')
  72. try:
  73. if method not in self.rpc_methods:
  74. raise InvalidRPCMethod(method)
  75. try:
  76. f = getattr(self, method)
  77. except AttributeError:
  78. f = getattr(self.repository, method)
  79. res = f(*args)
  80. except BaseException as e:
  81. # These exceptions are reconstructed on the client end in RemoteRepository.call_many(),
  82. # and will be handled just like locally raised exceptions. Suppress the remote traceback
  83. # for these, except ErrorWithTraceback, which should always display a traceback.
  84. if not isinstance(e, (Repository.DoesNotExist, Repository.AlreadyExists, PathNotAllowed)):
  85. logging.exception('Borg %s: exception in RPC call:', __version__)
  86. logging.error(sysinfo())
  87. exc = "Remote Exception (see remote log for the traceback)"
  88. os.write(stdout_fd, msgpack.packb((1, msgid, e.__class__.__name__, exc)))
  89. else:
  90. os.write(stdout_fd, msgpack.packb((1, msgid, None, res)))
  91. if es:
  92. self.repository.close()
  93. return
  94. def negotiate(self, versions):
  95. return RPC_PROTOCOL_VERSION
  96. def open(self, path, create=False, lock_wait=None, lock=True):
  97. path = os.fsdecode(path)
  98. if path.startswith('/~'):
  99. path = os.path.join(get_home_dir(), path[2:])
  100. path = os.path.realpath(path)
  101. if self.restrict_to_paths:
  102. for restrict_to_path in self.restrict_to_paths:
  103. if path.startswith(os.path.realpath(restrict_to_path)):
  104. break
  105. else:
  106. raise PathNotAllowed(path)
  107. self.repository = Repository(path, create, lock_wait=lock_wait, lock=lock)
  108. self.repository.__enter__() # clean exit handled by serve() method
  109. return self.repository.id
  110. class RemoteRepository:
  111. extra_test_args = []
  112. class RPCError(Exception):
  113. def __init__(self, name):
  114. self.name = name
  115. def __init__(self, location, create=False, lock_wait=None, lock=True, args=None):
  116. self.location = self._location = location
  117. self.preload_ids = []
  118. self.msgid = 0
  119. self.to_send = b''
  120. self.cache = {}
  121. self.ignore_responses = set()
  122. self.responses = {}
  123. self.unpacker = msgpack.Unpacker(use_list=False)
  124. self.p = None
  125. testing = location.host == '__testsuite__'
  126. borg_cmd = self.borg_cmd(args, testing)
  127. env = dict(os.environ)
  128. if not testing:
  129. borg_cmd = self.ssh_cmd(location) + borg_cmd
  130. # pyinstaller binary adds LD_LIBRARY_PATH=/tmp/_ME... but we do not want
  131. # that the system's ssh binary picks up (non-matching) libraries from there
  132. env.pop('LD_LIBRARY_PATH', None)
  133. self.p = Popen(borg_cmd, bufsize=0, stdin=PIPE, stdout=PIPE, stderr=PIPE, env=env)
  134. self.stdin_fd = self.p.stdin.fileno()
  135. self.stdout_fd = self.p.stdout.fileno()
  136. self.stderr_fd = self.p.stderr.fileno()
  137. fcntl.fcntl(self.stdin_fd, fcntl.F_SETFL, fcntl.fcntl(self.stdin_fd, fcntl.F_GETFL) | os.O_NONBLOCK)
  138. fcntl.fcntl(self.stdout_fd, fcntl.F_SETFL, fcntl.fcntl(self.stdout_fd, fcntl.F_GETFL) | os.O_NONBLOCK)
  139. fcntl.fcntl(self.stderr_fd, fcntl.F_SETFL, fcntl.fcntl(self.stderr_fd, fcntl.F_GETFL) | os.O_NONBLOCK)
  140. self.r_fds = [self.stdout_fd, self.stderr_fd]
  141. self.x_fds = [self.stdin_fd, self.stdout_fd, self.stderr_fd]
  142. try:
  143. version = self.call('negotiate', RPC_PROTOCOL_VERSION)
  144. except ConnectionClosed:
  145. raise ConnectionClosedWithHint('Is borg working on the server?') from None
  146. if version != RPC_PROTOCOL_VERSION:
  147. raise Exception('Server insisted on using unsupported protocol version %d' % version)
  148. try:
  149. self.id = self.call('open', self.location.path, create, lock_wait, lock)
  150. except Exception:
  151. self.close()
  152. raise
  153. def __del__(self):
  154. if self.p:
  155. self.close()
  156. assert False, "cleanup happened in Repository.__del__"
  157. def __repr__(self):
  158. return '<%s %s>' % (self.__class__.__name__, self.location.canonical_path())
  159. def __enter__(self):
  160. return self
  161. def __exit__(self, exc_type, exc_val, exc_tb):
  162. if exc_type is not None:
  163. self.rollback()
  164. self.close()
  165. @property
  166. def id_str(self):
  167. return bin_to_hex(self.id)
  168. def borg_cmd(self, args, testing):
  169. """return a borg serve command line"""
  170. # give some args/options to "borg serve" process as they were given to us
  171. opts = []
  172. if args is not None:
  173. opts.append('--umask=%03o' % args.umask)
  174. root_logger = logging.getLogger()
  175. if root_logger.isEnabledFor(logging.DEBUG):
  176. opts.append('--debug')
  177. elif root_logger.isEnabledFor(logging.INFO):
  178. opts.append('--info')
  179. elif root_logger.isEnabledFor(logging.WARNING):
  180. pass # warning is default
  181. elif root_logger.isEnabledFor(logging.ERROR):
  182. opts.append('--error')
  183. elif root_logger.isEnabledFor(logging.CRITICAL):
  184. opts.append('--critical')
  185. else:
  186. raise ValueError('log level missing, fix this code')
  187. if testing:
  188. return [sys.executable, '-m', 'borg.archiver', 'serve'] + opts + self.extra_test_args
  189. else: # pragma: no cover
  190. return [args.remote_path, 'serve'] + opts
  191. def ssh_cmd(self, location):
  192. """return a ssh command line that can be prefixed to a borg command line"""
  193. args = shlex.split(os.environ.get('BORG_RSH', 'ssh'))
  194. if location.port:
  195. args += ['-p', str(location.port)]
  196. if location.user:
  197. args.append('%s@%s' % (location.user, location.host))
  198. else:
  199. args.append('%s' % location.host)
  200. return args
  201. def call(self, cmd, *args, **kw):
  202. for resp in self.call_many(cmd, [args], **kw):
  203. return resp
  204. def call_many(self, cmd, calls, wait=True, is_preloaded=False):
  205. if not calls:
  206. return
  207. def fetch_from_cache(args):
  208. msgid = self.cache[args].pop(0)
  209. if not self.cache[args]:
  210. del self.cache[args]
  211. return msgid
  212. calls = list(calls)
  213. waiting_for = []
  214. w_fds = [self.stdin_fd]
  215. while wait or calls:
  216. while waiting_for:
  217. try:
  218. error, res = self.responses.pop(waiting_for[0])
  219. waiting_for.pop(0)
  220. if error:
  221. if error == b'DoesNotExist':
  222. raise Repository.DoesNotExist(self.location.orig)
  223. elif error == b'AlreadyExists':
  224. raise Repository.AlreadyExists(self.location.orig)
  225. elif error == b'CheckNeeded':
  226. raise Repository.CheckNeeded(self.location.orig)
  227. elif error == b'IntegrityError':
  228. raise IntegrityError(res)
  229. elif error == b'PathNotAllowed':
  230. raise PathNotAllowed(*res)
  231. elif error == b'ObjectNotFound':
  232. raise Repository.ObjectNotFound(res[0], self.location.orig)
  233. elif error == b'InvalidRPCMethod':
  234. raise InvalidRPCMethod(*res)
  235. else:
  236. raise self.RPCError(res.decode('utf-8'))
  237. else:
  238. yield res
  239. if not waiting_for and not calls:
  240. return
  241. except KeyError:
  242. break
  243. r, w, x = select.select(self.r_fds, w_fds, self.x_fds, 1)
  244. if x:
  245. raise Exception('FD exception occurred')
  246. for fd in r:
  247. if fd is self.stdout_fd:
  248. data = os.read(fd, BUFSIZE)
  249. if not data:
  250. raise ConnectionClosed()
  251. self.unpacker.feed(data)
  252. for unpacked in self.unpacker:
  253. if not (isinstance(unpacked, tuple) and len(unpacked) == 4):
  254. raise Exception("Unexpected RPC data format.")
  255. type, msgid, error, res = unpacked
  256. if msgid in self.ignore_responses:
  257. self.ignore_responses.remove(msgid)
  258. else:
  259. self.responses[msgid] = error, res
  260. elif fd is self.stderr_fd:
  261. data = os.read(fd, 32768)
  262. if not data:
  263. raise ConnectionClosed()
  264. data = data.decode('utf-8')
  265. for line in data.splitlines(keepends=True):
  266. if line.startswith('$LOG '):
  267. _, level, msg = line.split(' ', 2)
  268. level = getattr(logging, level, logging.CRITICAL) # str -> int
  269. if msg.startswith('Remote:'):
  270. # server format: '$LOG <level> Remote: <msg>'
  271. logging.log(level, msg.rstrip())
  272. else:
  273. # server format '$LOG <level> <logname> Remote: <msg>'
  274. logname, msg = msg.split(' ', 1)
  275. logging.getLogger(logname).log(level, msg.rstrip())
  276. else:
  277. sys.stderr.write("Remote: " + line)
  278. if w:
  279. while not self.to_send and (calls or self.preload_ids) and len(waiting_for) < 100:
  280. if calls:
  281. if is_preloaded:
  282. if calls[0] in self.cache:
  283. waiting_for.append(fetch_from_cache(calls.pop(0)))
  284. else:
  285. args = calls.pop(0)
  286. if cmd == 'get' and args in self.cache:
  287. waiting_for.append(fetch_from_cache(args))
  288. else:
  289. self.msgid += 1
  290. waiting_for.append(self.msgid)
  291. self.to_send = msgpack.packb((1, self.msgid, cmd, args))
  292. if not self.to_send and self.preload_ids:
  293. args = (self.preload_ids.pop(0),)
  294. self.msgid += 1
  295. self.cache.setdefault(args, []).append(self.msgid)
  296. self.to_send = msgpack.packb((1, self.msgid, cmd, args))
  297. if self.to_send:
  298. try:
  299. self.to_send = self.to_send[os.write(self.stdin_fd, self.to_send):]
  300. except OSError as e:
  301. # io.write might raise EAGAIN even though select indicates
  302. # that the fd should be writable
  303. if e.errno != errno.EAGAIN:
  304. raise
  305. if not self.to_send and not (calls or self.preload_ids):
  306. w_fds = []
  307. self.ignore_responses |= set(waiting_for)
  308. def check(self, repair=False, save_space=False):
  309. return self.call('check', repair, save_space)
  310. def commit(self, save_space=False):
  311. return self.call('commit', save_space)
  312. def rollback(self, *args):
  313. return self.call('rollback')
  314. def destroy(self):
  315. return self.call('destroy')
  316. def __len__(self):
  317. return self.call('__len__')
  318. def list(self, limit=None, marker=None):
  319. return self.call('list', limit, marker)
  320. def get(self, id_):
  321. for resp in self.get_many([id_]):
  322. return resp
  323. def get_many(self, ids, is_preloaded=False):
  324. for resp in self.call_many('get', [(id_,) for id_ in ids], is_preloaded=is_preloaded):
  325. yield resp
  326. def put(self, id_, data, wait=True):
  327. return self.call('put', id_, data, wait=wait)
  328. def delete(self, id_, wait=True):
  329. return self.call('delete', id_, wait=wait)
  330. def save_key(self, keydata):
  331. return self.call('save_key', keydata)
  332. def load_key(self):
  333. return self.call('load_key')
  334. def break_lock(self):
  335. return self.call('break_lock')
  336. def close(self):
  337. if self.p:
  338. self.p.stdin.close()
  339. self.p.stdout.close()
  340. self.p.wait()
  341. self.p = None
  342. def preload(self, ids):
  343. self.preload_ids += ids
  344. class RepositoryNoCache:
  345. """A not caching Repository wrapper, passes through to repository.
  346. Just to have same API (including the context manager) as RepositoryCache.
  347. """
  348. def __init__(self, repository):
  349. self.repository = repository
  350. def close(self):
  351. pass
  352. def __enter__(self):
  353. return self
  354. def __exit__(self, exc_type, exc_val, exc_tb):
  355. self.close()
  356. def get(self, key):
  357. return next(self.get_many([key]))
  358. def get_many(self, keys):
  359. for data in self.repository.get_many(keys):
  360. yield data
  361. class RepositoryCache(RepositoryNoCache):
  362. """A caching Repository wrapper
  363. Caches Repository GET operations using a local temporary Repository.
  364. """
  365. # maximum object size that will be cached, 64 kiB.
  366. THRESHOLD = 2**16
  367. def __init__(self, repository):
  368. super().__init__(repository)
  369. tmppath = tempfile.mkdtemp(prefix='borg-tmp')
  370. self.caching_repo = Repository(tmppath, create=True, exclusive=True)
  371. self.caching_repo.__enter__() # handled by context manager in base class
  372. def close(self):
  373. if self.caching_repo is not None:
  374. self.caching_repo.destroy()
  375. self.caching_repo = None
  376. def get_many(self, keys):
  377. unknown_keys = [key for key in keys if key not in self.caching_repo]
  378. repository_iterator = zip(unknown_keys, self.repository.get_many(unknown_keys))
  379. for key in keys:
  380. try:
  381. yield self.caching_repo.get(key)
  382. except Repository.ObjectNotFound:
  383. for key_, data in repository_iterator:
  384. if key_ == key:
  385. if len(data) <= self.THRESHOLD:
  386. self.caching_repo.put(key, data)
  387. yield data
  388. break
  389. # Consume any pending requests
  390. for _ in repository_iterator:
  391. pass
  392. def cache_if_remote(repository):
  393. if isinstance(repository, RemoteRepository):
  394. return RepositoryCache(repository)
  395. else:
  396. return RepositoryNoCache(repository)