keychain.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. from getpass import getpass
  2. import hashlib
  3. import os
  4. import msgpack
  5. import zlib
  6. from pbkdf2 import pbkdf2
  7. from Crypto.Cipher import AES
  8. from Crypto.Hash import SHA256, HMAC
  9. from Crypto.PublicKey import RSA
  10. from Crypto.Util import Counter
  11. from Crypto.Util.number import bytes_to_long
  12. from .helpers import IntegrityError, zero_pad
  13. from .oaep import OAEP
  14. class Keychain(object):
  15. FILE_ID = 'DARC KEYCHAIN'
  16. CREATE = '\1'
  17. READ = '\2'
  18. def __init__(self, path=None):
  19. self._key_cache = {}
  20. self.read_key = os.urandom(32)
  21. self.create_key = os.urandom(32)
  22. self.aes_id = self.rsa_read = self.rsa_create = None
  23. self.path = path
  24. if path:
  25. self.open(path)
  26. def get_chunkify_seed(self):
  27. return bytes_to_long(self.aes_id[:4])
  28. def open(self, path):
  29. print 'Opening keychain "%s"' % path
  30. with open(path, 'rb') as fd:
  31. if fd.read(len(self.FILE_ID)) != self.FILE_ID:
  32. raise ValueError('Not a keychain')
  33. cdata = fd.read()
  34. self.password = ''
  35. data = self.decrypt_keychain(cdata, '')
  36. while not data:
  37. self.password = getpass('Keychain password: ')
  38. if not self.password:
  39. raise Exception('Keychain decryption failed')
  40. data = self.decrypt_keychain(cdata, self.password)
  41. if not data:
  42. print 'Incorrect password'
  43. chain = msgpack.unpackb(data)
  44. assert chain['version'] == 1
  45. self.aes_id = chain['aes_id']
  46. self.rsa_read = RSA.importKey(chain['rsa_read'])
  47. self.rsa_create = RSA.importKey(chain['rsa_create'])
  48. self.read_encrypted = OAEP(256, hash=SHA256).encode(self.read_key, os.urandom(32))
  49. self.read_encrypted = zero_pad(self.rsa_read.encrypt(self.read_encrypted, '')[0], 256)
  50. self.create_encrypted = OAEP(256, hash=SHA256).encode(self.create_key, os.urandom(32))
  51. self.create_encrypted = zero_pad(self.rsa_create.encrypt(self.create_encrypted, '')[0], 256)
  52. def encrypt(self, data, password):
  53. salt = os.urandom(32)
  54. iterations = 2000
  55. key = pbkdf2(password, salt, 32, iterations, hashlib.sha256)
  56. hash = HMAC.new(key, data, SHA256).digest()
  57. cdata = AES.new(key, AES.MODE_CTR, counter=Counter.new(128)).encrypt(data)
  58. d = {
  59. 'version': 1,
  60. 'salt': salt,
  61. 'iterations': iterations,
  62. 'algorithm': 'SHA256',
  63. 'hash': hash,
  64. 'data': cdata,
  65. }
  66. return msgpack.packb(d)
  67. def decrypt_keychain(self, data, password):
  68. d = msgpack.unpackb(data)
  69. assert d['version'] == 1
  70. assert d['algorithm'] == 'SHA256'
  71. key = pbkdf2(password, d['salt'], 32, d['iterations'], hashlib.sha256)
  72. data = AES.new(key, AES.MODE_CTR, counter=Counter.new(128)).decrypt(d['data'])
  73. if HMAC.new(key, data, SHA256).digest() != d['hash']:
  74. return None
  75. return data
  76. def save(self, path, password):
  77. chain = {
  78. 'version': 1,
  79. 'aes_id': self.aes_id,
  80. 'rsa_read': self.rsa_read.exportKey('PEM'),
  81. 'rsa_create': self.rsa_create.exportKey('PEM'),
  82. }
  83. data = self.encrypt(msgpack.packb(chain), password)
  84. with open(path, 'wb') as fd:
  85. fd.write(self.FILE_ID)
  86. fd.write(data)
  87. print 'Key chain "%s" saved' % path
  88. def restrict(self, path):
  89. if os.path.exists(path):
  90. print '%s already exists' % path
  91. return 1
  92. self.rsa_read = self.rsa_read.publickey()
  93. self.save(path, self.password)
  94. return 0
  95. def chpass(self):
  96. password, password2 = 1, 2
  97. while password != password2:
  98. password = getpass('New password: ')
  99. password2 = getpass('New password again: ')
  100. if password != password2:
  101. print 'Passwords do not match'
  102. self.save(self.path, password)
  103. return 0
  104. @staticmethod
  105. def generate(path):
  106. if os.path.exists(path):
  107. print '%s already exists' % path
  108. return 1
  109. password, password2 = 1, 2
  110. while password != password2:
  111. password = getpass('Keychain password: ')
  112. password2 = getpass('Keychain password again: ')
  113. if password != password2:
  114. print 'Passwords do not match'
  115. chain = Keychain()
  116. print 'Generating keychain'
  117. chain.aes_id = os.urandom(32)
  118. chain.rsa_read = RSA.generate(2048)
  119. chain.rsa_create = RSA.generate(2048)
  120. chain.save(path, password)
  121. return 0
  122. def id_hash(self, data):
  123. return HMAC.new(self.aes_id, data, SHA256).digest()
  124. def encrypt_read(self, data):
  125. data = zlib.compress(data)
  126. hash = self.id_hash(data)
  127. counter = Counter.new(128, initial_value=bytes_to_long(hash[:16]), allow_wraparound=True)
  128. data = AES.new(self.read_key, AES.MODE_CTR, '', counter=counter).encrypt(data)
  129. return ''.join((self.READ, self.read_encrypted, hash, data)), hash
  130. def encrypt_create(self, data):
  131. data = zlib.compress(data)
  132. hash = self.id_hash(data)
  133. counter = Counter.new(128, initial_value=bytes_to_long(hash[:16]), allow_wraparound=True)
  134. data = AES.new(self.create_key, AES.MODE_CTR, '', counter=counter).encrypt(data)
  135. return ''.join((self.CREATE, self.create_encrypted, hash, data)), hash
  136. def decrypt_key(self, data, rsa_key):
  137. try:
  138. return self._key_cache[data]
  139. except KeyError:
  140. self._key_cache[data] = OAEP(256, hash=SHA256).decode(rsa_key.decrypt(data))
  141. return self._key_cache[data]
  142. def decrypt(self, data):
  143. type = data[0]
  144. if type == self.READ:
  145. key = self.decrypt_key(data[1:257], self.rsa_read)
  146. hash = data[257:289]
  147. counter = Counter.new(128, initial_value=bytes_to_long(hash[:16]), allow_wraparound=True)
  148. data = AES.new(key, AES.MODE_CTR, counter=counter).decrypt(data[289:])
  149. if self.id_hash(data) != hash:
  150. raise IntegrityError('decryption failed')
  151. return zlib.decompress(data), hash
  152. elif type == self.CREATE:
  153. key = self.decrypt_key(data[1:257], self.rsa_create)
  154. hash = data[257:289]
  155. counter = Counter.new(128, initial_value=bytes_to_long(hash[:16]), allow_wraparound=True)
  156. data = AES.new(key, AES.MODE_CTR, '', counter=counter).decrypt(data[289:])
  157. if self.id_hash(data) != hash:
  158. raise IntegrityError('decryption failed')
  159. return zlib.decompress(data), hash
  160. else:
  161. raise Exception('Unknown pack type %d found' % ord(type))