remote.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. from __future__ import with_statement
  2. import fcntl
  3. import msgpack
  4. import os
  5. import paramiko
  6. import select
  7. import sys
  8. import getpass
  9. from .store import Store
  10. from .helpers import Counter
  11. BUFSIZE = 1024 * 1024
  12. class ChannelNotifyer(object):
  13. def __init__(self, channel):
  14. self.channel = channel
  15. self.enabled = Counter()
  16. def set(self):
  17. if self.enabled > 0:
  18. with self.channel.lock:
  19. self.channel.out_buffer_cv.notifyAll()
  20. def clear(self):
  21. pass
  22. class StoreServer(object):
  23. def __init__(self):
  24. self.store = None
  25. def serve(self):
  26. # Make stdin non-blocking
  27. fl = fcntl.fcntl(sys.stdin.fileno(), fcntl.F_GETFL)
  28. fcntl.fcntl(sys.stdin.fileno(), fcntl.F_SETFL, fl | os.O_NONBLOCK)
  29. unpacker = msgpack.Unpacker()
  30. while True:
  31. r, w, es = select.select([sys.stdin], [], [], 10)
  32. if r:
  33. data = os.read(sys.stdin.fileno(), BUFSIZE)
  34. if not data:
  35. return
  36. unpacker.feed(data)
  37. for type, msgid, method, args in unpacker:
  38. try:
  39. try:
  40. f = getattr(self, method)
  41. except AttributeError:
  42. f = getattr(self.store, method)
  43. res = f(*args)
  44. except Exception, e:
  45. sys.stdout.write(msgpack.packb((1, msgid, e.__class__.__name__, None)))
  46. else:
  47. sys.stdout.write(msgpack.packb((1, msgid, None, res)))
  48. sys.stdout.flush()
  49. if es:
  50. return
  51. def negotiate(self, versions):
  52. return 1
  53. def open(self, path, create=False):
  54. if path.startswith('/~'):
  55. path = path[1:]
  56. self.store = Store(os.path.expanduser(path), create)
  57. return self.store.id
  58. class RemoteStore(object):
  59. class DoesNotExist(Exception):
  60. pass
  61. class AlreadyExists(Exception):
  62. pass
  63. class RPCError(Exception):
  64. def __init__(self, name):
  65. self.name = name
  66. def __init__(self, location, create=False):
  67. self.client = paramiko.SSHClient()
  68. self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
  69. params = {'username': location.user or getpass.getuser(),
  70. 'hostname': location.host, 'port': location.port}
  71. while True:
  72. try:
  73. self.client.connect(**params)
  74. break
  75. except (paramiko.PasswordRequiredException,
  76. paramiko.AuthenticationException,
  77. paramiko.SSHException):
  78. if not 'password' in params:
  79. params['password'] = getpass.getpass('Password for %(username)s@%(hostname)s:' % params)
  80. else:
  81. raise
  82. self.unpacker = msgpack.Unpacker()
  83. self.transport = self.client.get_transport()
  84. self.channel = self.transport.open_session()
  85. self.notifier = ChannelNotifyer(self.channel)
  86. self.channel.in_buffer.set_event(self.notifier)
  87. self.channel.in_stderr_buffer.set_event(self.notifier)
  88. self.channel.exec_command('darc serve')
  89. self.callbacks = {}
  90. self.msgid = 0
  91. self.recursion = 0
  92. self.odata = []
  93. # Negotiate protocol version
  94. version = self.cmd('negotiate', (1,))
  95. if version != 1:
  96. raise Exception('Server insisted on using unsupported protocol version %d' % version)
  97. self.id = self.cmd('open', (location.path, create))
  98. def wait(self, write=True):
  99. with self.channel.lock:
  100. if ((not write or self.channel.out_window_size == 0) and
  101. len(self.channel.in_buffer._buffer) == 0 and
  102. len(self.channel.in_stderr_buffer._buffer) == 0):
  103. self.channel.out_buffer_cv.wait(1)
  104. def cmd(self, cmd, args, callback=None, callback_data=None):
  105. self.msgid += 1
  106. self.notifier.enabled.inc()
  107. self.odata.append(msgpack.packb((1, self.msgid, cmd, args)))
  108. self.recursion += 1
  109. if callback:
  110. self.add_callback(callback, callback_data)
  111. if self.recursion > 1:
  112. self.recursion -= 1
  113. return
  114. while True:
  115. if self.channel.closed:
  116. self.recursion -= 1
  117. raise Exception('Connection closed')
  118. elif self.channel.recv_stderr_ready():
  119. print >> sys.stderr, 'remote stderr:', self.channel.recv_stderr(BUFSIZE)
  120. elif self.channel.recv_ready():
  121. self.unpacker.feed(self.channel.recv(BUFSIZE))
  122. for type, msgid, error, res in self.unpacker:
  123. self.notifier.enabled.dec()
  124. if msgid == self.msgid:
  125. if error:
  126. self.recursion -= 1
  127. raise self.RPCError(error)
  128. self.recursion -= 1
  129. return res
  130. else:
  131. for c, d in self.callbacks.pop(msgid, []):
  132. c(res, error, d)
  133. elif self.odata and self.channel.send_ready():
  134. data = self.odata.pop(0)
  135. n = self.channel.send(data)
  136. if n != len(data):
  137. self.odata.insert(0, data[n:])
  138. if not self.odata and callback:
  139. self.recursion -= 1
  140. return
  141. else:
  142. self.wait(self.odata)
  143. def commit(self, *args):
  144. self.cmd('commit', args)
  145. def rollback(self, *args):
  146. return self.cmd('rollback', args)
  147. def get(self, id, callback=None, callback_data=None):
  148. try:
  149. return self.cmd('get', (id, ), callback, callback_data)
  150. except self.RPCError, e:
  151. if e.name == 'DoesNotExist':
  152. raise self.DoesNotExist
  153. raise
  154. def put(self, id, data, callback=None, callback_data=None):
  155. try:
  156. return self.cmd('put', (id, data), callback, callback_data)
  157. except self.RPCError, e:
  158. if e.name == 'AlreadyExists':
  159. raise self.AlreadyExists
  160. def delete(self, id, callback=None, callback_data=None):
  161. return self.cmd('delete', (id, ), callback, callback_data)
  162. def add_callback(self, cb, data):
  163. self.callbacks.setdefault(self.msgid, []).append((cb, data))
  164. def flush_rpc(self, counter=None, backlog=0):
  165. counter = counter or self.notifier.enabled
  166. while counter > backlog:
  167. if self.channel.closed:
  168. raise Exception('Connection closed')
  169. elif self.odata and self.channel.send_ready():
  170. n = self.channel.send(self.odata)
  171. if n > 0:
  172. self.odata = self.odata[n:]
  173. elif self.channel.recv_stderr_ready():
  174. print >> sys.stderr, 'remote stderr:', self.channel.recv_stderr(BUFSIZE)
  175. elif self.channel.recv_ready():
  176. self.unpacker.feed(self.channel.recv(BUFSIZE))
  177. for type, msgid, error, res in self.unpacker:
  178. self.notifier.enabled.dec()
  179. for c, d in self.callbacks.pop(msgid, []):
  180. c(res, error, d)
  181. if msgid == self.msgid:
  182. return
  183. else:
  184. self.wait(self.odata)