crypto.pyx 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. """A thin OpenSSL wrapper
  2. This could be replaced by PyCrypto maybe?
  3. """
  4. import hashlib
  5. import hmac
  6. from math import ceil
  7. from libc.stdlib cimport malloc, free
  8. API_VERSION = '1.0_01'
  9. cdef extern from "openssl/rand.h":
  10. int RAND_bytes(unsigned char *buf, int num)
  11. cdef extern from "openssl/evp.h":
  12. ctypedef struct EVP_MD:
  13. pass
  14. ctypedef struct EVP_CIPHER:
  15. pass
  16. ctypedef struct EVP_CIPHER_CTX:
  17. pass
  18. ctypedef struct ENGINE:
  19. pass
  20. const EVP_CIPHER *EVP_aes_256_ctr()
  21. EVP_CIPHER_CTX *EVP_CIPHER_CTX_new()
  22. void EVP_CIPHER_CTX_free(EVP_CIPHER_CTX *a)
  23. int EVP_EncryptInit_ex(EVP_CIPHER_CTX *ctx, const EVP_CIPHER *cipher, ENGINE *impl,
  24. const unsigned char *key, const unsigned char *iv)
  25. int EVP_DecryptInit_ex(EVP_CIPHER_CTX *ctx, const EVP_CIPHER *cipher, ENGINE *impl,
  26. const unsigned char *key, const unsigned char *iv)
  27. int EVP_EncryptUpdate(EVP_CIPHER_CTX *ctx, unsigned char *out, int *outl,
  28. const unsigned char *in_, int inl)
  29. int EVP_DecryptUpdate(EVP_CIPHER_CTX *ctx, unsigned char *out, int *outl,
  30. const unsigned char *in_, int inl)
  31. int EVP_EncryptFinal_ex(EVP_CIPHER_CTX *ctx, unsigned char *out, int *outl)
  32. int EVP_DecryptFinal_ex(EVP_CIPHER_CTX *ctx, unsigned char *out, int *outl)
  33. import struct
  34. _int = struct.Struct('>I')
  35. _long = struct.Struct('>Q')
  36. _2long = struct.Struct('>QQ')
  37. bytes_to_int = lambda x, offset=0: _int.unpack_from(x, offset)[0]
  38. bytes_to_long = lambda x, offset=0: _long.unpack_from(x, offset)[0]
  39. long_to_bytes = lambda x: _long.pack(x)
  40. def bytes16_to_int(b, offset=0):
  41. h, l = _2long.unpack_from(b, offset)
  42. return (h << 64) + l
  43. def int_to_bytes16(i):
  44. max_uint64 = 0xffffffffffffffff
  45. l = i & max_uint64
  46. h = (i >> 64) & max_uint64
  47. return _2long.pack(h, l)
  48. def increment_iv(iv, amount=1):
  49. """
  50. Increment the IV by the given amount (default 1).
  51. :param iv: input IV, 16 bytes (128 bit)
  52. :param amount: increment value
  53. :return: input_IV + amount, 16 bytes (128 bit)
  54. """
  55. assert len(iv) == 16
  56. iv = bytes16_to_int(iv)
  57. iv += amount
  58. iv = int_to_bytes16(iv)
  59. return iv
  60. def num_aes_blocks(int length):
  61. """Return the number of AES blocks required to encrypt/decrypt *length* bytes of data.
  62. Note: this is only correct for modes without padding, like AES-CTR.
  63. """
  64. return (length + 15) // 16
  65. cdef class AES:
  66. """A thin wrapper around the OpenSSL EVP cipher API
  67. """
  68. cdef EVP_CIPHER_CTX *ctx
  69. cdef int is_encrypt
  70. cdef unsigned char iv_orig[16]
  71. cdef long long blocks
  72. def __cinit__(self, is_encrypt, key, iv=None):
  73. self.ctx = EVP_CIPHER_CTX_new()
  74. self.is_encrypt = is_encrypt
  75. # Set cipher type and mode
  76. cipher_mode = EVP_aes_256_ctr()
  77. if self.is_encrypt:
  78. if not EVP_EncryptInit_ex(self.ctx, cipher_mode, NULL, NULL, NULL):
  79. raise Exception('EVP_EncryptInit_ex failed')
  80. else: # decrypt
  81. if not EVP_DecryptInit_ex(self.ctx, cipher_mode, NULL, NULL, NULL):
  82. raise Exception('EVP_DecryptInit_ex failed')
  83. self.reset(key, iv)
  84. def __dealloc__(self):
  85. EVP_CIPHER_CTX_free(self.ctx)
  86. def reset(self, key=None, iv=None):
  87. cdef const unsigned char *key2 = NULL
  88. cdef const unsigned char *iv2 = NULL
  89. if key:
  90. key2 = key
  91. if iv:
  92. iv2 = iv
  93. assert isinstance(iv, bytes) and len(iv) == 16
  94. for i in range(16):
  95. self.iv_orig[i] = iv[i]
  96. self.blocks = 0 # number of AES blocks encrypted starting with iv_orig
  97. # Initialise key and IV
  98. if self.is_encrypt:
  99. if not EVP_EncryptInit_ex(self.ctx, NULL, NULL, key2, iv2):
  100. raise Exception('EVP_EncryptInit_ex failed')
  101. else: # decrypt
  102. if not EVP_DecryptInit_ex(self.ctx, NULL, NULL, key2, iv2):
  103. raise Exception('EVP_DecryptInit_ex failed')
  104. @property
  105. def iv(self):
  106. return increment_iv(self.iv_orig[:16], self.blocks)
  107. def encrypt(self, data):
  108. cdef int inl = len(data)
  109. cdef int ctl = 0
  110. cdef int outl = 0
  111. # note: modes that use padding, need up to one extra AES block (16b)
  112. cdef unsigned char *out = <unsigned char *>malloc(inl+16)
  113. if not out:
  114. raise MemoryError
  115. try:
  116. if not EVP_EncryptUpdate(self.ctx, out, &outl, data, inl):
  117. raise Exception('EVP_EncryptUpdate failed')
  118. ctl = outl
  119. if not EVP_EncryptFinal_ex(self.ctx, out+ctl, &outl):
  120. raise Exception('EVP_EncryptFinal failed')
  121. ctl += outl
  122. self.blocks += num_aes_blocks(ctl)
  123. return out[:ctl]
  124. finally:
  125. free(out)
  126. def decrypt(self, data):
  127. cdef int inl = len(data)
  128. cdef int ptl = 0
  129. cdef int outl = 0
  130. # note: modes that use padding, need up to one extra AES block (16b).
  131. # This is what the openssl docs say. I am not sure this is correct,
  132. # but OTOH it will not cause any harm if our buffer is a little bigger.
  133. cdef unsigned char *out = <unsigned char *>malloc(inl+16)
  134. if not out:
  135. raise MemoryError
  136. try:
  137. if not EVP_DecryptUpdate(self.ctx, out, &outl, data, inl):
  138. raise Exception('EVP_DecryptUpdate failed')
  139. ptl = outl
  140. if EVP_DecryptFinal_ex(self.ctx, out+ptl, &outl) <= 0:
  141. # this error check is very important for modes with padding or
  142. # authentication. for them, a failure here means corrupted data.
  143. # CTR mode does not use padding nor authentication.
  144. raise Exception('EVP_DecryptFinal failed')
  145. ptl += outl
  146. self.blocks += num_aes_blocks(inl)
  147. return out[:ptl]
  148. finally:
  149. free(out)
  150. def hkdf_hmac_sha512(ikm, salt, info, output_length):
  151. """
  152. Compute HKDF-HMAC-SHA512 with input key material *ikm*, *salt* and *info* to produce *output_length* bytes.
  153. This is the "HMAC-based Extract-and-Expand Key Derivation Function (HKDF)" (RFC 5869)
  154. instantiated with HMAC-SHA512.
  155. *output_length* must not be greater than 64 * 255 bytes.
  156. """
  157. digest_length = 64
  158. assert output_length <= (255 * digest_length), 'output_length must be <= 255 * 64 bytes'
  159. # Step 1. HKDF-Extract (ikm, salt) -> prk
  160. if salt is None:
  161. salt = bytes(64)
  162. prk = hmac.HMAC(salt, ikm, hashlib.sha512).digest()
  163. # Step 2. HKDF-Expand (prk, info, output_length) -> output key
  164. n = ceil(output_length / digest_length)
  165. t_n = b''
  166. output = b''
  167. for i in range(n):
  168. msg = t_n + info + (i + 1).to_bytes(1, 'little')
  169. t_n = hmac.HMAC(prk, msg, hashlib.sha512).digest()
  170. output += t_n
  171. return output[:output_length]