oaep.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. from Crypto.Util.number import long_to_bytes
  2. from Crypto.Hash import SHA
  3. from .helpers import IntegrityError
  4. def _xor_bytes(a, b):
  5. return ''.join(chr(ord(x[0]) ^ ord(x[1])) for x in zip(a, b))
  6. def MGF1(seed, mask_len, hash=SHA):
  7. """MGF1 is a Mask Generation Function based on hash function
  8. """
  9. T = ''.join(hash.new(seed + long_to_bytes(c, 4)).digest()
  10. for c in range(1 + mask_len / hash.digest_size))
  11. return T[:mask_len]
  12. class OAEP(object):
  13. """Optimal Asymmetric Encryption Padding
  14. """
  15. def __init__(self, k, hash=SHA, MGF=MGF1):
  16. self.k = k
  17. self.hash = hash
  18. self.MGF = MGF
  19. def encode(self, msg, seed, label=''):
  20. # FIXME: length checks
  21. if len(msg) > self.k - 2 * self.hash.digest_size - 2:
  22. raise ValueError('message too long')
  23. label_hash = self.hash.new(label).digest()
  24. padding = '\0' * (self.k - len(msg) - 2 * self.hash.digest_size - 2)
  25. datablock = '%s%s\1%s' % (label_hash, padding, msg)
  26. datablock_mask = self.MGF(seed, self.k - self.hash.digest_size - 1, self.hash)
  27. masked_db = _xor_bytes(datablock, datablock_mask)
  28. seed_mask = self.MGF(masked_db, self.hash.digest_size, self.hash)
  29. masked_seed = _xor_bytes(seed, seed_mask)
  30. return '\0%s%s' % (masked_seed, masked_db)
  31. def decode(self, ciphertext, label=''):
  32. if len(ciphertext) < self.k:
  33. ciphertext = ('\0' * (self.k - len(ciphertext))) + ciphertext
  34. label_hash = self.hash.new(label).digest()
  35. masked_seed = ciphertext[1:self.hash.digest_size + 1]
  36. masked_db = ciphertext[-(self.k - self.hash.digest_size - 1):]
  37. seed_mask = self.MGF(masked_db, self.hash.digest_size, self.hash)
  38. seed = _xor_bytes(masked_seed, seed_mask)
  39. datablock_mask = self.MGF(seed, self.k - self.hash.digest_size - 1, self.hash)
  40. datablock = _xor_bytes(masked_db, datablock_mask)
  41. label_hash2 = datablock[:self.hash.digest_size]
  42. data = datablock[self.hash.digest_size:].lstrip('\0')
  43. if (ciphertext[0] != '\0' or
  44. label_hash != label_hash2 or
  45. data[0] != '\1'):
  46. raise IntegrityError('decryption error')
  47. return data[1:]
  48. def test():
  49. from Crypto.Hash import SHA256
  50. import os
  51. import random
  52. oaep = OAEP(256, SHA256)
  53. for x in range(1000):
  54. M = os.urandom(random.randint(0, 100))
  55. EM = oaep.encode(M, os.urandom(32))
  56. assert len(EM) == oaep.k
  57. assert oaep.decode(EM) == M
  58. if __name__ == '__main__':
  59. test()