| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399 | import argparseimport binasciiimport grpimport msgpackimport osimport pwdimport reimport statimport sysimport timefrom datetime import datetime, timedeltafrom fnmatch import fnmatchcasefrom operator import attrgetterclass Manifest:    MANIFEST_ID = b'\0' * 32    def __init__(self):        self.archives = {}        self.config = {}    @classmethod    def load(cls, repository):        from .key import key_factory        manifest = cls()        manifest.repository = repository        cdata = repository.get(manifest.MANIFEST_ID)        manifest.key = key = key_factory(repository, cdata)        data = key.decrypt(None, cdata)        manifest.id = key.id_hash(data)        m = msgpack.unpackb(data)        if not m.get(b'version') == 1:            raise ValueError('Invalid manifest version')        manifest.archives = dict((k.decode('utf-8'), v) for k,v in m[b'archives'].items())        manifest.config = m[b'config']        return manifest, key    def write(self):        data = msgpack.packb({            'version': 1,            'archives': self.archives,            'config': self.config,        })        self.id = self.key.id_hash(data)        self.repository.put(self.MANIFEST_ID, self.key.encrypt(data))def prune_split(archives, pattern, n, skip=[]):    items = {}    keep = []    for a in archives:        key = to_localtime(a.ts).strftime(pattern)        items.setdefault(key, [])        items[key].append(a)    for key, values in sorted(items.items(), reverse=True):        if n and values[0] not in skip:            values.sort(key=attrgetter('ts'), reverse=True)            keep.append(values[0])            n -= 1    return keepclass Statistics:    def __init__(self):        self.osize = self.csize = self.usize = self.nfiles = 0    def update(self, size, csize, unique):        self.osize += size        self.csize += csize        if unique:            self.usize += csize    def print_(self):        print('Number of files: %d' % self.nfiles)        print('Original size: %d (%s)' % (self.osize, format_file_size(self.osize)))        print('Compressed size: %s (%s)' % (self.csize, format_file_size(self.csize)))        print('Unique data: %d (%s)' % (self.usize, format_file_size(self.usize)))def get_keys_dir():    """Determine where to repository keys and cache"""    return os.environ.get('ATTIC_KEYS_DIR',                          os.path.join(os.path.expanduser('~'), '.attic', 'keys'))def get_cache_dir():    """Determine where to repository keys and cache"""    return os.environ.get('ATTIC_CACHE_DIR',                          os.path.join(os.path.expanduser('~'), '.cache', 'attic'))def to_localtime(ts):    """Convert datetime object from UTC to local time zone"""    return ts - timedelta(seconds=time.altzone)def adjust_patterns(paths, excludes):    if paths:        return (excludes or []) + [IncludePattern(path) for path in paths] + [ExcludePattern('*')]    else:        return excludesdef exclude_path(path, patterns):    """Used by create and extract sub-commands to determine    if an item should be processed or not    """    for pattern in (patterns or []):        if pattern.match(path):            return isinstance(pattern, ExcludePattern)    return Falseclass IncludePattern:    """--include PATTERN    """    def __init__(self, pattern):        self.pattern = pattern    def match(self, path):        dir, name = os.path.split(path)        return (path == self.pattern                or (dir + os.path.sep).startswith(self.pattern))    def __repr__(self):        return '%s(%s)' % (type(self), self.pattern)class ExcludePattern(IncludePattern):    """    """    def __init__(self, pattern):        self.pattern = self.dirpattern = pattern        if not pattern.endswith(os.path.sep):            self.dirpattern += os.path.sep    def match(self, path):        dir, name = os.path.split(path)        return (path == self.pattern                or (dir + os.path.sep).startswith(self.dirpattern)                or fnmatchcase(name, self.pattern))    def __repr__(self):        return '%s(%s)' % (type(self), self.pattern)def walk_path(path, skip_inodes=None):    st = os.lstat(path)    if skip_inodes and (st.st_ino, st.st_dev) in skip_inodes:        return    yield path, st    if stat.S_ISDIR(st.st_mode):        for f in os.listdir(path):            for x in walk_path(os.path.join(path, f), skip_inodes):                yield xdef format_time(t):    """Format datetime suitable for fixed length list output    """    if (datetime.now() - t).days < 365:        return t.strftime('%b %d %H:%M')    else:        return t.strftime('%b %d  %Y')def format_timedelta(td):    """Format timedelta in a human friendly format    """    # Since td.total_seconds() requires python 2.7    ts = (td.microseconds + (td.seconds + td.days * 24 * 3600) * 10 ** 6) / float(10 ** 6)    s = ts % 60    m = int(ts / 60) % 60    h = int(ts / 3600) % 24    txt = '%.2f seconds' % s    if m:        txt = '%d minutes %s' % (m, txt)    if h:        txt = '%d hours %s' % (h, txt)    if td.days:        txt = '%d days %s' % (td.days, txt)    return txtdef format_file_mode(mod):    """Format file mode bits for list output    """    def x(v):        return ''.join(v & m and s or '-'                       for m, s in ((4, 'r'), (2, 'w'), (1, 'x')))    return '%s%s%s' % (x(mod // 64), x(mod // 8), x(mod))def format_file_size(v):    """Format file size into a human friendly format    """    if v > 1024 * 1024 * 1024:        return '%.2f GB' % (v / 1024. / 1024. / 1024.)    elif v > 1024 * 1024:        return '%.2f MB' % (v / 1024. / 1024.)    elif v > 1024:        return '%.2f kB' % (v / 1024.)    else:        return '%d B' % vclass IntegrityError(Exception):    """    """def memoize(function):    cache = {}    def decorated_function(*args):        try:            return cache[args]        except KeyError:            val = function(*args)            cache[args] = val            return val    return decorated_function@memoizedef uid2user(uid):    try:        return pwd.getpwuid(uid).pw_name    except KeyError:        return None@memoizedef user2uid(user):    try:        return user and pwd.getpwnam(user).pw_uid    except KeyError:        return None@memoizedef gid2group(gid):    try:        return grp.getgrgid(gid).gr_name    except KeyError:        return None@memoizedef group2gid(group):    try:        return group and grp.getgrnam(group).gr_gid    except KeyError:        return Noneclass Location:    """Object representing a repository / archive location    """    proto = user = host = port = path = archive = None    ssh_re = re.compile(r'(?P<proto>ssh)://(?:(?P<user>[^@]+)@)?'                        r'(?P<host>[^:/#]+)(?::(?P<port>\d+))?'                        r'(?P<path>[^:]+)(?:::(?P<archive>.+))?')    file_re = re.compile(r'(?P<proto>file)://'                         r'(?P<path>[^:]+)(?:::(?P<archive>.+))?')    scp_re = re.compile(r'((?:(?P<user>[^@]+)@)?(?P<host>[^:/]+):)?'                        r'(?P<path>[^:]+)(?:::(?P<archive>.+))?')    def __init__(self, text):        self.orig = text        if not self.parse(text):            raise ValueError    def parse(self, text):        m = self.ssh_re.match(text)        if m:            self.proto = m.group('proto')            self.user = m.group('user')            self.host = m.group('host')            self.port = m.group('port') and int(m.group('port')) or 22            self.path = m.group('path')            self.archive = m.group('archive')            return True        m = self.file_re.match(text)        if m:            self.proto = m.group('proto')            self.path = m.group('path')            self.archive = m.group('archive')            return True        m = self.scp_re.match(text)        if m:            self.user = m.group('user')            self.host = m.group('host')            self.path = m.group('path')            self.archive = m.group('archive')            self.proto = self.host and 'ssh' or 'file'            if self.proto == 'ssh':                self.port = 22            return True        return False    def __str__(self):        items = []        items.append('proto=%r' % self.proto)        items.append('user=%r' % self.user)        items.append('host=%r' % self.host)        items.append('port=%r' % self.port)        items.append('path=%r' % self.path)        items.append('archive=%r' % self.archive)        return ', '.join(items)    def to_key_filename(self):        name = re.sub('[^\w]', '_', self.path).strip('_')        if self.proto != 'file':            name = self.host + '__' + name        return os.path.join(get_keys_dir(), name)    def __repr__(self):        return "Location(%s)" % selfdef location_validator(archive=None):    def validator(text):        try:            loc = Location(text)        except ValueError:            raise argparse.ArgumentTypeError('Invalid location format: "%s"' % text)        if archive is True and not loc.archive:            raise argparse.ArgumentTypeError('"%s": No archive specified' % text)        elif archive is False and loc.archive:            raise argparse.ArgumentTypeError('"%s" No archive can be specified' % text)        return loc    return validatordef read_msgpack(filename):    with open(filename, 'rb') as fd:        return msgpack.unpack(fd)def write_msgpack(filename, d):    with open(filename + '.tmp', 'wb') as fd:        msgpack.pack(d, fd)        fd.flush()        os.fsync(fd)    os.rename(filename + '.tmp', filename)def decode_dict(d, keys, encoding='utf-8', errors='surrogateescape'):    for key in keys:        if isinstance(d.get(key), bytes):            d[key] = d[key].decode(encoding, errors)    return ddef remove_surrogates(s, errors='replace'):    """Replace surrogates generated by fsdecode with '?'    """    return s.encode('utf-8', errors).decode('utf-8')def daemonize():    """Detach process from controlling terminal and run in background    """    pid = os.fork()    if pid:        os._exit(0)    os.setsid()    pid = os.fork()    if pid:        os._exit(0)    os.chdir('/')    os.close(0)    os.close(1)    os.close(2)    fd = os.open('/dev/null', os.O_RDWR)    os.dup2(fd, 0)    os.dup2(fd, 1)    os.dup2(fd, 2)if sys.version < '3.3':    # st_mtime_ns attribute only available in 3.3+    def st_mtime_ns(st):        return int(st.st_mtime * 10**9)    # unhexlify in < 3.3 incorrectly only accepts bytes input    def unhexlify(data):        if isinstance(data, str):            data = data.encode('ascii')        return binascii.unhexlify(data)else:    def st_mtime_ns(st):        return st.st_mtime_ns    unhexlify = binascii.unhexlify
 |