ソースを参照

refactor AES class to new api

Thomas Waldmann 8 年 前
コミット
de0707d3dd
2 ファイル変更75 行追加73 行削除
  1. 2 2
      src/borg/crypto/key.py
  2. 73 71
      src/borg/crypto/low_level.pyx

+ 2 - 2
src/borg/crypto/key.py

@@ -590,7 +590,7 @@ class KeyfileKeyBase(AESKeyBase):
         assert enc_key.version == 1
         assert enc_key.algorithm == 'sha256'
         key = passphrase.kdf(enc_key.salt, enc_key.iterations, 32)
-        data = AES(is_encrypt=False, key=key).decrypt(enc_key.data)
+        data = AES(key, b'\0'*16).decrypt(enc_key.data)
         if hmac_sha256(key, data) == enc_key.hash:
             return data
 
@@ -599,7 +599,7 @@ class KeyfileKeyBase(AESKeyBase):
         iterations = PBKDF2_ITERATIONS
         key = passphrase.kdf(salt, iterations, 32)
         hash = hmac_sha256(key, data)
-        cdata = AES(is_encrypt=True, key=key).encrypt(data)
+        cdata = AES(key, b'\0'*16).encrypt(data)
         enc_key = EncryptedKey(
             version=1,
             salt=salt,

+ 73 - 71
src/borg/crypto/low_level.pyx

@@ -38,8 +38,6 @@ import hashlib
 import hmac
 from math import ceil
 
-from libc.stdlib cimport malloc, free
-
 from cpython cimport PyMem_Malloc, PyMem_Free
 from cpython.buffer cimport PyBUF_SIMPLE, PyObject_GetBuffer, PyBuffer_Release
 from cpython.bytes cimport PyBytes_FromStringAndSize
@@ -563,98 +561,102 @@ cdef class CHACHA20_POLY1305(_CHACHA_BASE):
 
 cdef class AES:
     """A thin wrapper around the OpenSSL EVP cipher API - for legacy code, like key file encryption"""
+    cdef CIPHER cipher
     cdef EVP_CIPHER_CTX *ctx
-    cdef int is_encrypt
-    cdef unsigned char iv_orig[16]
+    cdef unsigned char *enc_key
+    cdef int cipher_blk_len
+    cdef int iv_len
+    cdef unsigned char iv[16]
     cdef long long blocks
 
-    def __cinit__(self, is_encrypt, key, iv=None):
+    def __init__(self, enc_key, iv=None):
+        assert isinstance(enc_key, bytes) and len(enc_key) == 32
+        self.enc_key = enc_key
+        self.iv_len = 16
+        assert sizeof(self.iv) == self.iv_len
+        self.cipher = EVP_aes_256_ctr
+        self.cipher_blk_len = 16
+        if iv is not None:
+            self.set_iv(iv)
+        else:
+            self.blocks = -1  # make sure set_iv is called before encrypt
+
+    def __cinit__(self, enc_key, iv=None):
         self.ctx = EVP_CIPHER_CTX_new()
-        self.is_encrypt = is_encrypt
-        # Set cipher type and mode
-        cipher_mode = EVP_aes_256_ctr()
-        if self.is_encrypt:
-            if not EVP_EncryptInit_ex(self.ctx, cipher_mode, NULL, NULL, NULL):
-                raise Exception('EVP_EncryptInit_ex failed')
-        else:  # decrypt
-            if not EVP_DecryptInit_ex(self.ctx, cipher_mode, NULL, NULL, NULL):
-                raise Exception('EVP_DecryptInit_ex failed')
-        self.reset(key, iv)
 
     def __dealloc__(self):
         EVP_CIPHER_CTX_free(self.ctx)
 
-    def reset(self, key=None, iv=None):
-        cdef const unsigned char *key2 = NULL
-        cdef const unsigned char *iv2 = NULL
-        if key:
-            key2 = key
-        if iv:
-            iv2 = iv
-            assert isinstance(iv, bytes) and len(iv) == 16
-            for i in range(16):
-                self.iv_orig[i] = iv[i]
-            self.blocks = 0  # number of AES blocks encrypted starting with iv_orig
-        # Initialise key and IV
-        if self.is_encrypt:
-            if not EVP_EncryptInit_ex(self.ctx, NULL, NULL, key2, iv2):
-                raise Exception('EVP_EncryptInit_ex failed')
-        else:  # decrypt
-            if not EVP_DecryptInit_ex(self.ctx, NULL, NULL, key2, iv2):
-                raise Exception('EVP_DecryptInit_ex failed')
-
-    @property
-    def iv(self):
-        return increment_iv(self.iv_orig[:16], self.blocks)
-
-    def encrypt(self, data):
-        cdef Py_buffer data_buf = ro_buffer(data)
-        cdef int inl = len(data)
-        cdef int ctl = 0
-        cdef int outl = 0
-        # note: modes that use padding, need up to one extra AES block (16b)
-        cdef unsigned char *out = <unsigned char *>malloc(inl+16)
-        if not out:
+    def encrypt(self, data, iv=None):
+        if iv is not None:
+            self.set_iv(iv)
+        assert self.blocks == 0, 'iv needs to be set before encrypt is called'
+        cdef Py_buffer idata = ro_buffer(data)
+        cdef int ilen = len(data)
+        cdef int offset
+        cdef int olen
+        cdef unsigned char *odata = <unsigned char *>PyMem_Malloc(ilen + self.cipher_blk_len)
+        if not odata:
             raise MemoryError
         try:
-            if not EVP_EncryptUpdate(self.ctx, out, &outl, <const unsigned char*> data_buf.buf, inl):
+            if not EVP_EncryptInit_ex(self.ctx, self.cipher(), NULL, self.enc_key, self.iv):
+                raise Exception('EVP_EncryptInit_ex failed')
+            offset = 0
+            if not EVP_EncryptUpdate(self.ctx, odata, &olen, <const unsigned char*> idata.buf, ilen):
                 raise Exception('EVP_EncryptUpdate failed')
-            ctl = outl
-            if not EVP_EncryptFinal_ex(self.ctx, out+ctl, &outl):
+            offset += olen
+            if not EVP_EncryptFinal_ex(self.ctx, odata+offset, &olen):
                 raise Exception('EVP_EncryptFinal failed')
-            ctl += outl
-            self.blocks += num_aes_blocks(ctl)
-            return out[:ctl]
+            offset += olen
+            self.blocks = self.block_count(offset)
+            return odata[:offset]
         finally:
-            free(out)
-            PyBuffer_Release(&data_buf)
+            PyMem_Free(odata)
+            PyBuffer_Release(&idata)
 
     def decrypt(self, data):
-        cdef Py_buffer data_buf = ro_buffer(data)
-        cdef int inl = len(data)
-        cdef int ptl = 0
-        cdef int outl = 0
-        # note: modes that use padding, need up to one extra AES block (16b).
-        # This is what the openssl docs say. I am not sure this is correct,
-        # but OTOH it will not cause any harm if our buffer is a little bigger.
-        cdef unsigned char *out = <unsigned char *>malloc(inl+16)
-        if not out:
+        cdef Py_buffer idata = ro_buffer(data)
+        cdef int ilen = len(data)
+        cdef int offset
+        cdef int olen
+        cdef unsigned char *odata = <unsigned char *>PyMem_Malloc(ilen + self.cipher_blk_len)
+        if not odata:
             raise MemoryError
         try:
-            if not EVP_DecryptUpdate(self.ctx, out, &outl, <const unsigned char*> data_buf.buf, inl):
+            # Set cipher type and mode
+            if not EVP_DecryptInit_ex(self.ctx, self.cipher(), NULL, self.enc_key, self.iv):
+                raise Exception('EVP_DecryptInit_ex failed')
+            offset = 0
+            if not EVP_DecryptUpdate(self.ctx, odata, &olen, <const unsigned char*> idata.buf, ilen):
                 raise Exception('EVP_DecryptUpdate failed')
-            ptl = outl
-            if EVP_DecryptFinal_ex(self.ctx, out+ptl, &outl) <= 0:
+            offset += olen
+            if EVP_DecryptFinal_ex(self.ctx, odata+offset, &olen) <= 0:
                 # this error check is very important for modes with padding or
                 # authentication. for them, a failure here means corrupted data.
                 # CTR mode does not use padding nor authentication.
                 raise Exception('EVP_DecryptFinal failed')
-            ptl += outl
-            self.blocks += num_aes_blocks(inl)
-            return out[:ptl]
+            offset += olen
+            self.blocks = self.block_count(ilen)
+            return odata[:offset]
         finally:
-            free(out)
-            PyBuffer_Release(&data_buf)
+            PyMem_Free(odata)
+            PyBuffer_Release(&idata)
+
+    def block_count(self, length):
+        # number of cipher blocks needed for data of length bytes
+        return (length + self.cipher_blk_len - 1) // self.cipher_blk_len
+
+    def set_iv(self, iv):
+        # set_iv needs to be called before each encrypt() call,
+        # because encrypt does a full initialisation of the cipher context.
+        assert isinstance(iv, bytes) and len(iv) == self.iv_len
+        self.blocks = 0  # number of cipher blocks encrypted with this IV
+        for i in range(self.iv_len):
+            self.iv[i] = iv[i]
+
+    def next_iv(self):
+        # call this after encrypt() to get the next iv for the next encrypt() call
+        return increment_iv(self.iv[:self.iv_len], self.blocks)
 
 
 def hmac_sha256(key, data):