keychain.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. from getpass import getpass
  2. import hashlib
  3. import os
  4. import msgpack
  5. import zlib
  6. from pbkdf2 import pbkdf2
  7. from Crypto.Cipher import AES
  8. from Crypto.Hash import SHA256, HMAC
  9. from Crypto.PublicKey import RSA
  10. from Crypto.Util import Counter
  11. from Crypto.Util.number import bytes_to_long, long_to_bytes
  12. from .helpers import IntegrityError, zero_pad
  13. from .oaep import OAEP
  14. class Keychain(object):
  15. FILE_ID = 'DARC KEYCHAIN'
  16. CREATE = '\1'
  17. READ = '\2'
  18. def __init__(self, path=None):
  19. self._key_cache = {}
  20. self.read_key = os.urandom(32)
  21. self.create_key = os.urandom(32)
  22. self.counter = Counter.new(64, prefix='\0' * 8)
  23. self.aes_id = self.rsa_read = self.rsa_create = None
  24. self.path = path
  25. if path:
  26. self.open(path)
  27. def get_chunkify_seed(self):
  28. return bytes_to_long(self.aes_id[:4])
  29. def open(self, path):
  30. print 'Opening keychain "%s"' % path
  31. with open(path, 'rb') as fd:
  32. if fd.read(len(self.FILE_ID)) != self.FILE_ID:
  33. raise ValueError('Not a keychain')
  34. cdata = fd.read()
  35. self.password = ''
  36. data = self.decrypt_keychain(cdata, '')
  37. while not data:
  38. self.password = getpass('Keychain password: ')
  39. if not self.password:
  40. raise Exception('Keychain decryption failed')
  41. data = self.decrypt_keychain(cdata, self.password)
  42. if not data:
  43. print 'Incorrect password'
  44. chain = msgpack.unpackb(data)
  45. assert chain['version'] == 1
  46. self.aes_id = chain['aes_id']
  47. self.rsa_read = RSA.importKey(chain['rsa_read'])
  48. self.rsa_create = RSA.importKey(chain['rsa_create'])
  49. self.read_encrypted = OAEP(256, hash=SHA256).encode(self.read_key, os.urandom(32))
  50. self.read_encrypted = zero_pad(self.rsa_read.encrypt(self.read_encrypted, '')[0], 256)
  51. self.create_encrypted = OAEP(256, hash=SHA256).encode(self.create_key, os.urandom(32))
  52. self.create_encrypted = zero_pad(self.rsa_create.encrypt(self.create_encrypted, '')[0], 256)
  53. def encrypt_keychain(self, data, password):
  54. salt = os.urandom(32)
  55. iterations = 2000
  56. key = pbkdf2(password, salt, 32, iterations, hashlib.sha256)
  57. hash = HMAC.new(key, data, SHA256).digest()
  58. cdata = AES.new(key, AES.MODE_CTR, counter=Counter.new(128)).encrypt(data)
  59. d = {
  60. 'version': 1,
  61. 'salt': salt,
  62. 'iterations': iterations,
  63. 'algorithm': 'SHA256',
  64. 'hash': hash,
  65. 'data': cdata,
  66. }
  67. return msgpack.packb(d)
  68. def decrypt_keychain(self, data, password):
  69. d = msgpack.unpackb(data)
  70. assert d['version'] == 1
  71. assert d['algorithm'] == 'SHA256'
  72. key = pbkdf2(password, d['salt'], 32, d['iterations'], hashlib.sha256)
  73. data = AES.new(key, AES.MODE_CTR, counter=Counter.new(128)).decrypt(d['data'])
  74. if HMAC.new(key, data, SHA256).digest() != d['hash']:
  75. return None
  76. return data
  77. def save(self, path, password):
  78. chain = {
  79. 'version': 1,
  80. 'aes_id': self.aes_id,
  81. 'rsa_read': self.rsa_read.exportKey('PEM'),
  82. 'rsa_create': self.rsa_create.exportKey('PEM'),
  83. }
  84. data = self.encrypt_keychain(msgpack.packb(chain), password)
  85. with open(path, 'wb') as fd:
  86. fd.write(self.FILE_ID)
  87. fd.write(data)
  88. print 'Key chain "%s" saved' % path
  89. def restrict(self, path):
  90. if os.path.exists(path):
  91. print '%s already exists' % path
  92. return 1
  93. self.rsa_read = self.rsa_read.publickey()
  94. self.save(path, self.password)
  95. return 0
  96. def chpass(self):
  97. password, password2 = 1, 2
  98. while password != password2:
  99. password = getpass('New password: ')
  100. password2 = getpass('New password again: ')
  101. if password != password2:
  102. print 'Passwords do not match'
  103. self.save(self.path, password)
  104. return 0
  105. @staticmethod
  106. def generate(path):
  107. if os.path.exists(path):
  108. print '%s already exists' % path
  109. return 1
  110. password, password2 = 1, 2
  111. while password != password2:
  112. password = getpass('Keychain password: ')
  113. password2 = getpass('Keychain password again: ')
  114. if password != password2:
  115. print 'Passwords do not match'
  116. chain = Keychain()
  117. print 'Generating keychain'
  118. chain.aes_id = os.urandom(32)
  119. chain.rsa_read = RSA.generate(2048)
  120. chain.rsa_create = RSA.generate(2048)
  121. chain.save(path, password)
  122. return 0
  123. def id_hash(self, data):
  124. """Return HMAC hash using the "id" AES key
  125. """
  126. return HMAC.new(self.aes_id, data, SHA256).digest()
  127. def _encrypt(self, id, rsa_key, key, data):
  128. """Helper function used by `encrypt_read` and `encrypt_create`
  129. """
  130. data = zlib.compress(data)
  131. nonce = long_to_bytes(self.counter.next_value(), 8)
  132. data = nonce + rsa_key + AES.new(key, AES.MODE_CTR, '', counter=self.counter).encrypt(data)
  133. hash = self.id_hash(data)
  134. return ''.join((id, hash, data)), hash
  135. def encrypt_read(self, data):
  136. """Encrypt `data` using the AES "read" key
  137. An RSA encrypted version of the AES key is included in the header
  138. """
  139. return self._encrypt(self.READ, self.read_encrypted, self.read_key, data)
  140. def encrypt_create(self, data, iv=None):
  141. """Encrypt `data` using the AES "create" key
  142. An RSA encrypted version of the AES key is included in the header
  143. """
  144. return self._encrypt(self.CREATE, self.create_encrypted, self.create_key, data)
  145. def _decrypt_key(self, data, rsa_key):
  146. """Helper function used by `decrypt`
  147. """
  148. try:
  149. return self._key_cache[data]
  150. except KeyError:
  151. self._key_cache[data] = OAEP(256, hash=SHA256).decode(rsa_key.decrypt(data))
  152. return self._key_cache[data]
  153. def decrypt(self, data):
  154. """Decrypt `data` previously encrypted by `encrypt_create` or `encrypt_read`
  155. """
  156. type = data[0]
  157. hash = data[1:33]
  158. if self.id_hash(data[33:]) != hash:
  159. raise IntegrityError('Encryption integrity error')
  160. nonce = bytes_to_long(data[33:41])
  161. counter = Counter.new(64, prefix='\0' * 8, initial_value=nonce)
  162. if type == self.READ:
  163. key = self._decrypt_key(data[41:297], self.rsa_read)
  164. elif type == self.CREATE:
  165. key = self.decrypt_key(data[41:297], self.rsa_create)
  166. else:
  167. raise Exception('Unknown pack type %d found' % ord(type))
  168. data = AES.new(key, AES.MODE_CTR, counter=counter).decrypt(data[297:])
  169. return zlib.decompress(data), hash