remote.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. from __future__ import with_statement
  2. import fcntl
  3. import msgpack
  4. import os
  5. import paramiko
  6. import select
  7. import sys
  8. import getpass
  9. from .store import Store
  10. from .helpers import Counter
  11. BUFSIZE = 1024 * 1024
  12. class ChannelNotifyer(object):
  13. def __init__(self, channel):
  14. self.channel = channel
  15. self.enabled = Counter()
  16. def set(self):
  17. if self.enabled > 0:
  18. with self.channel.lock:
  19. self.channel.out_buffer_cv.notifyAll()
  20. def clear(self):
  21. pass
  22. class StoreServer(object):
  23. def __init__(self):
  24. self.store = None
  25. def serve(self):
  26. # Make stdin non-blocking
  27. fl = fcntl.fcntl(sys.stdin.fileno(), fcntl.F_GETFL)
  28. fcntl.fcntl(sys.stdin.fileno(), fcntl.F_SETFL, fl | os.O_NONBLOCK)
  29. # Make stdout blocking
  30. fl = fcntl.fcntl(sys.stdout.fileno(), fcntl.F_GETFL)
  31. fcntl.fcntl(sys.stdout.fileno(), fcntl.F_SETFL, fl & ~os.O_NONBLOCK)
  32. unpacker = msgpack.Unpacker()
  33. while True:
  34. r, w, es = select.select([sys.stdin], [], [], 10)
  35. if r:
  36. data = os.read(sys.stdin.fileno(), BUFSIZE)
  37. if not data:
  38. return
  39. unpacker.feed(data)
  40. for type, msgid, method, args in unpacker:
  41. try:
  42. try:
  43. f = getattr(self, method)
  44. except AttributeError:
  45. f = getattr(self.store, method)
  46. res = f(*args)
  47. except Exception, e:
  48. sys.stdout.write(msgpack.packb((1, msgid, e.__class__.__name__, None)))
  49. else:
  50. sys.stdout.write(msgpack.packb((1, msgid, None, res)))
  51. sys.stdout.flush()
  52. if es:
  53. return
  54. def negotiate(self, versions):
  55. return 1
  56. def open(self, path, create=False):
  57. if path.startswith('/~'):
  58. path = path[1:]
  59. self.store = Store(os.path.expanduser(path), create)
  60. return self.store.id
  61. class RemoteStore(object):
  62. class DoesNotExist(Exception):
  63. pass
  64. class AlreadyExists(Exception):
  65. pass
  66. class RPCError(Exception):
  67. def __init__(self, name):
  68. self.name = name
  69. def __init__(self, location, create=False):
  70. self.client = paramiko.SSHClient()
  71. self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
  72. params = {'username': location.user or getpass.getuser(),
  73. 'hostname': location.host, 'port': location.port}
  74. while True:
  75. try:
  76. self.client.connect(**params)
  77. break
  78. except (paramiko.PasswordRequiredException,
  79. paramiko.AuthenticationException,
  80. paramiko.SSHException):
  81. if not 'password' in params:
  82. params['password'] = getpass.getpass('Password for %(username)s@%(hostname)s:' % params)
  83. else:
  84. raise
  85. self.unpacker = msgpack.Unpacker()
  86. self.transport = self.client.get_transport()
  87. self.channel = self.transport.open_session()
  88. self.notifier = ChannelNotifyer(self.channel)
  89. self.channel.in_buffer.set_event(self.notifier)
  90. self.channel.in_stderr_buffer.set_event(self.notifier)
  91. self.channel.exec_command('darc serve')
  92. self.callbacks = {}
  93. self.msgid = 0
  94. self.recursion = 0
  95. self.odata = []
  96. # Negotiate protocol version
  97. version = self.cmd('negotiate', (1,))
  98. if version != 1:
  99. raise Exception('Server insisted on using unsupported protocol version %d' % version)
  100. self.id = self.cmd('open', (location.path, create))
  101. def wait(self, write=True):
  102. with self.channel.lock:
  103. if ((not write or self.channel.out_window_size == 0) and
  104. len(self.channel.in_buffer._buffer) == 0 and
  105. len(self.channel.in_stderr_buffer._buffer) == 0):
  106. self.channel.out_buffer_cv.wait(1)
  107. def cmd(self, cmd, args, callback=None, callback_data=None):
  108. self.msgid += 1
  109. self.notifier.enabled.inc()
  110. self.odata.append(msgpack.packb((1, self.msgid, cmd, args)))
  111. self.recursion += 1
  112. if callback:
  113. self.add_callback(callback, callback_data)
  114. if self.recursion > 1:
  115. self.recursion -= 1
  116. return
  117. while True:
  118. if self.channel.closed:
  119. self.recursion -= 1
  120. raise Exception('Connection closed')
  121. elif self.channel.recv_stderr_ready():
  122. print >> sys.stderr, 'remote stderr:', self.channel.recv_stderr(BUFSIZE)
  123. elif self.channel.recv_ready():
  124. self.unpacker.feed(self.channel.recv(BUFSIZE))
  125. for type, msgid, error, res in self.unpacker:
  126. self.notifier.enabled.dec()
  127. if msgid == self.msgid:
  128. if error:
  129. self.recursion -= 1
  130. raise self.RPCError(error)
  131. self.recursion -= 1
  132. return res
  133. else:
  134. for c, d in self.callbacks.pop(msgid, []):
  135. c(res, error, d)
  136. elif self.odata and self.channel.send_ready():
  137. data = self.odata.pop(0)
  138. n = self.channel.send(data)
  139. if n != len(data):
  140. self.odata.insert(0, data[n:])
  141. if not self.odata and callback:
  142. self.recursion -= 1
  143. return
  144. else:
  145. self.wait(self.odata)
  146. def commit(self, *args):
  147. self.cmd('commit', args)
  148. def rollback(self, *args):
  149. return self.cmd('rollback', args)
  150. def get(self, id, callback=None, callback_data=None):
  151. try:
  152. return self.cmd('get', (id, ), callback, callback_data)
  153. except self.RPCError, e:
  154. if e.name == 'DoesNotExist':
  155. raise self.DoesNotExist
  156. raise
  157. def put(self, id, data, callback=None, callback_data=None):
  158. try:
  159. return self.cmd('put', (id, data), callback, callback_data)
  160. except self.RPCError, e:
  161. if e.name == 'AlreadyExists':
  162. raise self.AlreadyExists
  163. def delete(self, id, callback=None, callback_data=None):
  164. return self.cmd('delete', (id, ), callback, callback_data)
  165. def add_callback(self, cb, data):
  166. self.callbacks.setdefault(self.msgid, []).append((cb, data))
  167. def flush_rpc(self, counter=None, backlog=0):
  168. counter = counter or self.notifier.enabled
  169. while counter > backlog:
  170. if self.channel.closed:
  171. raise Exception('Connection closed')
  172. elif self.odata and self.channel.send_ready():
  173. n = self.channel.send(self.odata)
  174. if n > 0:
  175. self.odata = self.odata[n:]
  176. elif self.channel.recv_stderr_ready():
  177. print >> sys.stderr, 'remote stderr:', self.channel.recv_stderr(BUFSIZE)
  178. elif self.channel.recv_ready():
  179. self.unpacker.feed(self.channel.recv(BUFSIZE))
  180. for type, msgid, error, res in self.unpacker:
  181. self.notifier.enabled.dec()
  182. for c, d in self.callbacks.pop(msgid, []):
  183. c(res, error, d)
  184. if msgid == self.msgid:
  185. return
  186. else:
  187. self.wait(self.odata)