remote.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. from __future__ import with_statement
  2. import fcntl
  3. import msgpack
  4. import os
  5. import paramiko
  6. import select
  7. import sys
  8. import getpass
  9. from .store import Store
  10. from .helpers import Counter
  11. BUFSIZE = 1024 * 1024
  12. class ChannelNotifyer(object):
  13. def __init__(self, channel):
  14. self.channel = channel
  15. self.enabled = Counter()
  16. def set(self):
  17. if self.enabled > 0:
  18. with self.channel.lock:
  19. self.channel.out_buffer_cv.notifyAll()
  20. def clear(self):
  21. pass
  22. class StoreServer(object):
  23. def __init__(self):
  24. self.store = None
  25. def serve(self):
  26. # Make stdin non-blocking
  27. fl = fcntl.fcntl(sys.stdin.fileno(), fcntl.F_GETFL)
  28. fcntl.fcntl(sys.stdin.fileno(), fcntl.F_SETFL, fl | os.O_NONBLOCK)
  29. unpacker = msgpack.Unpacker()
  30. while True:
  31. r, w, es = select.select([sys.stdin], [], [], 10)
  32. if r:
  33. data = os.read(sys.stdin.fileno(), BUFSIZE)
  34. if not data:
  35. return
  36. unpacker.feed(data)
  37. for type, msgid, method, args in unpacker:
  38. try:
  39. try:
  40. f = getattr(self, method)
  41. except AttributeError:
  42. f = getattr(self.store, method)
  43. res = f(*args)
  44. except Exception, e:
  45. sys.stdout.write(msgpack.packb((1, msgid, e.__class__.__name__, None)))
  46. else:
  47. sys.stdout.write(msgpack.packb((1, msgid, None, res)))
  48. sys.stdout.flush()
  49. if es:
  50. return
  51. def open(self, path, create=False):
  52. if path.startswith('/~'):
  53. path = path[1:]
  54. self.store = Store(os.path.expanduser(path), create)
  55. return self.store.id, self.store.tid
  56. class RemoteStore(object):
  57. class DoesNotExist(Exception):
  58. pass
  59. class AlreadyExists(Exception):
  60. pass
  61. class RPCError(Exception):
  62. def __init__(self, name):
  63. self.name = name
  64. def __init__(self, location, create=False):
  65. self.client = paramiko.SSHClient()
  66. self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
  67. params = {'username': location.user or getpass.getuser(),
  68. 'hostname': location.host, 'port': location.port}
  69. while True:
  70. try:
  71. self.client.connect(**params)
  72. break
  73. except (paramiko.PasswordRequiredException,
  74. paramiko.AuthenticationException,
  75. paramiko.SSHException):
  76. if not 'password' in params:
  77. params['password'] = getpass.getpass('Password for %(username)s@%(hostname)s:' % params)
  78. else:
  79. raise
  80. self.unpacker = msgpack.Unpacker()
  81. self.transport = self.client.get_transport()
  82. self.channel = self.transport.open_session()
  83. self.notifier = ChannelNotifyer(self.channel)
  84. self.channel.in_buffer.set_event(self.notifier)
  85. self.channel.in_stderr_buffer.set_event(self.notifier)
  86. self.channel.exec_command('darc serve')
  87. self.callbacks = {}
  88. self.msgid = 0
  89. self.recursion = 0
  90. self.odata = []
  91. self.id, self.tid = self.cmd('open', (location.path, create))
  92. def wait(self, write=True):
  93. with self.channel.lock:
  94. if ((not write or self.channel.out_window_size == 0) and
  95. len(self.channel.in_buffer._buffer) == 0 and
  96. len(self.channel.in_stderr_buffer._buffer) == 0):
  97. self.channel.out_buffer_cv.wait(1)
  98. def cmd(self, cmd, args, callback=None, callback_data=None):
  99. self.msgid += 1
  100. self.notifier.enabled.inc()
  101. self.odata.append(msgpack.packb((1, self.msgid, cmd, args)))
  102. self.recursion += 1
  103. if callback:
  104. self.callbacks[self.msgid] = callback, callback_data
  105. if self.recursion > 1:
  106. self.recursion -= 1
  107. return
  108. while True:
  109. if self.channel.closed:
  110. self.recursion -= 1
  111. raise Exception('Connection closed')
  112. elif self.channel.recv_stderr_ready():
  113. print >> sys.stderr, 'remote stderr:', self.channel.recv_stderr(BUFSIZE)
  114. elif self.channel.recv_ready():
  115. self.unpacker.feed(self.channel.recv(BUFSIZE))
  116. for type, msgid, error, res in self.unpacker:
  117. self.notifier.enabled.dec()
  118. if msgid == self.msgid:
  119. if error:
  120. self.recursion -= 1
  121. raise self.RPCError(error)
  122. self.recursion -= 1
  123. return res
  124. else:
  125. c, d = self.callbacks.pop(msgid, (None, None))
  126. if c:
  127. c(res, error, d)
  128. elif self.odata and self.channel.send_ready():
  129. data = self.odata.pop(0)
  130. n = self.channel.send(data)
  131. if n != len(data):
  132. self.odata.insert(0, data[n:])
  133. if not self.odata and callback:
  134. self.recursion -= 1
  135. return
  136. else:
  137. self.wait(self.odata)
  138. def commit(self, *args):
  139. self.cmd('commit', args)
  140. self.tid += 1
  141. def rollback(self, *args):
  142. return self.cmd('rollback', args)
  143. def get(self, ns, id, callback=None, callback_data=None):
  144. try:
  145. return self.cmd('get', (ns, id), callback, callback_data)
  146. except self.RPCError, e:
  147. if e.name == 'DoesNotExist':
  148. raise self.DoesNotExist
  149. raise
  150. def put(self, ns, id, data, callback=None, callback_data=None):
  151. try:
  152. return self.cmd('put', (ns, id, data), callback, callback_data)
  153. except self.RPCError, e:
  154. if e.name == 'AlreadyExists':
  155. raise self.AlreadyExists
  156. def delete(self, ns, id, callback=None, callback_data=None):
  157. return self.cmd('delete', (ns, id), callback, callback_data)
  158. def list(self, *args):
  159. return self.cmd('list', args)
  160. def flush_rpc(self, counter=None, backlog=0):
  161. counter = counter or self.notifier.enabled
  162. while counter > backlog:
  163. if self.channel.closed:
  164. raise Exception('Connection closed')
  165. elif self.odata and self.channel.send_ready():
  166. n = self.channel.send(self.odata)
  167. if n > 0:
  168. self.odata = self.odata[n:]
  169. elif self.channel.recv_stderr_ready():
  170. print >> sys.stderr, 'remote stderr:', self.channel.recv_stderr(BUFSIZE)
  171. elif self.channel.recv_ready():
  172. self.unpacker.feed(self.channel.recv(BUFSIZE))
  173. for type, msgid, error, res in self.unpacker:
  174. self.notifier.enabled.dec()
  175. c, d = self.callbacks.pop(msgid, (None, None))
  176. if c:
  177. c(res, error, d)
  178. if msgid == self.msgid:
  179. return
  180. else:
  181. self.wait(self.odata)