Преглед на файлове

Verify hash before decompressing with PlaintextKey

Jonas Borgström преди 12 години
родител
ревизия
a887fa2065
променени са 1 файла, в които са добавени 25 реда и са изтрити 19 реда
  1. 25 19
      darc/key.py

+ 25 - 19
darc/key.py

@@ -12,10 +12,6 @@ from .helpers import IntegrityError, get_keys_dir
 
 PREFIX = b'\0' * 8
 
-KEYFILE = b'\0'
-PASSPHRASE = b'\1'
-PLAINTEXT = b'\2'
-
 
 class HMAC(hmac.HMAC):
 
@@ -33,11 +29,11 @@ def key_creator(repository, args):
 
 
 def key_factory(repository, manifest_data):
-    if manifest_data[:1] == KEYFILE:
+    if manifest_data[0] == KeyfileKey.TYPE:
         return KeyfileKey.detect(repository, manifest_data)
-    elif manifest_data[:1] == PASSPHRASE:
+    elif manifest_data[0] == PassphraseKey.TYPE:
         return PassphraseKey.detect(repository, manifest_data)
-    elif manifest_data[:1] == PLAINTEXT:
+    elif manifest_data[0] == PlaintextKey.TYPE:
         return PlaintextKey.detect(repository, manifest_data)
     else:
         raise Exception('Unkown Key type %d' % ord(manifest_data[0]))
@@ -45,6 +41,9 @@ def key_factory(repository, manifest_data):
 
 class KeyBase(object):
 
+    def __init__(self):
+        self.TYPE_STR = bytes([self.TYPE])
+
     def id_hash(self, data):
         """Return HMAC hash using the "id" HMAC key
         """
@@ -57,7 +56,7 @@ class KeyBase(object):
 
 
 class PlaintextKey(KeyBase):
-    TYPE = PLAINTEXT
+    TYPE = 0x02
 
     chunk_seed = 0
 
@@ -74,12 +73,19 @@ class PlaintextKey(KeyBase):
         return sha256(data).digest()
 
     def encrypt(self, data):
-        return b''.join([self.TYPE, zlib.compress(data)])
+        cdata = zlib.compress(data)
+        hash = sha256(cdata).digest()
+        return b''.join([self.TYPE_STR, hash, cdata])
 
     def decrypt(self, id, data):
-        if data[:1] != self.TYPE:
+        if data[0] != self.TYPE:
             raise IntegrityError('Invalid encryption envelope')
-        data = zlib.decompress(memoryview(data)[1:])
+        # This is just a hash and not a hmac but it will at least
+        # stop unintentionally corrupted data from hitting zlib.decompress()
+        hash = memoryview(data)[1:33]
+        if memoryview(sha256(memoryview(data)[33:]).digest()) != hash:
+            raise IntegrityError('Payload checksum mismatch')
+        data = zlib.decompress(memoryview(data)[33:])
         if id and sha256(data).digest() != id:
             raise IntegrityError('Chunk id verification failed')
         return data
@@ -96,14 +102,14 @@ class AESKeyBase(KeyBase):
         data = zlib.compress(data)
         self.enc_cipher.reset()
         data = b''.join((self.enc_cipher.iv[8:], self.enc_cipher.encrypt(data)))
-        hash = HMAC(self.enc_hmac_key, data, sha256).digest()
-        return b''.join((self.TYPE, hash, data))
+        hmac = HMAC(self.enc_hmac_key, data, sha256).digest()
+        return b''.join((self.TYPE_STR, hmac, data))
 
     def decrypt(self, id, data):
-        if data[:1] != self.TYPE:
+        if data[0] != self.TYPE:
             raise IntegrityError('Invalid encryption envelope')
-        hash = memoryview(data)[1:33]
-        if memoryview(HMAC(self.enc_hmac_key, memoryview(data)[33:], sha256).digest()) != hash:
+        hmac = memoryview(data)[1:33]
+        if memoryview(HMAC(self.enc_hmac_key, memoryview(data)[33:], sha256).digest()) != hmac:
             raise IntegrityError('Encryption envelope checksum mismatch')
         self.dec_cipher.reset(iv=PREFIX + data[33:41])
         data = zlib.decompress(self.dec_cipher.decrypt(data[41:]))  # should use memoryview
@@ -112,7 +118,7 @@ class AESKeyBase(KeyBase):
         return data
 
     def extract_iv(self, payload):
-        if payload[:1] != self.TYPE:
+        if payload[0] != self.TYPE:
             raise IntegrityError('Invalid encryption envelope')
         nonce = bytes_to_long(payload[33:41])
         return nonce
@@ -132,7 +138,7 @@ class AESKeyBase(KeyBase):
 
 
 class PassphraseKey(AESKeyBase):
-    TYPE = PASSPHRASE
+    TYPE = 0x01
     iterations = 100000
 
     @classmethod
@@ -179,7 +185,7 @@ class PassphraseKey(AESKeyBase):
 
 class KeyfileKey(AESKeyBase):
     FILE_ID = 'DARC KEY'
-    TYPE = KEYFILE
+    TYPE = 0x00
 
     @classmethod
     def detect(cls, repository, manifest_data):