crypto.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import os
  2. import logging
  3. import msgpack
  4. import zlib
  5. from Crypto.Cipher import AES
  6. from Crypto.Hash import SHA256, HMAC
  7. from Crypto.PublicKey import RSA
  8. from Crypto.Util import Counter
  9. from Crypto.Util.number import bytes_to_long
  10. from .helpers import IntegrityError
  11. from .oaep import OAEP
  12. class KeyChain(object):
  13. def __init__(self, path=None):
  14. self.aes_id = self.rsa_read = self.rsa_create = None
  15. if path:
  16. self.open(path)
  17. def open(self, path):
  18. with open(path, 'rb') as fd:
  19. chain = msgpack.unpackb(fd.read())
  20. logging.info('Key chain "%s" opened', path)
  21. assert chain['version'] == 1
  22. self.aes_id = chain['aes_id']
  23. self.rsa_read = RSA.importKey(chain['rsa_read'])
  24. self.rsa_create = RSA.importKey(chain['rsa_create'])
  25. def save(self, path):
  26. chain = {
  27. 'version': 1,
  28. 'aes_id': self.aes_id,
  29. 'rsa_read': self.rsa_read.exportKey('PEM'),
  30. 'rsa_create': self.rsa_create.exportKey('PEM'),
  31. }
  32. with open(path, 'wb') as fd:
  33. fd.write(msgpack.packb(chain))
  34. logging.info('Key chain "%s" saved', path)
  35. @staticmethod
  36. def generate():
  37. chain = KeyChain()
  38. chain.aes_id = os.urandom(32)
  39. chain.rsa_read = RSA.generate(2048)
  40. chain.rsa_create = RSA.generate(2048)
  41. return chain
  42. class CryptoManager(object):
  43. CREATE = '\1'
  44. READ = '\2'
  45. def __init__(self, keychain):
  46. self._key_cache = {}
  47. self.keychain = keychain
  48. self.read_key = os.urandom(32)
  49. self.create_key = os.urandom(32)
  50. self.read_encrypted = OAEP(256, hash=SHA256).encode(self.read_key, os.urandom(32))
  51. self.read_encrypted = keychain.rsa_read.encrypt(self.read_encrypted, '')[0]
  52. self.create_encrypted = OAEP(256, hash=SHA256).encode(self.create_key, os.urandom(32))
  53. self.create_encrypted = keychain.rsa_create.encrypt(self.create_encrypted, '')[0]
  54. def id_hash(self, data):
  55. return HMAC.new(self.keychain.aes_id, data, SHA256).digest()
  56. def encrypt_read(self, data):
  57. data = zlib.compress(data)
  58. hash = self.id_hash(data)
  59. counter = Counter.new(128, initial_value=bytes_to_long(hash[:16]), allow_wraparound=True)
  60. data = AES.new(self.read_key, AES.MODE_CTR, '', counter=counter).encrypt(data)
  61. return ''.join((self.READ, self.read_encrypted, hash, data)), hash
  62. def encrypt_create(self, data):
  63. data = zlib.compress(data)
  64. hash = self.id_hash(data)
  65. counter = Counter.new(128, initial_value=bytes_to_long(hash[:16]), allow_wraparound=True)
  66. data = AES.new(self.create_key, AES.MODE_CTR, '', counter=counter).encrypt(data)
  67. return ''.join((self.CREATE, self.create_encrypted, hash, data)), hash
  68. def decrypt_key(self, data, rsa_key):
  69. try:
  70. return self._key_cache[data]
  71. except KeyError:
  72. self._key_cache[data] = OAEP(256, hash=SHA256).decode(rsa_key.decrypt(data))
  73. return self._key_cache[data]
  74. def decrypt(self, data):
  75. type = data[0]
  76. if type == self.READ:
  77. key = self.decrypt_key(data[1:257], self.keychain.rsa_read)
  78. hash = data[257:289]
  79. counter = Counter.new(128, initial_value=bytes_to_long(hash[:16]), allow_wraparound=True)
  80. data = AES.new(key, AES.MODE_CTR, counter=counter).decrypt(data[289:])
  81. if self.id_hash(data) != hash:
  82. raise IntegrityError('decryption failed')
  83. return zlib.decompress(data), hash
  84. elif type == self.CREATE:
  85. key = self.decrypt_key(data[1:257], self.keychain.rsa_create)
  86. hash = data[257:289]
  87. counter = Counter.new(128, initial_value=bytes_to_long(hash[:16]), allow_wraparound=True)
  88. data = AES.new(key, AES.MODE_CTR, '', counter=counter).decrypt(data[289:])
  89. if self.id_hash(data) != hash:
  90. raise IntegrityError('decryption failed')
  91. return zlib.decompress(data), hash
  92. else:
  93. raise Exception('Unknown pack type %d found' % ord(type))