| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271 | import fcntlimport msgpackimport osimport selectfrom subprocess import Popen, PIPEimport sysimport getpassfrom .repository import Repositoryfrom .lrucache import LRUCacheBUFSIZE = 10 * 1024 * 1024class ConnectionClosed(Exception):    """Connection closed by remote host    """class RepositoryServer(object):    def __init__(self):        self.repository = None    def serve(self):        # Make stdin non-blocking        fl = fcntl.fcntl(sys.stdin.fileno(), fcntl.F_GETFL)        fcntl.fcntl(sys.stdin.fileno(), fcntl.F_SETFL, fl | os.O_NONBLOCK)        # Make stdout blocking        fl = fcntl.fcntl(sys.stdout.fileno(), fcntl.F_GETFL)        fcntl.fcntl(sys.stdout.fileno(), fcntl.F_SETFL, fl & ~os.O_NONBLOCK)        unpacker = msgpack.Unpacker(use_list=False)        while True:            r, w, es = select.select([sys.stdin], [], [], 10)            if r:                data = os.read(sys.stdin.fileno(), BUFSIZE)                if not data:                    return                unpacker.feed(data)                for type, msgid, method, args in unpacker:                    method = method.decode('ascii')                    try:                        try:                            f = getattr(self, method)                        except AttributeError:                            f = getattr(self.repository, method)                        res = f(*args)                    except Exception as e:                        sys.stdout.buffer.write(msgpack.packb((1, msgid, e.__class__.__name__, None)))                    else:                        sys.stdout.buffer.write(msgpack.packb((1, msgid, None, res)))                    sys.stdout.flush()            if es:                return    def negotiate(self, versions):        return 1    def open(self, path, create=False):        path = os.fsdecode(path)        if path.startswith('/~'):            path = path[1:]        self.repository = Repository(os.path.expanduser(path), create)        return self.repository.idclass RemoteRepository(object):    class RPCError(Exception):        def __init__(self, name):            self.name = name    def __init__(self, location, create=False):        self.p = None        self.cache = LRUCache(256)        self.to_send = b''        self.extra = {}        self.pending = {}        self.unpacker = msgpack.Unpacker(use_list=False)        self.msgid = 0        self.received_msgid = 0        if location.host == '__testsuite__':            args = [sys.executable, '-m', 'attic.archiver', 'serve']        else:            args = ['ssh', '-p', str(location.port), '%s@%s' % (location.user or getpass.getuser(), location.host), 'attic', 'serve']        self.p = Popen(args, bufsize=0, stdin=PIPE, stdout=PIPE)        self.stdin_fd = self.p.stdin.fileno()        self.stdout_fd = self.p.stdout.fileno()        fcntl.fcntl(self.stdin_fd, fcntl.F_SETFL, fcntl.fcntl(self.stdin_fd, fcntl.F_GETFL) | os.O_NONBLOCK)        fcntl.fcntl(self.stdout_fd, fcntl.F_SETFL, fcntl.fcntl(self.stdout_fd, fcntl.F_GETFL) | os.O_NONBLOCK)        self.r_fds = [self.stdout_fd]        self.x_fds = [self.stdin_fd, self.stdout_fd]        version = self.call('negotiate', (1,))        if version != 1:            raise Exception('Server insisted on using unsupported protocol version %d' % version)        try:            self.id = self.call('open', (location.path, create))        except self.RPCError as e:            if e.name == b'DoesNotExist':                raise Repository.DoesNotExist            elif e.name == b'AlreadyExists':                raise Repository.AlreadyExists    def __del__(self):        self.close()    def call(self, cmd, args, wait=True):        self.msgid += 1        to_send = msgpack.packb((1, self.msgid, cmd, args))        w_fds = [self.stdin_fd]        while wait or to_send:            r, w, x = select.select(self.r_fds, w_fds, self.x_fds, 1)            if x:                raise Exception('FD exception occured')            if r:                data = os.read(self.stdout_fd, BUFSIZE)                if not data:                    raise ConnectionClosed()                self.unpacker.feed(data)                for type, msgid, error, res in self.unpacker:                    if msgid == self.msgid:                        assert msgid == self.msgid                        self.received_msgid = msgid                        if error:                            raise self.RPCError(error)                        else:                            return res                    else:                        args = self.pending.pop(msgid, None)                        if args is not None:                            self.cache[args] = msgid, res, error            if w:                if to_send:                    n = os.write(self.stdin_fd, to_send)                    assert n > 0                    to_send = memoryview(to_send)[n:]                else:                    w_fds = []    def _read(self):        data = os.read(self.stdout_fd, BUFSIZE)        if not data:            raise Exception('Remote host closed connection')        self.unpacker.feed(data)        to_yield = []        for type, msgid, error, res in self.unpacker:            self.received_msgid = msgid            args = self.pending.pop(msgid, None)            if args is not None:                self.cache[args] = msgid, res, error                for args, resp, error in self.extra.pop(msgid, []):                    if not resp and not error:                        resp, error = self.cache[args][1:]                    to_yield.append((resp, error))        for res, error in to_yield:            if error:                raise self.RPCError(error)            else:                yield res    def gen_request(self, cmd, argsv, wait):        data = []        m = self.received_msgid        for args in argsv:            # Make sure to invalidate any existing cache entries for non-get requests            if not args in self.cache:                self.msgid += 1                msgid = self.msgid                self.pending[msgid] = args                self.cache[args] = msgid, None, None                data.append(msgpack.packb((1, msgid, cmd, args)))            if wait:                msgid, resp, error = self.cache[args]                m = max(m, msgid)                self.extra.setdefault(m, []).append((args, resp, error))        return b''.join(data)    def gen_cache_requests(self, cmd, peek):        data = []        while True:            try:                args = (peek()[0],)            except StopIteration:                break            if args in self.cache:                continue            self.msgid += 1            msgid = self.msgid            self.pending[msgid] = args            self.cache[args] = msgid, None, None            data.append(msgpack.packb((1, msgid, cmd, args)))        return b''.join(data)    def call_multi(self, cmd, argsv, wait=True, peek=None):        w_fds = [self.stdin_fd]        left = len(argsv)        data = self.gen_request(cmd, argsv, wait)        self.to_send += data        for args, resp, error in self.extra.pop(self.received_msgid, []):            left -= 1            if not resp and not error:                resp, error = self.cache[args][1:]            if error:                raise self.RPCError(error)            else:                yield resp        while left:            r, w, x = select.select(self.r_fds, w_fds, self.x_fds, 1)            if x:                raise Exception('FD exception occured')            if r:                for res in self._read():                    left -= 1                    yield res            if w:                if not self.to_send and peek:                    self.to_send = self.gen_cache_requests(cmd, peek)                if self.to_send:                    n = os.write(self.stdin_fd, self.to_send)                    assert n > 0#                    self.to_send = memoryview(self.to_send)[n:]                    self.to_send = self.to_send[n:]                else:                    w_fds = []                    if not wait:                        return    def commit(self, *args):        self.call('commit', args)    def rollback(self, *args):        self.cache.clear()        self.pending.clear()        self.extra.clear()        return self.call('rollback', args)    def get(self, id):        try:            for res in self.call_multi('get', [(id, )]):                return res        except self.RPCError as e:            if e.name == b'DoesNotExist':                raise Repository.DoesNotExist            raise    def get_many(self, ids, peek=None):        return self.call_multi('get', [(id, ) for id in ids], peek=peek)    def _invalidate(self, id):        key = (id, )        if key in self.cache:            self.pending.pop(self.cache.pop(key)[0], None)    def put(self, id, data, wait=True):        resp = self.call('put', (id, data), wait=wait)        self._invalidate(id)        return resp    def delete(self, id, wait=True):        resp = self.call('delete', (id, ), wait=wait)        self._invalidate(id)        return resp    def close(self):        if self.p:            self.p.stdin.close()            self.p.stdout.close()            self.p.wait()            self.p = None
 |