2
0
Эх сурвалжийг харах

Rework the way the encryption works to make it more tamper proof

Jonas Borgström 14 жил өмнө
parent
commit
480c946415
1 өөрчлөгдсөн 39 нэмэгдсэн , 28 устгасан
  1. 39 28
      darc/keychain.py

+ 39 - 28
darc/keychain.py

@@ -9,7 +9,7 @@ from Crypto.Cipher import AES
 from Crypto.Hash import SHA256, HMAC
 from Crypto.PublicKey import RSA
 from Crypto.Util import Counter
-from Crypto.Util.number import bytes_to_long
+from Crypto.Util.number import bytes_to_long, long_to_bytes
 
 from .helpers import IntegrityError, zero_pad
 from .oaep import OAEP
@@ -25,6 +25,7 @@ class Keychain(object):
         self._key_cache = {}
         self.read_key = os.urandom(32)
         self.create_key = os.urandom(32)
+        self.counter = Counter.new(64, prefix='\0' * 8)
         self.aes_id = self.rsa_read = self.rsa_create = None
         self.path = path
         if path:
@@ -58,7 +59,7 @@ class Keychain(object):
         self.create_encrypted = OAEP(256, hash=SHA256).encode(self.create_key, os.urandom(32))
         self.create_encrypted = zero_pad(self.rsa_create.encrypt(self.create_encrypted, '')[0], 256)
 
-    def encrypt(self, data, password):
+    def encrypt_keychain(self, data, password):
         salt = os.urandom(32)
         iterations = 2000
         key = pbkdf2(password, salt, 32, iterations, hashlib.sha256)
@@ -91,7 +92,7 @@ class Keychain(object):
             'rsa_read': self.rsa_read.exportKey('PEM'),
             'rsa_create': self.rsa_create.exportKey('PEM'),
         }
-        data = self.encrypt(msgpack.packb(chain), password)
+        data = self.encrypt_keychain(msgpack.packb(chain), password)
         with open(path, 'wb') as fd:
             fd.write(self.FILE_ID)
             fd.write(data)
@@ -135,23 +136,36 @@ class Keychain(object):
         return 0
 
     def id_hash(self, data):
+        """Return HMAC hash using the "id" AES key
+        """
         return HMAC.new(self.aes_id, data, SHA256).digest()
 
-    def encrypt_read(self, data):
+    def _encrypt(self, id, rsa_key, key, data):
+        """Helper function used by `encrypt_read` and `encrypt_create`
+        """
         data = zlib.compress(data)
+        nonce = long_to_bytes(self.counter.next_value(), 8)
+        data = nonce + rsa_key + AES.new(key, AES.MODE_CTR, '', counter=self.counter).encrypt(data)
         hash = self.id_hash(data)
-        counter = Counter.new(128, initial_value=bytes_to_long(hash[:16]), allow_wraparound=True)
-        data = AES.new(self.read_key, AES.MODE_CTR, '', counter=counter).encrypt(data)
-        return ''.join((self.READ, self.read_encrypted, hash, data)), hash
+        return ''.join((id, hash, data)), hash
 
-    def encrypt_create(self, data):
-        data = zlib.compress(data)
-        hash = self.id_hash(data)
-        counter = Counter.new(128, initial_value=bytes_to_long(hash[:16]), allow_wraparound=True)
-        data = AES.new(self.create_key, AES.MODE_CTR, '', counter=counter).encrypt(data)
-        return ''.join((self.CREATE, self.create_encrypted, hash, data)), hash
+    def encrypt_read(self, data):
+        """Encrypt `data` using the AES "read" key
+
+        An RSA encrypted version of the AES key is included in the header
+        """
+        return self._encrypt(self.READ, self.read_encrypted, self.read_key, data)
+
+    def encrypt_create(self, data, iv=None):
+        """Encrypt `data` using the AES "create" key
+
+        An RSA encrypted version of the AES key is included in the header
+        """
+        return self._encrypt(self.CREATE, self.create_encrypted, self.create_key, data)
 
-    def decrypt_key(self, data, rsa_key):
+    def _decrypt_key(self, data, rsa_key):
+        """Helper function used by `decrypt`
+        """
         try:
             return self._key_cache[data]
         except KeyError:
@@ -159,25 +173,22 @@ class Keychain(object):
             return self._key_cache[data]
 
     def decrypt(self, data):
+        """Decrypt `data` previously encrypted by `encrypt_create` or `encrypt_read`
+        """
         type = data[0]
+        hash = data[1:33]
+        if self.id_hash(data[33:]) != hash:
+            raise IntegrityError('Encryption integrity error')
+        nonce = bytes_to_long(data[33:41])
+        counter = Counter.new(64, prefix='\0' * 8, initial_value=nonce)
         if type == self.READ:
-            key = self.decrypt_key(data[1:257], self.rsa_read)
-            hash = data[257:289]
-            counter = Counter.new(128, initial_value=bytes_to_long(hash[:16]), allow_wraparound=True)
-            data = AES.new(key, AES.MODE_CTR, counter=counter).decrypt(data[289:])
-            if self.id_hash(data) != hash:
-                raise IntegrityError('decryption failed')
-            return zlib.decompress(data), hash
+            key = self._decrypt_key(data[41:297], self.rsa_read)
         elif type == self.CREATE:
-            key = self.decrypt_key(data[1:257], self.rsa_create)
-            hash = data[257:289]
-            counter = Counter.new(128, initial_value=bytes_to_long(hash[:16]), allow_wraparound=True)
-            data = AES.new(key, AES.MODE_CTR, '', counter=counter).decrypt(data[289:])
-            if self.id_hash(data) != hash:
-                raise IntegrityError('decryption failed')
-            return zlib.decompress(data), hash
+            key = self.decrypt_key(data[41:297], self.rsa_create)
         else:
             raise Exception('Unknown pack type %d found' % ord(type))
+        data = AES.new(key, AES.MODE_CTR, counter=counter).decrypt(data[297:])
+        return zlib.decompress(data), hash