| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071 | from Crypto.Util.number import long_to_bytesfrom Crypto.Hash import SHAfrom .helpers import IntegrityErrordef _xor_bytes(a, b):    return ''.join(chr(ord(x[0]) ^ ord(x[1])) for x in zip(a, b))def MGF1(seed, mask_len, hash=SHA):    """MGF1 is a Mask Generation Function based on hash function    """    T = ''.join(hash.new(seed + long_to_bytes(c, 4)).digest()                for c in range(1 + mask_len / hash.digest_size))    return T[:mask_len]class OAEP(object):    """Optimal Asymmetric Encryption Padding    """    def __init__(self, k, hash=SHA, MGF=MGF1):        self.k = k        self.hash = hash        self.MGF = MGF    def encode(self, msg, seed, label=''):        # FIXME: length checks        if len(msg) > self.k - 2 * self.hash.digest_size - 2:            raise ValueError('message too long')        label_hash = self.hash.new(label).digest()        padding = '\0' * (self.k - len(msg) - 2 * self.hash.digest_size - 2)        datablock = '%s%s\1%s' % (label_hash, padding, msg)        datablock_mask = self.MGF(seed, self.k - self.hash.digest_size - 1, self.hash)        masked_db = _xor_bytes(datablock, datablock_mask)        seed_mask = self.MGF(masked_db, self.hash.digest_size, self.hash)        masked_seed = _xor_bytes(seed, seed_mask)        return '\0%s%s' % (masked_seed, masked_db)    def decode(self, ciphertext, label=''):        if len(ciphertext) < self.k:            ciphertext = ('\0' * (self.k - len(ciphertext))) + ciphertext        label_hash = self.hash.new(label).digest()        masked_seed = ciphertext[1:self.hash.digest_size + 1]        masked_db = ciphertext[-(self.k - self.hash.digest_size - 1):]        seed_mask = self.MGF(masked_db, self.hash.digest_size, self.hash)        seed = _xor_bytes(masked_seed, seed_mask)        datablock_mask = self.MGF(seed, self.k - self.hash.digest_size - 1, self.hash)        datablock = _xor_bytes(masked_db, datablock_mask)        label_hash2 = datablock[:self.hash.digest_size]        data = datablock[self.hash.digest_size:].lstrip('\0')        if (ciphertext[0] != '\0' or            label_hash != label_hash2 or            data[0] != '\1'):            raise IntegrityError('decryption error')        return data[1:]def test():    from Crypto.Hash import SHA256    import os    import random    oaep = OAEP(256, SHA256)    for x in range(1000):        M = os.urandom(random.randint(0, 100))        EM = oaep.encode(M, os.urandom(32))        assert len(EM) == oaep.k        assert oaep.decode(EM) == Mif __name__ == '__main__':    test()
 |