crypto.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. from getpass import getpass
  2. import hashlib
  3. import os
  4. import logging
  5. import msgpack
  6. import zlib
  7. from pbkdf2 import pbkdf2
  8. from Crypto.Cipher import AES
  9. from Crypto.Hash import SHA256, HMAC
  10. from Crypto.PublicKey import RSA
  11. from Crypto.Util import Counter
  12. from Crypto.Util.number import bytes_to_long
  13. from .helpers import IntegrityError
  14. from .oaep import OAEP
  15. class KeyChain(object):
  16. def __init__(self, path=None):
  17. self.aes_id = self.rsa_read = self.rsa_create = None
  18. if path:
  19. self.open(path)
  20. def open(self, path):
  21. with open(path, 'rb') as fd:
  22. cdata = fd.read()
  23. data = self.decrypt(cdata, '')
  24. while not data:
  25. password = getpass('Keychain password: ')
  26. if not password:
  27. raise Exception('Keychain decryption failed')
  28. data = self.decrypt(cdata, password)
  29. if not data:
  30. logging.error('Incorrect password')
  31. chain = msgpack.unpackb(data)
  32. logging.info('Key chain "%s" opened', path)
  33. assert chain['version'] == 1
  34. self.aes_id = chain['aes_id']
  35. self.rsa_read = RSA.importKey(chain['rsa_read'])
  36. self.rsa_create = RSA.importKey(chain['rsa_create'])
  37. def encrypt(self, data, password):
  38. salt = os.urandom(32)
  39. iterations = 2000
  40. key = pbkdf2(password, salt, 32, iterations, hashlib.sha256)
  41. hash = HMAC.new(key, data, SHA256).digest()
  42. cdata = AES.new(key, AES.MODE_CTR, counter=Counter.new(128)).encrypt(data)
  43. d = {
  44. 'version': 1,
  45. 'salt': salt,
  46. 'iterations': iterations,
  47. 'algorithm': 'SHA256',
  48. 'hash': hash,
  49. 'data': cdata,
  50. }
  51. return msgpack.packb(d)
  52. def decrypt(self, data, password):
  53. d = msgpack.unpackb(data)
  54. assert d['version'] == 1
  55. assert d['algorithm'] == 'SHA256'
  56. key = pbkdf2(password, d['salt'], 32, d['iterations'], hashlib.sha256)
  57. data = AES.new(key, AES.MODE_CTR, counter=Counter.new(128)).decrypt(d['data'])
  58. if HMAC.new(key, data, SHA256).digest() != d['hash']:
  59. return None
  60. return data
  61. def save(self, path, password):
  62. chain = {
  63. 'version': 1,
  64. 'aes_id': self.aes_id,
  65. 'rsa_read': self.rsa_read.exportKey('PEM'),
  66. 'rsa_create': self.rsa_create.exportKey('PEM'),
  67. }
  68. data = self.encrypt(msgpack.packb(chain), password)
  69. with open(path, 'wb') as fd:
  70. fd.write(data)
  71. logging.info('Key chain "%s" saved', path)
  72. @staticmethod
  73. def generate(path, password):
  74. chain = KeyChain()
  75. logging.info('Generating keys')
  76. chain.aes_id = os.urandom(32)
  77. chain.rsa_read = RSA.generate(2048)
  78. chain.rsa_create = RSA.generate(2048)
  79. chain.save(path, password)
  80. return chain
  81. class CryptoManager(object):
  82. CREATE = '\1'
  83. READ = '\2'
  84. def __init__(self, keychain):
  85. self._key_cache = {}
  86. self.keychain = keychain
  87. self.read_key = os.urandom(32)
  88. self.create_key = os.urandom(32)
  89. self.read_encrypted = OAEP(256, hash=SHA256).encode(self.read_key, os.urandom(32))
  90. self.read_encrypted = keychain.rsa_read.encrypt(self.read_encrypted, '')[0]
  91. self.create_encrypted = OAEP(256, hash=SHA256).encode(self.create_key, os.urandom(32))
  92. self.create_encrypted = keychain.rsa_create.encrypt(self.create_encrypted, '')[0]
  93. def id_hash(self, data):
  94. return HMAC.new(self.keychain.aes_id, data, SHA256).digest()
  95. def encrypt_read(self, data):
  96. data = zlib.compress(data)
  97. hash = self.id_hash(data)
  98. counter = Counter.new(128, initial_value=bytes_to_long(hash[:16]), allow_wraparound=True)
  99. data = AES.new(self.read_key, AES.MODE_CTR, '', counter=counter).encrypt(data)
  100. return ''.join((self.READ, self.read_encrypted, hash, data)), hash
  101. def encrypt_create(self, data):
  102. data = zlib.compress(data)
  103. hash = self.id_hash(data)
  104. counter = Counter.new(128, initial_value=bytes_to_long(hash[:16]), allow_wraparound=True)
  105. data = AES.new(self.create_key, AES.MODE_CTR, '', counter=counter).encrypt(data)
  106. return ''.join((self.CREATE, self.create_encrypted, hash, data)), hash
  107. def decrypt_key(self, data, rsa_key):
  108. try:
  109. return self._key_cache[data]
  110. except KeyError:
  111. self._key_cache[data] = OAEP(256, hash=SHA256).decode(rsa_key.decrypt(data))
  112. return self._key_cache[data]
  113. def decrypt(self, data):
  114. type = data[0]
  115. if type == self.READ:
  116. key = self.decrypt_key(data[1:257], self.keychain.rsa_read)
  117. hash = data[257:289]
  118. counter = Counter.new(128, initial_value=bytes_to_long(hash[:16]), allow_wraparound=True)
  119. data = AES.new(key, AES.MODE_CTR, counter=counter).decrypt(data[289:])
  120. if self.id_hash(data) != hash:
  121. raise IntegrityError('decryption failed')
  122. return zlib.decompress(data), hash
  123. elif type == self.CREATE:
  124. key = self.decrypt_key(data[1:257], self.keychain.rsa_create)
  125. hash = data[257:289]
  126. counter = Counter.new(128, initial_value=bytes_to_long(hash[:16]), allow_wraparound=True)
  127. data = AES.new(key, AES.MODE_CTR, '', counter=counter).decrypt(data[289:])
  128. if self.id_hash(data) != hash:
  129. raise IntegrityError('decryption failed')
  130. return zlib.decompress(data), hash
  131. else:
  132. raise Exception('Unknown pack type %d found' % ord(type))