I am trying to solve a challenge regarding a RSA oracle which allows me to encrypt/decrypt any plaintext/ciphertext I want, but there are a few checks that I have to bypass, and my goal is to decrypt the given flag. The strategy I am using is basically trying to get the N by making the oracle encrypt some small numbers, and then just adding this N to the encrypted flag to bypass the check:
if c == flag_encrypted:
My current script works if I remove (in my local version of the oracle) the second check on the used array, but of course I cannot remove it in the remote one, which contains the flag I am trying to decrypt. Do you have any idea on how I can bypass the following check?
for no in used:
if m % no == 0:
print("Wait. That's illegal.")
break
Oracle's code:
#!/usr/bin/env python3
import signal
from binascii import hexlify
from Crypto.PublicKey import RSA
from Crypto.Util.number import *
from random import randint
from secret import FLAG
import string
TIMEOUT = 300 # 5 minutes time-out
def menu():
print()
print('Choice:')
print(' [0] Exit')
print(' [1] Encrypt')
print(' [2] Decrypt')
print('')
return input('> ')
def encrypt(m):
return pow(m, rsa.e, rsa.n)
def decrypt(c):
return pow(c, rsa.d, rsa.n)
rsa = RSA.generate(1024)
flag_encrypted = pow(bytes_to_long(FLAG.encode()), rsa.e, rsa.n)
used = [bytes_to_long(FLAG.encode())]
def handle():
print("================================================================================")
print("= RSA Encryption & Decryption oracle =")
print("= Find the flag! =")
print("================================================================================")
print("")
print("Encrypted flag:", flag_encrypted)
while True:
choice = menu()
# Exit
if choice == '0':
print("Goodbye!")
break
# Encrypt
elif choice == '1':
m = int(input('\nPlaintext > ').strip())
used.append(m)
print('\nEncrypted: ' + str(encrypt(m)))
# Decrypt
elif choice == '2':
c = int(input('\nCiphertext > ').strip())
if c == flag_encrypted:
print("Wait. That's illegal.")
else:
m = decrypt(c)
for no in used:
if m % no == 0:
print("Wait. That's illegal.")
break
else:
print('\nDecrypted: ' + str(m))
# Invalid
else:
print('bye!')
break
if __name__ == "__main__":
signal.alarm(TIMEOUT)
handle()
My current script:
from pwn import *
from Crypto.Util.number import *
from math import gcd
import gmpy2
import sys
r = remote('oracle.challs.cyberchallenge.it', 9042)
r.recvuntil(b'Encrypted flag: ')
encrypted_flag = int(r.recvline().strip().decode())
e = 65537
# Let's first gather the ciphertext of the new num
public_exponent = 65537
numbers = [2,3,4,5,6]
numbers_bytes = [b'\x02',b'\x03',b'\x04',b'\x05',b'\x06']
ciphers = []
diffs = []
for i in range(4):
r.recvuntil(b'>')
r.sendline(b'1')
r.recvuntil(b'Plaintext > ')
r.sendline(str(bytes_to_long(numbers_bytes[i])))
r.recvuntil(b'Encrypted: ')
cipher = int(r.recvline().strip().decode())
ciphers.append(cipher)
diffs.append(gmpy2.sub(pow(numbers[i], public_exponent),cipher))
print(diffs)
common_factor = None
for diff in diffs:
if common_factor is None:
common_factor = diff
else:
common_factor = gmpy2.gcd(common_factor, diff)
print("N: ")
print(common_factor)
encrypted_flag += int(common_factor)
r.recvuntil(b'>')
r.sendline(b'2')
r.recvuntil(b'Ciphertext > ')
r.sendline(str(encrypted_flag))
r.recvuntil('Decrypted: ')
flag = int(r.recvline().decode())
print(long_to_bytes(flag))