helpers.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450
  1. import argparse
  2. import binascii
  3. import grp
  4. import msgpack
  5. import os
  6. import pwd
  7. import re
  8. import stat
  9. import sys
  10. import time
  11. from datetime import datetime, timezone
  12. from fnmatch import fnmatchcase
  13. from operator import attrgetter
  14. import fcntl
  15. class Error(Exception):
  16. """Error base class"""
  17. exit_code = 1
  18. def get_message(self):
  19. return 'Error: ' + type(self).__doc__.format(*self.args)
  20. class UpgradableLock:
  21. class LockUpgradeFailed(Error):
  22. """Failed to acquire write lock on {}"""
  23. def __init__(self, path, exclusive=False):
  24. self.path = path
  25. try:
  26. self.fd = open(path, 'r+')
  27. except IOError:
  28. self.fd = open(path, 'r')
  29. if exclusive:
  30. fcntl.lockf(self.fd, fcntl.LOCK_EX)
  31. else:
  32. fcntl.lockf(self.fd, fcntl.LOCK_SH)
  33. self.is_exclusive = exclusive
  34. def upgrade(self):
  35. try:
  36. fcntl.lockf(self.fd, fcntl.LOCK_EX)
  37. except OSError as e:
  38. raise self.LockUpgradeFailed(self.path)
  39. self.is_exclusive = True
  40. def release(self):
  41. fcntl.lockf(self.fd, fcntl.LOCK_UN)
  42. self.fd.close()
  43. class Manifest:
  44. MANIFEST_ID = b'\0' * 32
  45. def __init__(self):
  46. self.archives = {}
  47. self.config = {}
  48. @classmethod
  49. def load(cls, repository):
  50. from .key import key_factory
  51. manifest = cls()
  52. manifest.repository = repository
  53. cdata = repository.get(manifest.MANIFEST_ID)
  54. manifest.key = key = key_factory(repository, cdata)
  55. data = key.decrypt(None, cdata)
  56. manifest.id = key.id_hash(data)
  57. m = msgpack.unpackb(data)
  58. if not m.get(b'version') == 1:
  59. raise ValueError('Invalid manifest version')
  60. manifest.archives = dict((k.decode('utf-8'), v) for k,v in m[b'archives'].items())
  61. manifest.timestamp = m.get(b'timestamp')
  62. if manifest.timestamp:
  63. manifest.timestamp = manifest.timestamp.decode('ascii')
  64. manifest.config = m[b'config']
  65. return manifest, key
  66. def write(self):
  67. self.timestamp = datetime.utcnow().isoformat()
  68. data = msgpack.packb({
  69. 'version': 1,
  70. 'archives': self.archives,
  71. 'timestamp': self.timestamp,
  72. 'config': self.config,
  73. })
  74. self.id = self.key.id_hash(data)
  75. self.repository.put(self.MANIFEST_ID, self.key.encrypt(data))
  76. def prune_split(archives, pattern, n, skip=[]):
  77. items = {}
  78. keep = []
  79. for a in archives:
  80. key = to_localtime(a.ts).strftime(pattern)
  81. items.setdefault(key, [])
  82. items[key].append(a)
  83. for key, values in sorted(items.items(), reverse=True):
  84. if n and values[0] not in skip:
  85. values.sort(key=attrgetter('ts'), reverse=True)
  86. keep.append(values[0])
  87. n -= 1
  88. return keep
  89. class Statistics:
  90. def __init__(self):
  91. self.osize = self.csize = self.usize = self.nfiles = 0
  92. def update(self, size, csize, unique):
  93. self.osize += size
  94. self.csize += csize
  95. if unique:
  96. self.usize += csize
  97. def print_(self):
  98. print('Number of files: %d' % self.nfiles)
  99. print('Original size: %d (%s)' % (self.osize, format_file_size(self.osize)))
  100. print('Compressed size: %s (%s)' % (self.csize, format_file_size(self.csize)))
  101. print('Unique data: %d (%s)' % (self.usize, format_file_size(self.usize)))
  102. def get_keys_dir():
  103. """Determine where to repository keys and cache"""
  104. return os.environ.get('ATTIC_KEYS_DIR',
  105. os.path.join(os.path.expanduser('~'), '.attic', 'keys'))
  106. def get_cache_dir():
  107. """Determine where to repository keys and cache"""
  108. return os.environ.get('ATTIC_CACHE_DIR',
  109. os.path.join(os.path.expanduser('~'), '.cache', 'attic'))
  110. def to_localtime(ts):
  111. """Convert datetime object from UTC to local time zone"""
  112. return datetime(*time.localtime((ts - datetime(1970, 1, 1, tzinfo=timezone.utc)).total_seconds())[:6])
  113. def adjust_patterns(paths, excludes):
  114. if paths:
  115. return (excludes or []) + [IncludePattern(path) for path in paths] + [ExcludePattern('*')]
  116. else:
  117. return excludes
  118. def exclude_path(path, patterns):
  119. """Used by create and extract sub-commands to determine
  120. if an item should be processed or not
  121. """
  122. for pattern in (patterns or []):
  123. if pattern.match(path):
  124. return isinstance(pattern, ExcludePattern)
  125. return False
  126. class IncludePattern:
  127. """--include PATTERN
  128. """
  129. def __init__(self, pattern):
  130. self.pattern = pattern
  131. def match(self, path):
  132. dir, name = os.path.split(path)
  133. return (path == self.pattern
  134. or (dir + os.path.sep).startswith(self.pattern))
  135. def __repr__(self):
  136. return '%s(%s)' % (type(self), self.pattern)
  137. class ExcludePattern(IncludePattern):
  138. """
  139. """
  140. def __init__(self, pattern):
  141. self.pattern = self.dirpattern = pattern
  142. if not pattern.endswith(os.path.sep):
  143. self.dirpattern += os.path.sep
  144. def match(self, path):
  145. dir, name = os.path.split(path)
  146. return (path == self.pattern
  147. or (dir + os.path.sep).startswith(self.dirpattern)
  148. or fnmatchcase(name, self.pattern))
  149. def __repr__(self):
  150. return '%s(%s)' % (type(self), self.pattern)
  151. def walk_path(path, skip_inodes=None):
  152. st = os.lstat(path)
  153. if skip_inodes and (st.st_ino, st.st_dev) in skip_inodes:
  154. return
  155. yield path, st
  156. if stat.S_ISDIR(st.st_mode):
  157. for f in os.listdir(path):
  158. for x in walk_path(os.path.join(path, f), skip_inodes):
  159. yield x
  160. def format_time(t):
  161. """Format datetime suitable for fixed length list output
  162. """
  163. if (datetime.now() - t).days < 365:
  164. return t.strftime('%b %d %H:%M')
  165. else:
  166. return t.strftime('%b %d %Y')
  167. def format_timedelta(td):
  168. """Format timedelta in a human friendly format
  169. """
  170. # Since td.total_seconds() requires python 2.7
  171. ts = (td.microseconds + (td.seconds + td.days * 24 * 3600) * 10 ** 6) / float(10 ** 6)
  172. s = ts % 60
  173. m = int(ts / 60) % 60
  174. h = int(ts / 3600) % 24
  175. txt = '%.2f seconds' % s
  176. if m:
  177. txt = '%d minutes %s' % (m, txt)
  178. if h:
  179. txt = '%d hours %s' % (h, txt)
  180. if td.days:
  181. txt = '%d days %s' % (td.days, txt)
  182. return txt
  183. def format_file_mode(mod):
  184. """Format file mode bits for list output
  185. """
  186. def x(v):
  187. return ''.join(v & m and s or '-'
  188. for m, s in ((4, 'r'), (2, 'w'), (1, 'x')))
  189. return '%s%s%s' % (x(mod // 64), x(mod // 8), x(mod))
  190. def format_file_size(v):
  191. """Format file size into a human friendly format
  192. """
  193. if v > 1024 * 1024 * 1024:
  194. return '%.2f GB' % (v / 1024. / 1024. / 1024.)
  195. elif v > 1024 * 1024:
  196. return '%.2f MB' % (v / 1024. / 1024.)
  197. elif v > 1024:
  198. return '%.2f kB' % (v / 1024.)
  199. else:
  200. return '%d B' % v
  201. class IntegrityError(Exception):
  202. """
  203. """
  204. def memoize(function):
  205. cache = {}
  206. def decorated_function(*args):
  207. try:
  208. return cache[args]
  209. except KeyError:
  210. val = function(*args)
  211. cache[args] = val
  212. return val
  213. return decorated_function
  214. @memoize
  215. def uid2user(uid):
  216. try:
  217. return pwd.getpwuid(uid).pw_name
  218. except KeyError:
  219. return None
  220. @memoize
  221. def user2uid(user):
  222. try:
  223. return user and pwd.getpwnam(user).pw_uid
  224. except KeyError:
  225. return None
  226. @memoize
  227. def gid2group(gid):
  228. try:
  229. return grp.getgrgid(gid).gr_name
  230. except KeyError:
  231. return None
  232. @memoize
  233. def group2gid(group):
  234. try:
  235. return group and grp.getgrnam(group).gr_gid
  236. except KeyError:
  237. return None
  238. class Location:
  239. """Object representing a repository / archive location
  240. """
  241. proto = user = host = port = path = archive = None
  242. ssh_re = re.compile(r'(?P<proto>ssh)://(?:(?P<user>[^@]+)@)?'
  243. r'(?P<host>[^:/#]+)(?::(?P<port>\d+))?'
  244. r'(?P<path>[^:]+)(?:::(?P<archive>.+))?')
  245. file_re = re.compile(r'(?P<proto>file)://'
  246. r'(?P<path>[^:]+)(?:::(?P<archive>.+))?')
  247. scp_re = re.compile(r'((?:(?P<user>[^@]+)@)?(?P<host>[^:/]+):)?'
  248. r'(?P<path>[^:]+)(?:::(?P<archive>.+))?')
  249. def __init__(self, text):
  250. self.orig = text
  251. if not self.parse(text):
  252. raise ValueError
  253. def parse(self, text):
  254. m = self.ssh_re.match(text)
  255. if m:
  256. self.proto = m.group('proto')
  257. self.user = m.group('user')
  258. self.host = m.group('host')
  259. self.port = m.group('port') and int(m.group('port')) or None
  260. self.path = m.group('path')
  261. self.archive = m.group('archive')
  262. return True
  263. m = self.file_re.match(text)
  264. if m:
  265. self.proto = m.group('proto')
  266. self.path = m.group('path')
  267. self.archive = m.group('archive')
  268. return True
  269. m = self.scp_re.match(text)
  270. if m:
  271. self.user = m.group('user')
  272. self.host = m.group('host')
  273. self.path = m.group('path')
  274. self.archive = m.group('archive')
  275. self.proto = self.host and 'ssh' or 'file'
  276. return True
  277. return False
  278. def __str__(self):
  279. items = []
  280. items.append('proto=%r' % self.proto)
  281. items.append('user=%r' % self.user)
  282. items.append('host=%r' % self.host)
  283. items.append('port=%r' % self.port)
  284. items.append('path=%r' % self.path)
  285. items.append('archive=%r' % self.archive)
  286. return ', '.join(items)
  287. def to_key_filename(self):
  288. name = re.sub('[^\w]', '_', self.path).strip('_')
  289. if self.proto != 'file':
  290. name = self.host + '__' + name
  291. return os.path.join(get_keys_dir(), name)
  292. def __repr__(self):
  293. return "Location(%s)" % self
  294. def location_validator(archive=None):
  295. def validator(text):
  296. try:
  297. loc = Location(text)
  298. except ValueError:
  299. raise argparse.ArgumentTypeError('Invalid location format: "%s"' % text)
  300. if archive is True and not loc.archive:
  301. raise argparse.ArgumentTypeError('"%s": No archive specified' % text)
  302. elif archive is False and loc.archive:
  303. raise argparse.ArgumentTypeError('"%s" No archive can be specified' % text)
  304. return loc
  305. return validator
  306. def read_msgpack(filename):
  307. with open(filename, 'rb') as fd:
  308. return msgpack.unpack(fd)
  309. def write_msgpack(filename, d):
  310. with open(filename + '.tmp', 'wb') as fd:
  311. msgpack.pack(d, fd)
  312. fd.flush()
  313. os.fsync(fd)
  314. os.rename(filename + '.tmp', filename)
  315. def decode_dict(d, keys, encoding='utf-8', errors='surrogateescape'):
  316. for key in keys:
  317. if isinstance(d.get(key), bytes):
  318. d[key] = d[key].decode(encoding, errors)
  319. return d
  320. def remove_surrogates(s, errors='replace'):
  321. """Replace surrogates generated by fsdecode with '?'
  322. """
  323. return s.encode('utf-8', errors).decode('utf-8')
  324. _safe_re = re.compile('^((..)?/+)+')
  325. def make_path_safe(path):
  326. """Make path safe by making it relative and local
  327. """
  328. return _safe_re.sub('', path) or '.'
  329. def daemonize():
  330. """Detach process from controlling terminal and run in background
  331. """
  332. pid = os.fork()
  333. if pid:
  334. os._exit(0)
  335. os.setsid()
  336. pid = os.fork()
  337. if pid:
  338. os._exit(0)
  339. os.chdir('/')
  340. os.close(0)
  341. os.close(1)
  342. os.close(2)
  343. fd = os.open('/dev/null', os.O_RDWR)
  344. os.dup2(fd, 0)
  345. os.dup2(fd, 1)
  346. os.dup2(fd, 2)
  347. if sys.version < '3.3':
  348. # st_mtime_ns attribute only available in 3.3+
  349. def st_mtime_ns(st):
  350. return int(st.st_mtime * 1e9)
  351. # unhexlify in < 3.3 incorrectly only accepts bytes input
  352. def unhexlify(data):
  353. if isinstance(data, str):
  354. data = data.encode('ascii')
  355. return binascii.unhexlify(data)
  356. else:
  357. def st_mtime_ns(st):
  358. return st.st_mtime_ns
  359. unhexlify = binascii.unhexlify