|
@@ -2,31 +2,14 @@ from __future__ import with_statement
|
|
|
import fcntl
|
|
|
import msgpack
|
|
|
import os
|
|
|
-import paramiko
|
|
|
import select
|
|
|
+from subprocess import Popen, PIPE
|
|
|
import sys
|
|
|
import getpass
|
|
|
|
|
|
from .store import Store
|
|
|
-from .helpers import Counter
|
|
|
|
|
|
-
|
|
|
-BUFSIZE = 1024 * 1024
|
|
|
-
|
|
|
-
|
|
|
-class ChannelNotifyer(object):
|
|
|
-
|
|
|
- def __init__(self, channel):
|
|
|
- self.channel = channel
|
|
|
- self.enabled = Counter()
|
|
|
-
|
|
|
- def set(self):
|
|
|
- if self.enabled > 0:
|
|
|
- with self.channel.lock:
|
|
|
- self.channel.out_buffer_cv.notifyAll()
|
|
|
-
|
|
|
- def clear(self):
|
|
|
- pass
|
|
|
+BUFSIZE = 10 * 1024 * 1024
|
|
|
|
|
|
|
|
|
class StoreServer(object):
|
|
@@ -87,134 +70,73 @@ class RemoteStore(object):
|
|
|
def __init__(self, name):
|
|
|
self.name = name
|
|
|
|
|
|
-
|
|
|
def __init__(self, location, create=False):
|
|
|
- self.client = paramiko.SSHClient()
|
|
|
- self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
|
|
- params = {'username': location.user or getpass.getuser(),
|
|
|
- 'hostname': location.host, 'port': location.port}
|
|
|
- while True:
|
|
|
- try:
|
|
|
- self.client.connect(**params)
|
|
|
- break
|
|
|
- except (paramiko.PasswordRequiredException,
|
|
|
- paramiko.AuthenticationException,
|
|
|
- paramiko.SSHException):
|
|
|
- if not 'password' in params:
|
|
|
- params['password'] = getpass.getpass('Password for %(username)s@%(hostname)s:' % params)
|
|
|
- else:
|
|
|
- raise
|
|
|
-
|
|
|
self.unpacker = msgpack.Unpacker()
|
|
|
- self.transport = self.client.get_transport()
|
|
|
- self.channel = self.transport.open_session()
|
|
|
- self.notifier = ChannelNotifyer(self.channel)
|
|
|
- self.channel.in_buffer.set_event(self.notifier)
|
|
|
- self.channel.in_stderr_buffer.set_event(self.notifier)
|
|
|
- self.channel.exec_command('darc serve')
|
|
|
- self.callbacks = {}
|
|
|
self.msgid = 0
|
|
|
- self.recursion = 0
|
|
|
- self.odata = []
|
|
|
- # Negotiate protocol version
|
|
|
- version = self.cmd('negotiate', (1,))
|
|
|
+ args = ['ssh', '-p', str(location.port), '%s@%s' % (location.user or getpass.getuser(), location.host), 'darc', 'serve']
|
|
|
+ self.p = Popen(args, bufsize=0, stdin=PIPE, stdout=PIPE)
|
|
|
+ self.stdout_fd = self.p.stdout.fileno()
|
|
|
+ version = self.call('negotiate', (1,))
|
|
|
if version != 1:
|
|
|
raise Exception('Server insisted on using unsupported protocol version %d' % version)
|
|
|
- self.id = self.cmd('open', (location.path, create))
|
|
|
-
|
|
|
- def wait(self, write=True):
|
|
|
- with self.channel.lock:
|
|
|
- if ((not write or self.channel.out_window_size == 0) and
|
|
|
- len(self.channel.in_buffer._buffer) == 0 and
|
|
|
- len(self.channel.in_stderr_buffer._buffer) == 0):
|
|
|
- self.channel.out_buffer_cv.wait(1)
|
|
|
-
|
|
|
- def cmd(self, cmd, args, callback=None, callback_data=None):
|
|
|
- self.msgid += 1
|
|
|
- self.notifier.enabled.inc()
|
|
|
- self.odata.append(msgpack.packb((1, self.msgid, cmd, args)))
|
|
|
- self.recursion += 1
|
|
|
- if callback:
|
|
|
- self.add_callback(callback, callback_data)
|
|
|
- if self.recursion > 1:
|
|
|
- self.recursion -= 1
|
|
|
- return
|
|
|
- while True:
|
|
|
- if self.channel.closed:
|
|
|
- self.recursion -= 1
|
|
|
- raise Exception('Connection closed')
|
|
|
- elif self.channel.recv_stderr_ready():
|
|
|
- print >> sys.stderr, 'remote stderr:', self.channel.recv_stderr(BUFSIZE)
|
|
|
- elif self.channel.recv_ready():
|
|
|
- self.unpacker.feed(self.channel.recv(BUFSIZE))
|
|
|
- for type, msgid, error, res in self.unpacker:
|
|
|
- self.notifier.enabled.dec()
|
|
|
- if msgid == self.msgid:
|
|
|
- if error:
|
|
|
- self.recursion -= 1
|
|
|
- raise self.RPCError(error)
|
|
|
- self.recursion -= 1
|
|
|
- return res
|
|
|
- else:
|
|
|
- for c, d in self.callbacks.pop(msgid, []):
|
|
|
- c(res, error, d)
|
|
|
- elif self.odata and self.channel.send_ready():
|
|
|
- data = self.odata.pop(0)
|
|
|
- n = self.channel.send(data)
|
|
|
- if n != len(data):
|
|
|
- self.odata.insert(0, data[n:])
|
|
|
- if not self.odata and callback:
|
|
|
- self.recursion -= 1
|
|
|
- return
|
|
|
- else:
|
|
|
- self.wait(self.odata)
|
|
|
+ self.id = self.call('open', (location.path, create))
|
|
|
+
|
|
|
+ def __del__(self):
|
|
|
+ self.p.stdin.close()
|
|
|
+ self.p.stdout.close()
|
|
|
+ self.p.wait()
|
|
|
+
|
|
|
+ def _read(self, msgids):
|
|
|
+ data = os.read(self.stdout_fd, BUFSIZE)
|
|
|
+ self.unpacker.feed(data)
|
|
|
+ for type, msgid, error, res in self.unpacker:
|
|
|
+ if error:
|
|
|
+ raise self.RPCError(error)
|
|
|
+ if msgid in msgids:
|
|
|
+ msgids.remove(msgid)
|
|
|
+ yield res
|
|
|
+
|
|
|
+ def call(self, cmd, args, wait=True):
|
|
|
+ for res in self.call_multi(cmd, [args], wait=wait):
|
|
|
+ return res
|
|
|
+
|
|
|
+ def call_multi(self, cmd, argsv, wait=True):
|
|
|
+ msgids = set()
|
|
|
+ for args in argsv:
|
|
|
+ if select.select([self.stdout_fd], [], [], 0)[0]:
|
|
|
+ for res in self._read(msgids):
|
|
|
+ yield res
|
|
|
+ self.msgid += 1
|
|
|
+ msgid = self.msgid
|
|
|
+ msgids.add(msgid)
|
|
|
+ self.p.stdin.write(msgpack.packb((1, msgid, cmd, args)))
|
|
|
+ while msgids and wait:
|
|
|
+ for res in self._read(msgids):
|
|
|
+ yield res
|
|
|
|
|
|
def commit(self, *args):
|
|
|
- self.cmd('commit', args)
|
|
|
+ self.call('commit', args)
|
|
|
|
|
|
def rollback(self, *args):
|
|
|
- return self.cmd('rollback', args)
|
|
|
+ return self.call('rollback', args)
|
|
|
|
|
|
- def get(self, id, callback=None, callback_data=None):
|
|
|
+ def get(self, id):
|
|
|
try:
|
|
|
- return self.cmd('get', (id, ), callback, callback_data)
|
|
|
+ return self.call('get', (id, ))
|
|
|
except self.RPCError, e:
|
|
|
if e.name == 'DoesNotExist':
|
|
|
raise self.DoesNotExist
|
|
|
raise
|
|
|
|
|
|
- def put(self, id, data, callback=None, callback_data=None):
|
|
|
+ def get_many(self, ids):
|
|
|
+ return self.call_multi('get', [(id, ) for id in ids])
|
|
|
+
|
|
|
+ def put(self, id, data, wait=True):
|
|
|
try:
|
|
|
- return self.cmd('put', (id, data), callback, callback_data)
|
|
|
+ return self.call('put', (id, data), wait=wait)
|
|
|
except self.RPCError, e:
|
|
|
if e.name == 'AlreadyExists':
|
|
|
raise self.AlreadyExists
|
|
|
|
|
|
- def delete(self, id, callback=None, callback_data=None):
|
|
|
- return self.cmd('delete', (id, ), callback, callback_data)
|
|
|
-
|
|
|
- def add_callback(self, cb, data):
|
|
|
- self.callbacks.setdefault(self.msgid, []).append((cb, data))
|
|
|
-
|
|
|
- def flush_rpc(self, counter=None, backlog=0):
|
|
|
- counter = counter or self.notifier.enabled
|
|
|
- while counter > backlog:
|
|
|
- if self.channel.closed:
|
|
|
- raise Exception('Connection closed')
|
|
|
- elif self.odata and self.channel.send_ready():
|
|
|
- n = self.channel.send(self.odata)
|
|
|
- if n > 0:
|
|
|
- self.odata = self.odata[n:]
|
|
|
- elif self.channel.recv_stderr_ready():
|
|
|
- print >> sys.stderr, 'remote stderr:', self.channel.recv_stderr(BUFSIZE)
|
|
|
- elif self.channel.recv_ready():
|
|
|
- self.unpacker.feed(self.channel.recv(BUFSIZE))
|
|
|
- for type, msgid, error, res in self.unpacker:
|
|
|
- self.notifier.enabled.dec()
|
|
|
- for c, d in self.callbacks.pop(msgid, []):
|
|
|
- c(res, error, d)
|
|
|
- if msgid == self.msgid:
|
|
|
- return
|
|
|
- else:
|
|
|
- self.wait(self.odata)
|
|
|
-
|
|
|
+ def delete(self, id, wait=True):
|
|
|
+ return self.call('delete', (id, ), wait=wait)
|