remote.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. import fcntl
  2. import msgpack
  3. import os
  4. import paramiko
  5. import select
  6. import sys
  7. import getpass
  8. from .store import Store
  9. from .helpers import Counter
  10. BUFSIZE = 1024 * 1024
  11. class ChannelNotifyer(object):
  12. def __init__(self, channel):
  13. self.channel = channel
  14. self.enabled = Counter()
  15. def set(self):
  16. if self.enabled > 0:
  17. with self.channel.lock:
  18. self.channel.out_buffer_cv.notifyAll()
  19. def clear(self):
  20. pass
  21. class StoreServer(object):
  22. def __init__(self):
  23. self.store = None
  24. def serve(self):
  25. # Make stdin non-blocking
  26. fl = fcntl.fcntl(sys.stdin.fileno(), fcntl.F_GETFL)
  27. fcntl.fcntl(sys.stdin.fileno(), fcntl.F_SETFL, fl | os.O_NONBLOCK)
  28. unpacker = msgpack.Unpacker()
  29. while True:
  30. r, w, es = select.select([sys.stdin], [], [], 10)
  31. if r:
  32. data = os.read(sys.stdin.fileno(), BUFSIZE)
  33. if not data:
  34. return
  35. unpacker.feed(data)
  36. for type, msgid, method, args in unpacker:
  37. try:
  38. try:
  39. f = getattr(self, method)
  40. except AttributeError:
  41. f = getattr(self.store, method)
  42. res = f(*args)
  43. except Exception, e:
  44. sys.stdout.write(msgpack.packb((1, msgid, e.__class__.__name__, None)))
  45. else:
  46. sys.stdout.write(msgpack.packb((1, msgid, None, res)))
  47. sys.stdout.flush()
  48. if es:
  49. return
  50. def open(self, path, create=False):
  51. if path.startswith('/~'):
  52. path = path[1:]
  53. self.store = Store(os.path.expanduser(path), create)
  54. return self.store.id, self.store.tid
  55. class RemoteStore(object):
  56. class DoesNotExist(Exception):
  57. pass
  58. class AlreadyExists(Exception):
  59. pass
  60. class RPCError(Exception):
  61. def __init__(self, name):
  62. self.name = name
  63. def __init__(self, location, create=False):
  64. self.client = paramiko.SSHClient()
  65. self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
  66. params = {'username': location.user or getpass.getuser(),
  67. 'hostname': location.host, 'port': location.port}
  68. while True:
  69. try:
  70. self.client.connect(**params)
  71. break
  72. except (paramiko.PasswordRequiredException,
  73. paramiko.AuthenticationException,
  74. paramiko.SSHException):
  75. if not 'password' in params:
  76. params['password'] = getpass.getpass('Password for %(username)s@%(hostname)s:' % params)
  77. else:
  78. raise
  79. self.unpacker = msgpack.Unpacker()
  80. self.transport = self.client.get_transport()
  81. self.channel = self.transport.open_session()
  82. self.notifier = ChannelNotifyer(self.channel)
  83. self.channel.in_buffer.set_event(self.notifier)
  84. self.channel.in_stderr_buffer.set_event(self.notifier)
  85. self.channel.exec_command('darc serve')
  86. self.callbacks = {}
  87. self.msgid = 0
  88. self.recursion = 0
  89. self.odata = ''
  90. self.id, self.tid = self.cmd('open', (location.path, create))
  91. def wait(self, write=True):
  92. with self.channel.lock:
  93. if ((not write or self.channel.out_window_size == 0) and
  94. len(self.channel.in_buffer._buffer) == 0 and
  95. len(self.channel.in_stderr_buffer._buffer) == 0):
  96. self.channel.out_buffer_cv.wait(1)
  97. def cmd(self, cmd, args, callback=None, callback_data=None):
  98. self.msgid += 1
  99. self.notifier.enabled.inc()
  100. self.odata += msgpack.packb((0, self.msgid, cmd, args))
  101. self.recursion += 1
  102. if callback:
  103. self.callbacks[self.msgid] = callback, callback_data
  104. if self.recursion > 1:
  105. self.recursion -= 1
  106. return
  107. while True:
  108. if self.channel.closed:
  109. self.recursion -= 1
  110. raise Exception('Connection closed')
  111. elif self.channel.recv_stderr_ready():
  112. print >> sys.stderr, 'remote stderr:', self.channel.recv_stderr(BUFSIZE)
  113. elif self.channel.recv_ready():
  114. self.unpacker.feed(self.channel.recv(BUFSIZE))
  115. for type, msgid, error, res in self.unpacker:
  116. self.notifier.enabled.dec()
  117. if msgid == self.msgid:
  118. if error:
  119. raise self.RPCError(error)
  120. self.recursion -= 1
  121. return res
  122. else:
  123. c, d = self.callbacks.pop(msgid, (None, None))
  124. if c:
  125. c(res, error, d)
  126. elif self.odata and self.channel.send_ready():
  127. n = self.channel.send(self.odata)
  128. if n > 0:
  129. self.odata = self.odata[n:]
  130. if not self.odata and callback:
  131. self.recursion -= 1
  132. return
  133. else:
  134. self.wait(self.odata)
  135. def commit(self, *args):
  136. self.cmd('commit', args)
  137. self.tid += 1
  138. def rollback(self, *args):
  139. return self.cmd('rollback', args)
  140. def get(self, ns, id, callback=None, callback_data=None):
  141. try:
  142. return self.cmd('get', (ns, id), callback, callback_data)
  143. except self.RPCError, e:
  144. print e.name
  145. if e.name == 'DoesNotExist':
  146. raise self.DoesNotExist
  147. raise
  148. def put(self, ns, id, data, callback=None, callback_data=None):
  149. try:
  150. return self.cmd('put', (ns, id, data), callback, callback_data)
  151. except self.RPCError, e:
  152. if e.name == 'AlreadyExists':
  153. raise self.AlreadyExists
  154. def delete(self, ns, id, callback=None, callback_data=None):
  155. return self.cmd('delete', (ns, id), callback, callback_data)
  156. def list(self, *args):
  157. return self.cmd('list', args)
  158. def flush_rpc(self, counter=None, backlog=0):
  159. counter = counter or self.notifier.enabled
  160. while counter > backlog:
  161. if self.channel.closed:
  162. raise Exception('Connection closed')
  163. elif self.odata and self.channel.send_ready():
  164. n = self.channel.send(self.odata)
  165. if n > 0:
  166. self.odata = self.odata[n:]
  167. elif self.channel.recv_stderr_ready():
  168. print >> sys.stderr, 'remote stderr:', self.channel.recv_stderr(BUFSIZE)
  169. elif self.channel.recv_ready():
  170. self.unpacker.feed(self.channel.recv(BUFSIZE))
  171. for type, msgid, error, res in self.unpacker:
  172. self.notifier.enabled.dec()
  173. c, d = self.callbacks.pop(msgid, (None, None))
  174. if c:
  175. c(res, error, d)
  176. if msgid == self.msgid:
  177. return
  178. else:
  179. self.wait(self.odata)