2
0

remote.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. import errno
  2. import fcntl
  3. import logging
  4. import os
  5. import select
  6. import shlex
  7. from subprocess import Popen, PIPE
  8. import sys
  9. import tempfile
  10. import traceback
  11. from . import __version__
  12. from .helpers import Error, IntegrityError, sysinfo
  13. from .repository import Repository
  14. import msgpack
  15. RPC_PROTOCOL_VERSION = 2
  16. BUFSIZE = 10 * 1024 * 1024
  17. class ConnectionClosed(Error):
  18. """Connection closed by remote host"""
  19. class ConnectionClosedWithHint(ConnectionClosed):
  20. """Connection closed by remote host. {}"""
  21. class PathNotAllowed(Error):
  22. """Repository path not allowed"""
  23. class InvalidRPCMethod(Error):
  24. """RPC method {} is not valid"""
  25. class RepositoryServer: # pragma: no cover
  26. rpc_methods = (
  27. '__len__',
  28. 'check',
  29. 'commit',
  30. 'delete',
  31. 'destroy',
  32. 'get',
  33. 'list',
  34. 'negotiate',
  35. 'open',
  36. 'put',
  37. 'repair',
  38. 'rollback',
  39. 'save_key',
  40. 'load_key',
  41. 'break_lock',
  42. )
  43. def __init__(self, restrict_to_paths):
  44. self.repository = None
  45. self.restrict_to_paths = restrict_to_paths
  46. def serve(self):
  47. stdin_fd = sys.stdin.fileno()
  48. stdout_fd = sys.stdout.fileno()
  49. stderr_fd = sys.stdout.fileno()
  50. # Make stdin non-blocking
  51. fl = fcntl.fcntl(stdin_fd, fcntl.F_GETFL)
  52. fcntl.fcntl(stdin_fd, fcntl.F_SETFL, fl | os.O_NONBLOCK)
  53. # Make stdout blocking
  54. fl = fcntl.fcntl(stdout_fd, fcntl.F_GETFL)
  55. fcntl.fcntl(stdout_fd, fcntl.F_SETFL, fl & ~os.O_NONBLOCK)
  56. # Make stderr blocking
  57. fl = fcntl.fcntl(stderr_fd, fcntl.F_GETFL)
  58. fcntl.fcntl(stderr_fd, fcntl.F_SETFL, fl & ~os.O_NONBLOCK)
  59. unpacker = msgpack.Unpacker(use_list=False)
  60. while True:
  61. r, w, es = select.select([stdin_fd], [], [], 10)
  62. if r:
  63. data = os.read(stdin_fd, BUFSIZE)
  64. if not data:
  65. return
  66. unpacker.feed(data)
  67. for unpacked in unpacker:
  68. if not (isinstance(unpacked, tuple) and len(unpacked) == 4):
  69. raise Exception("Unexpected RPC data format.")
  70. type, msgid, method, args = unpacked
  71. method = method.decode('ascii')
  72. try:
  73. if method not in self.rpc_methods:
  74. raise InvalidRPCMethod(method)
  75. try:
  76. f = getattr(self, method)
  77. except AttributeError:
  78. f = getattr(self.repository, method)
  79. res = f(*args)
  80. except BaseException as e:
  81. logging.exception('Borg %s: exception in RPC call:', __version__)
  82. logging.error(sysinfo())
  83. exc = "Remote Exception (see remote log for the traceback)"
  84. os.write(stdout_fd, msgpack.packb((1, msgid, e.__class__.__name__, exc)))
  85. else:
  86. os.write(stdout_fd, msgpack.packb((1, msgid, None, res)))
  87. if es:
  88. return
  89. def negotiate(self, versions):
  90. return RPC_PROTOCOL_VERSION
  91. def open(self, path, create=False, lock_wait=None, lock=True):
  92. path = os.fsdecode(path)
  93. if path.startswith('/~'):
  94. path = path[1:]
  95. path = os.path.realpath(os.path.expanduser(path))
  96. if self.restrict_to_paths:
  97. for restrict_to_path in self.restrict_to_paths:
  98. if path.startswith(os.path.realpath(restrict_to_path)):
  99. break
  100. else:
  101. raise PathNotAllowed(path)
  102. self.repository = Repository(path, create, lock_wait=lock_wait, lock=lock)
  103. return self.repository.id
  104. class RemoteRepository:
  105. extra_test_args = []
  106. class RPCError(Exception):
  107. def __init__(self, name):
  108. self.name = name
  109. def __init__(self, location, create=False, lock_wait=None, lock=True, args=None):
  110. self.location = location
  111. self.preload_ids = []
  112. self.msgid = 0
  113. self.to_send = b''
  114. self.cache = {}
  115. self.ignore_responses = set()
  116. self.responses = {}
  117. self.unpacker = msgpack.Unpacker(use_list=False)
  118. self.p = None
  119. testing = location.host == '__testsuite__'
  120. borg_cmd = self.borg_cmd(args, testing)
  121. if not testing:
  122. borg_cmd = self.ssh_cmd(location) + borg_cmd
  123. self.p = Popen(borg_cmd, bufsize=0, stdin=PIPE, stdout=PIPE, stderr=PIPE)
  124. self.stdin_fd = self.p.stdin.fileno()
  125. self.stdout_fd = self.p.stdout.fileno()
  126. self.stderr_fd = self.p.stderr.fileno()
  127. fcntl.fcntl(self.stdin_fd, fcntl.F_SETFL, fcntl.fcntl(self.stdin_fd, fcntl.F_GETFL) | os.O_NONBLOCK)
  128. fcntl.fcntl(self.stdout_fd, fcntl.F_SETFL, fcntl.fcntl(self.stdout_fd, fcntl.F_GETFL) | os.O_NONBLOCK)
  129. fcntl.fcntl(self.stderr_fd, fcntl.F_SETFL, fcntl.fcntl(self.stderr_fd, fcntl.F_GETFL) | os.O_NONBLOCK)
  130. self.r_fds = [self.stdout_fd, self.stderr_fd]
  131. self.x_fds = [self.stdin_fd, self.stdout_fd, self.stderr_fd]
  132. try:
  133. version = self.call('negotiate', RPC_PROTOCOL_VERSION)
  134. except ConnectionClosed:
  135. raise ConnectionClosedWithHint('Is borg working on the server?')
  136. if version != RPC_PROTOCOL_VERSION:
  137. raise Exception('Server insisted on using unsupported protocol version %d' % version)
  138. self.id = self.call('open', location.path, create, lock_wait, lock)
  139. def __del__(self):
  140. self.close()
  141. def __repr__(self):
  142. return '<%s %s>' % (self.__class__.__name__, self.location.canonical_path())
  143. def borg_cmd(self, args, testing):
  144. """return a borg serve command line"""
  145. # give some args/options to "borg serve" process as they were given to us
  146. opts = []
  147. if args is not None:
  148. opts.append('--umask=%03o' % args.umask)
  149. root_logger = logging.getLogger()
  150. if root_logger.isEnabledFor(logging.DEBUG):
  151. opts.append('--debug')
  152. elif root_logger.isEnabledFor(logging.INFO):
  153. opts.append('--info')
  154. elif root_logger.isEnabledFor(logging.WARNING):
  155. pass # warning is default
  156. else:
  157. raise ValueError('log level missing, fix this code')
  158. if testing:
  159. return [sys.executable, '-m', 'borg.archiver', 'serve' ] + opts + self.extra_test_args
  160. else: # pragma: no cover
  161. return [args.remote_path, 'serve'] + opts
  162. def ssh_cmd(self, location):
  163. """return a ssh command line that can be prefixed to a borg command line"""
  164. args = shlex.split(os.environ.get('BORG_RSH', 'ssh'))
  165. if location.port:
  166. args += ['-p', str(location.port)]
  167. if location.user:
  168. args.append('%s@%s' % (location.user, location.host))
  169. else:
  170. args.append('%s' % location.host)
  171. return args
  172. def call(self, cmd, *args, **kw):
  173. for resp in self.call_many(cmd, [args], **kw):
  174. return resp
  175. def call_many(self, cmd, calls, wait=True, is_preloaded=False):
  176. if not calls:
  177. return
  178. def fetch_from_cache(args):
  179. msgid = self.cache[args].pop(0)
  180. if not self.cache[args]:
  181. del self.cache[args]
  182. return msgid
  183. calls = list(calls)
  184. waiting_for = []
  185. w_fds = [self.stdin_fd]
  186. while wait or calls:
  187. while waiting_for:
  188. try:
  189. error, res = self.responses.pop(waiting_for[0])
  190. waiting_for.pop(0)
  191. if error:
  192. if error == b'DoesNotExist':
  193. raise Repository.DoesNotExist(self.location.orig)
  194. elif error == b'AlreadyExists':
  195. raise Repository.AlreadyExists(self.location.orig)
  196. elif error == b'CheckNeeded':
  197. raise Repository.CheckNeeded(self.location.orig)
  198. elif error == b'IntegrityError':
  199. raise IntegrityError(res)
  200. elif error == b'PathNotAllowed':
  201. raise PathNotAllowed(*res)
  202. elif error == b'ObjectNotFound':
  203. raise Repository.ObjectNotFound(res[0], self.location.orig)
  204. elif error == b'InvalidRPCMethod':
  205. raise InvalidRPCMethod(*res)
  206. else:
  207. raise self.RPCError(res.decode('utf-8'))
  208. else:
  209. yield res
  210. if not waiting_for and not calls:
  211. return
  212. except KeyError:
  213. break
  214. r, w, x = select.select(self.r_fds, w_fds, self.x_fds, 1)
  215. if x:
  216. raise Exception('FD exception occurred')
  217. for fd in r:
  218. if fd is self.stdout_fd:
  219. data = os.read(fd, BUFSIZE)
  220. if not data:
  221. raise ConnectionClosed()
  222. self.unpacker.feed(data)
  223. for unpacked in self.unpacker:
  224. if not (isinstance(unpacked, tuple) and len(unpacked) == 4):
  225. raise Exception("Unexpected RPC data format.")
  226. type, msgid, error, res = unpacked
  227. if msgid in self.ignore_responses:
  228. self.ignore_responses.remove(msgid)
  229. else:
  230. self.responses[msgid] = error, res
  231. elif fd is self.stderr_fd:
  232. data = os.read(fd, 32768)
  233. if not data:
  234. raise ConnectionClosed()
  235. data = data.decode('utf-8')
  236. for line in data.splitlines(True): # keepends=True for py3.3+
  237. if line.startswith('$LOG '):
  238. _, level, msg = line.split(' ', 2)
  239. level = getattr(logging, level, logging.CRITICAL) # str -> int
  240. logging.log(level, msg.rstrip())
  241. else:
  242. sys.stderr.write("Remote: " + line)
  243. if w:
  244. while not self.to_send and (calls or self.preload_ids) and len(waiting_for) < 100:
  245. if calls:
  246. if is_preloaded:
  247. if calls[0] in self.cache:
  248. waiting_for.append(fetch_from_cache(calls.pop(0)))
  249. else:
  250. args = calls.pop(0)
  251. if cmd == 'get' and args in self.cache:
  252. waiting_for.append(fetch_from_cache(args))
  253. else:
  254. self.msgid += 1
  255. waiting_for.append(self.msgid)
  256. self.to_send = msgpack.packb((1, self.msgid, cmd, args))
  257. if not self.to_send and self.preload_ids:
  258. args = (self.preload_ids.pop(0),)
  259. self.msgid += 1
  260. self.cache.setdefault(args, []).append(self.msgid)
  261. self.to_send = msgpack.packb((1, self.msgid, cmd, args))
  262. if self.to_send:
  263. try:
  264. self.to_send = self.to_send[os.write(self.stdin_fd, self.to_send):]
  265. except OSError as e:
  266. # io.write might raise EAGAIN even though select indicates
  267. # that the fd should be writable
  268. if e.errno != errno.EAGAIN:
  269. raise
  270. if not self.to_send and not (calls or self.preload_ids):
  271. w_fds = []
  272. self.ignore_responses |= set(waiting_for)
  273. def check(self, repair=False, save_space=False):
  274. return self.call('check', repair, save_space)
  275. def commit(self, save_space=False):
  276. return self.call('commit', save_space)
  277. def rollback(self, *args):
  278. return self.call('rollback')
  279. def destroy(self):
  280. return self.call('destroy')
  281. def __len__(self):
  282. return self.call('__len__')
  283. def list(self, limit=None, marker=None):
  284. return self.call('list', limit, marker)
  285. def get(self, id_):
  286. for resp in self.get_many([id_]):
  287. return resp
  288. def get_many(self, ids, is_preloaded=False):
  289. for resp in self.call_many('get', [(id_,) for id_ in ids], is_preloaded=is_preloaded):
  290. yield resp
  291. def put(self, id_, data, wait=True):
  292. return self.call('put', id_, data, wait=wait)
  293. def delete(self, id_, wait=True):
  294. return self.call('delete', id_, wait=wait)
  295. def save_key(self, keydata):
  296. return self.call('save_key', keydata)
  297. def load_key(self):
  298. return self.call('load_key')
  299. def break_lock(self):
  300. return self.call('break_lock')
  301. def close(self):
  302. if self.p:
  303. self.p.stdin.close()
  304. self.p.stdout.close()
  305. self.p.wait()
  306. self.p = None
  307. def preload(self, ids):
  308. self.preload_ids += ids
  309. class RepositoryCache:
  310. """A caching Repository wrapper
  311. Caches Repository GET operations using a local temporary Repository.
  312. """
  313. def __init__(self, repository):
  314. self.repository = repository
  315. tmppath = tempfile.mkdtemp(prefix='borg-tmp')
  316. self.caching_repo = Repository(tmppath, create=True, exclusive=True)
  317. def __del__(self):
  318. self.caching_repo.destroy()
  319. def get(self, key):
  320. return next(self.get_many([key]))
  321. def get_many(self, keys):
  322. unknown_keys = [key for key in keys if key not in self.caching_repo]
  323. repository_iterator = zip(unknown_keys, self.repository.get_many(unknown_keys))
  324. for key in keys:
  325. try:
  326. yield self.caching_repo.get(key)
  327. except Repository.ObjectNotFound:
  328. for key_, data in repository_iterator:
  329. if key_ == key:
  330. self.caching_repo.put(key, data)
  331. yield data
  332. break
  333. # Consume any pending requests
  334. for _ in repository_iterator:
  335. pass
  336. def cache_if_remote(repository):
  337. if isinstance(repository, RemoteRepository):
  338. return RepositoryCache(repository)
  339. return repository