crypto.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import os
  2. import logging
  3. import msgpack
  4. import zlib
  5. from Crypto.Cipher import AES
  6. from Crypto.Hash import SHA256, HMAC
  7. from Crypto.PublicKey import RSA
  8. from Crypto.Util import Counter
  9. from Crypto.Util.number import bytes_to_long
  10. from .helpers import IntegrityError
  11. from .oaep import OAEP
  12. class KeyChain(object):
  13. def __init__(self, path=None):
  14. self.aes_id = self.rsa_read = self.rsa_create = None
  15. if path:
  16. self.open(path)
  17. def open(self, path):
  18. with open(path, 'rb') as fd:
  19. chain = msgpack.unpackb(fd.read())
  20. logging.info('Key chain "%s" opened', path)
  21. assert chain['version'] == 1
  22. self.aes_id = chain['aes_id']
  23. self.rsa_read = RSA.importKey(chain['rsa_read'])
  24. self.rsa_create = RSA.importKey(chain['rsa_create'])
  25. def save(self, path):
  26. chain = {
  27. 'version': 1,
  28. 'aes_id': self.aes_id,
  29. 'rsa_read': self.rsa_read.exportKey('PEM'),
  30. 'rsa_create': self.rsa_create.exportKey('PEM'),
  31. }
  32. with open(path, 'wb') as fd:
  33. fd.write(msgpack.packb(chain))
  34. logging.info('Key chain "%s" saved', path)
  35. @staticmethod
  36. def generate():
  37. chain = KeyChain()
  38. chain.aes_id = os.urandom(32)
  39. chain.rsa_read = RSA.generate(2048)
  40. chain.rsa_create = RSA.generate(2048)
  41. return chain
  42. class CryptoManager(object):
  43. CREATE = '\1'
  44. READ = '\2'
  45. def __init__(self, keychain):
  46. self._key_cache = {}
  47. self.keychain = keychain
  48. self.read_key = os.urandom(32)
  49. self.create_key = os.urandom(32)
  50. self.read_encrypted = OAEP(256, hash=SHA256).encode(self.read_key, os.urandom(32))
  51. self.read_encrypted = keychain.rsa_read.encrypt(self.read_encrypted, '')[0]
  52. self.create_encrypted = OAEP(256, hash=SHA256).encode(self.create_key, os.urandom(32))
  53. self.create_encrypted = keychain.rsa_create.encrypt(self.create_encrypted, '')[0]
  54. def id_hash(self, data):
  55. return HMAC.new(self.keychain.aes_id, data, SHA256).digest()
  56. def encrypt_read(self, data):
  57. data = zlib.compress(data)
  58. hash = SHA256.new(data).digest()
  59. counter = Counter.new(128, initial_value=bytes_to_long(hash[:16]), allow_wraparound=True)
  60. data = AES.new(self.read_key, AES.MODE_CTR, '', counter=counter).encrypt(data)
  61. return ''.join((self.READ, self.read_encrypted, hash, data))
  62. def encrypt_create(self, data):
  63. data = zlib.compress(data)
  64. hash = SHA256.new(data).digest()
  65. counter = Counter.new(128, initial_value=bytes_to_long(hash[:16]), allow_wraparound=True)
  66. data = AES.new(self.create_key, AES.MODE_CTR, '', counter=counter).encrypt(data)
  67. return ''.join((self.CREATE, self.create_encrypted, hash, data))
  68. def decrypt_key(self, data, rsa_key):
  69. try:
  70. return self._key_cache[data]
  71. except KeyError:
  72. self._key_cache[data] = OAEP(256, hash=SHA256).decode(rsa_key.decrypt(data))
  73. return self._key_cache[data]
  74. def decrypt(self, data):
  75. type = data[0]
  76. if type == self.READ:
  77. key = self.decrypt_key(data[1:257], self.keychain.rsa_read)
  78. hash = data[257:289]
  79. counter = Counter.new(128, initial_value=bytes_to_long(hash[:16]), allow_wraparound=True)
  80. data = AES.new(key, AES.MODE_CTR, counter=counter).decrypt(data[289:])
  81. if SHA256.new(data).digest() != hash:
  82. raise IntegrityError('decryption failed')
  83. return zlib.decompress(data)
  84. elif type == self.CREATE:
  85. key = self.decrypt_key(data[1:257], self.keychain.rsa_create)
  86. hash = data[257:289]
  87. counter = Counter.new(128, initial_value=bytes_to_long(hash[:16]), allow_wraparound=True)
  88. data = AES.new(key, AES.MODE_CTR, '', counter=counter).decrypt(data[289:])
  89. if SHA256.new(data).digest() != hash:
  90. raise IntegrityError('decryption failed')
  91. return zlib.decompress(data)
  92. else:
  93. raise Exception('Unknown pack type %d found' % ord(type))