// SPDX-License-Identifier: MIT
/*
 * Copyright (c) 2016 MediaTek Inc.
 *
 * Permission is hereby granted, free of charge, to any person obtaining
 * a copy of this software and associated documentation files
 * (the "Software"), to deal in the Software without restriction,
 * including without limitation the rights to use, copy, modify, merge,
 * publish, distribute, sublicense, and/or sell copies of the Software,
 * and to permit persons to whom the Software is furnished to do so,
 * subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be
 * included in all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
 * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
 * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
 * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
 * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
 */

#include <arch/ops.h>
#include <errno.h>
#include <kernel/thread.h>
#include <kernel/vm.h>
#include <lib/bio.h>
#include <libfdt.h>
#include <lib/decompress.h>
#include <lib/mempool.h>
#include <list.h>
#include <string.h>
#include <trace.h>

#include "fit.h"
#include "image.h"

#define LOCAL_TRACE 0

#define uswap_32(x) \
    ((((x) & 0xff000000) >> 24) | \
     (((x) & 0x00ff0000) >>  8) | \
     (((x) & 0x0000ff00) <<  8) | \
     (((x) & 0x000000ff) << 24))

int fit_image_get_node(const void *fit, const char *image_uname)
{
    int noffset, images_noffset;

    images_noffset = fdt_path_offset(fit, FIT_IMAGES_PATH);
    if (images_noffset < 0) {
        dprintf(CRITICAL,"Can't find images parent node '%s' (%s)\n",
                FIT_IMAGES_PATH, fdt_strerror(images_noffset));
        return images_noffset;
    }

    noffset = fdt_subnode_offset(fit, images_noffset, image_uname);
    if (noffset < 0) {
        dprintf(CRITICAL,"Can't get node offset for image name: '%s' (%s)\n",
                image_uname, fdt_strerror(noffset));
    }

    return noffset;
}

int fit_image_get_data(const void *fit, int noffset,
                       const void **data, uint32_t *size)
{
    int len;
    *data = fdt_getprop(fit, noffset, FDT_DATA_NODE, &len);
    if (*data == NULL)
        return -1;

    *size = len;

    return 0;
}

int fit_conf_get_prop_node(const void *fit, int noffset,
                           const char *prop_name)
{
    char *uname;
    int len;

    /* get kernel image unit name from configuration kernel property */
    uname = (char *)fdt_getprop(fit, noffset, prop_name, &len);
    if (uname == NULL)
        return len;

    return fit_image_get_node(fit, uname);
}

/**
 * fit_get_img_subnode_offset() - get a subnode offset for a given image name
 *
 * This finds subnode offset using given image name within node "/images"
 *
 * @fit:    fit image start address
 * @image_name: image name. "kernel", "fdt" or "ramdisk"...
 *
 * returns:
 *    great than or equal 0, on success
 *    otherwise, on failure
 *
 */
static int fit_get_img_subnode_offset(void *fit, const char *image_name)
{
    int noffset;

    /* get image node offset */
    noffset = fdt_path_offset(fit, "/images");
    if (noffset < 0) {
        dprintf(CRITICAL, "Can't find image node(%s)\n", fdt_strerror(noffset));
        return noffset;
    }

    /* get subnode offset */
    noffset = fdt_subnode_offset(fit, noffset, image_name);
    if (noffset < 0)
        dprintf(CRITICAL, "Can't get node offset for image name: '%s' (%s)\n",
                image_name, fdt_strerror(noffset));

    return noffset;
}

/**
 * fit_get_def_cfg_offset() - get a subnode offset from node "/configurations"
 *
 * This finds configuration subnode offset in node "configruations".
 * If "conf" is not given, it will find property "default" for the case.
 *
 * @fit:    fit image start address
 * @conf:   configuration name
 *
 * returns:
 *    great than or equal 0, on success
 *    otherwise, on failure
 *
 */
static int fit_get_def_cfg_offset(void *fit, const char *conf)
{
    int noffset, cfg_noffset, len;

    noffset = fdt_path_offset(fit, "/configurations");
    if (noffset < 0) {
        dprintf(CRITICAL, "can't find configuration node\n");
        return noffset;
    }

    if (conf == NULL) {
        conf = (char *)fdt_getprop(fit, noffset,
                                   "default", &len);
        if (conf == NULL) {
            dprintf(CRITICAL, "Can't get default conf name\n");
            return len;
        }
        dprintf(SPEW, "got default conf: %s\n", conf);
    }

    cfg_noffset = fdt_subnode_offset(fit, noffset, conf);
    if (cfg_noffset < 0)
        dprintf(CRITICAL, "Can't get conf subnode\n");
    else
        dprintf(SPEW, "got conf: %s subnode\n", conf);

    return cfg_noffset;
}

/* check whether fit contains unit address ('@') in node name */
static int fit_check_unit_address(const void *fit, int parent)
{
    struct list_node node_list = LIST_INITIAL_VALUE(node_list);
    struct _fit_node {
        struct list_node node;
        int offset;
    };

    struct _fit_node *node;
    struct _fit_node *subnode;
    int subnode_offset;
    const char *name;
    int offset;

    /* root name has been checked in fdt_check_full(), skip the check */
    if (parent != 0) {
        name = fdt_get_name(fit, parent, NULL);
        if (!name || strchr(name, '@')) {
            LTRACEF("Node name error: offset=0x%x, name='%s'.\n", parent, name);
            goto _err;
        }
    }

    node = malloc(sizeof(struct _fit_node));
    if (!node) {
        LTRACEF("Malloc for node offset=0x%x failed.\n", offset);
        goto _err;
    }
    node->offset = parent;
    list_add_tail(&node_list, &node->node);

    /* BFS traversal to check subnode name */
    do {
        node = (struct _fit_node *)list_remove_head(&node_list);
        if (!node)
            break;

        offset = node->offset;
        free(node);
        fdt_for_each_subnode(subnode_offset, fit, offset) {
            name = fdt_get_name(fit, subnode_offset, NULL);
            if (!name || strchr(name, '@'))
                goto _err;

            node = malloc(sizeof(struct _fit_node));
            if (!node)
                goto _err;

            node->offset = subnode_offset;
            list_add_tail(&node_list, &node->node);
        }
    } while (1);

    return 0;

_err:
    do {
        node = (struct _fit_node *)list_remove_head(&node_list);
        if (!node)
            break;
        free(node);
    } while (1);

    return -1;
}

static int fdt_check_format(const void *fit)
{
    int ret;
    size_t totalsize;

    /* valid fdt? */
    ret = fdt_check_header(fit);
    if (ret) {
        LTRACEF("Error! Not valid fdt format, ret=%d, fit=%p\n", ret, fit);
        return ret;
    }

    /* full check */
    totalsize = fdt_totalsize(fit);
    ret = fdt_check_full(fit, totalsize);
    if (ret) {
        LTRACEF("Error! Fdt full check fail, ret=%d, fit=%p\n", ret, fit);
        return ret;
    }

    return 0;
}

static int fit_check_format(const void *fit)
{
    int ret;

    ret = fdt_check_format(fit);
    if (ret) {
        LTRACEF("Error! fdt_check_format fail, ret=%d\n", ret);
        return -1;
    }

    ret = fit_check_unit_address(fit, 0);
    if (ret) {
        LTRACEF("Error! Fit node names have '@' symbol, not allowed.\n");
        return ret;
    }

    /* must have 'description' property */
    if (!fdt_getprop(fit, 0, FIT_DESC_PROP, NULL)) {
        LTRACEF("Error! FIT without '%s' property.\n", FIT_DESC_PROP);
        return -2;
    }

    /* must have 'timestamp' property */
    if (!fdt_getprop(fit, 0, FIT_TIMESTAMP_PROP, NULL)) {
        LTRACEF("Error! FIT without '%s' property.\n", FIT_TIMESTAMP_PROP);
        return -3;
    }

    /* must have '/images' node */
    if (fdt_path_offset(fit, FIT_IMAGES_PATH) < 0) {
        LTRACEF("Error! FIT without '%s' subnode.\n", FIT_IMAGES_PATH);
        return -4;
    }

    return 0;
}

int fit_get_image(const char *label, void **load_buf)
{
    bdev_t *bdev;
    struct fdt_header fdt = { 0 };
    size_t totalsize;
    int fdt_len, ret = 0;
    void *fit_buf = NULL;

    fdt_len = sizeof(struct fdt_header);
    bdev = bio_open_by_label(label) ? : bio_open(label);
    if (!bdev) {
        dprintf(CRITICAL, "Partition [%s] is not exist.\n", label);
        return -ENODEV;
    }

    if (bio_read(bdev, &fdt, 0, fdt_len) < fdt_len) {
        ret = -EIO;
        goto closebdev;
    }

    ret = fdt_check_header(&fdt);
    if (ret) {
        dprintf(CRITICAL, "[%s] check header failed\n", label);
        goto closebdev;
    }

    totalsize = fdt_totalsize(&fdt);
    fit_buf = mempool_alloc(totalsize, MEMPOOL_ANY);
    if (!fit_buf) {
        ret = -ENOMEM;
        goto closebdev;
    }

    if (bio_read(bdev, fit_buf, 0, totalsize) < totalsize) {
        ret = -EIO;
        goto closebdev;
    }

    /* check fit or fdt format */
    if ((fdt_path_offset((const void *)fit_buf, FIT_IMAGES_PATH) >= 0) ||
        (fdt_path_offset((const void *)fit_buf, FIT_CONFIGS_PATH) >= 0)) {
        /*
         * fdt image (such as legacy dtbo image) doesn't have '/images' or
         * '/configurations' node, if the image have either one of these, do
         * fit image check.
         */
        ret = fit_check_format((const void *)fit_buf);
        if (ret)
            dprintf(CRITICAL, "%s: %s check format failed, ret=%d.\n", label, "fit", ret);
    } else {
        ret = fdt_check_format((const void *)fit_buf);
        if (ret)
            dprintf(CRITICAL, "%s: %s check format failed, ret=%d.\n", label, "fdt", ret);
    }

    if (ret == 0)
        *load_buf = fit_buf;

closebdev:
    bio_close(bdev);
    if ((ret != 0) && (fit_buf != NULL))
        mempool_free(fit_buf);

    return ret;
}

int fit_processing_data(void *fit, const char *image_name, int noffset,
                        addr_t *load, size_t *load_size, paddr_t *entry)
{
    int len, ret, ac;
    size_t size;
    const char *type;
    const void *data, *compression;
    const uint32_t *load_prop, *entry_prop;
    addr_t load_addr;
    paddr_t entry_addr;

    data = fdt_getprop(fit, noffset, "data", &len);
    if (!data) {
        dprintf(CRITICAL, "%s can't get prop data\n", image_name);
        return len;
    }
    size = len;

    compression = fdt_getprop(fit, noffset, "compression", &len);
    if (!compression) {
        dprintf(CRITICAL, "%s compression is not specified\n", image_name);
        return -EINVAL;
    }

    type = fdt_getprop(fit, noffset, "type", &len);
    if (!type) {
        dprintf(CRITICAL, "%s image type is not specified\n", image_name);
        return -EINVAL;
    }

    /* read address-cells from root */
    ac = fdt_address_cells(fit, 0);
    if (ac <= 0 || (ac > sizeof(ulong) / sizeof(uint))) {
        LTRACEF("%s #address-cells with a bad format or value\n", image_name);
        return -EINVAL;
    }

    load_prop = fdt_getprop(fit, noffset, "load", &len);
    if (!load_prop &&
            (!strcmp(type, "kernel") || (!strcmp(type, "loadable")))) {
        dprintf(CRITICAL, "%s need load addr\n", image_name);
        return -EINVAL;
    }

    /* load address determination:
     *   1. "load" property exist: use address in "load" property
     *   2. "load" property not exist: use runtime address of "data" property
     */
    load_addr = (addr_t)data;
    if (load_prop) {
        load_addr = (addr_t)uswap_32(load_prop[0]);
        if (ac == 2)
            load_addr = (load_addr << 32) | (addr_t)uswap_32(load_prop[1]);
#if WITH_KERNEL_VM
        load_addr = (addr_t)paddr_to_kvaddr(load_addr);
#endif
    }

    if (!strcmp((char *)compression, "lz4")) {
        ret = unlz4(data, size - 4, (void *)(load_addr));
        if (ret != LZ4_OK) {
            dprintf(ALWAYS, "lz4 decompress failure\n");
            return -LZ4_FAIL;
        }
        /* In lz4 kernel image, the last four bytes are the uncompressed
         * kernel image size */
        size = *(u32 *)(data + size - 4);
    } else if (!strcmp((char *)compression, "none")) {
        memmove((void *)(load_addr), data, size);
    } else {
        dprintf(CRITICAL, "%s compression does not support\n", image_name);
        return -EINVAL;
    }

#if WITH_KERNEL_VM
    /* always flush cache to PoC */
    arch_clean_cache_range(load_addr, size);
#endif

    LTRACEF("[%s] load_addr 0x%lx\n", image_name, load_addr);
    LTRACEF("[%s] fit = %p\n", image_name, fit);
    LTRACEF("[%s] data = %p\n", image_name, data);
    LTRACEF("[%s] size = %zu\n", image_name, size);

    /* return load, load_size and entry address if caller spcified */
    if (load)
        *load = load_addr;

    if (load_size)
        *load_size = size;

    if (entry) {
        /*
         * entry address determination:
         *   1. "entry" property not exist: entry address = load address
         *   2. "entry" & "load" properties both exist: "entry" property
         *      contains the absolute address of entry, thus
         *      entry address = "entry"
         *   3. only "entry" property exist: "entry" property contains the
         *      entry offset to load address, thus
         *      entry address = "entry" + load address
         */

#if WITH_KERNEL_VM
        load_addr = kvaddr_to_paddr((void *)load_addr);
#endif
        entry_addr = load_addr;
        entry_prop = fdt_getprop(fit, noffset, "entry", &len);
        if (entry_prop) {
            entry_addr = (paddr_t)uswap_32(entry_prop[0]);
            if (ac == 2) {
                entry_addr = (entry_addr << 32) |
                             (paddr_t)uswap_32(entry_prop[1]);
            }
            entry_addr += load_prop ? 0 : load_addr;
        }
        *entry = entry_addr;

        LTRACEF("[%s] entry_addr 0x%lx\n", image_name, *entry);
    }

    return 0;
}

int fit_load_loadable_image(void *fit, const char *sub_node_name, addr_t *load)
{
    int noffset;
    int ret;

    noffset = fit_get_img_subnode_offset(fit, sub_node_name);
    if (noffset < 0) {
        LTRACEF("%s: fit_get_img_subnode_offset fail\n", sub_node_name);
        return noffset;
    }

    if (hash_check_enabled()) {
        ret = fit_image_integrity_verify(fit, noffset);
        LTRACEF("%s: integrity check %s\n",
                sub_node_name, ret ? "fail" : "pass");
        if (ret)
            return -EACCES;
    }

    return fit_processing_data(fit, sub_node_name, noffset, load, NULL, NULL);
}

int fit_conf_verify_sig(const char *conf, void *fit)
{
    int ret;
    int noffset;

    /* get defualt configuration offset (conf_1, conf_2,...or conf_n) */
    noffset = fit_get_def_cfg_offset(fit, conf);
    if (noffset < 0)
        return noffset;

    /* verify config signature */
    if (rsa_check_enabled()) {
        ret = fit_verify_sign(fit, noffset);
        dprintf(ALWAYS, "Verify sign: %s\n", ret ? "fail" : "pass");
        if (ret)
            return -EACCES;
    }

    return 0;
}

static int fit_image_integrity_check_process(void *arg)
{
    int ret;
    struct verify_data *verify_info;

    verify_info = (struct verify_data *)arg;
    ret = fit_image_integrity_verify(verify_info->fit_image,
            verify_info->noffset);

    return ret;
}

int fit_load_image(const char *conf, const char *img_pro, void *fit,
                   addr_t *load, size_t *load_size, paddr_t *entry,
                   bool need_verified)
{
    int noffset, len, cfg_noffset;
    int ret, rc;
    const char *image_name;
    thread_t *integrity_verify_t;

    /* get defualt configuration offset (conf_1, conf_2,...or conf_n) */
    cfg_noffset = fit_get_def_cfg_offset(fit, conf);
    if (cfg_noffset < 0)
        return cfg_noffset;

    /* unit name: fdt_1, kernel_2, ramdisk_3 and so on */
    image_name = (char *)fdt_getprop(fit, cfg_noffset, img_pro, &len);
    if (image_name == NULL) {
        LTRACEF("%s get image name failed\n", img_pro);
        return -ENOENT;
    }

    /* get this sub image node offset */
    noffset = fit_get_img_subnode_offset(fit, image_name);
    if (noffset < 0) {
        dprintf(CRITICAL, "get sub image node (%s) failed\n", image_name);
        return noffset;
    }

    /* verify integrity of this image */
    if (hash_check_enabled() && need_verified) {
#if WITH_SMP
        struct verify_data verify_info;
        verify_info.fit_image = fit;
        verify_info.noffset = noffset;

        integrity_verify_t = thread_create("integrity_verify_t",
            &fit_image_integrity_check_process, &verify_info,
            DEFAULT_PRIORITY, DEFAULT_STACK_SIZE);

        /* Assigned the thread to active cpu */
        extern __WEAK void plat_mp_assign_workcpu(thread_t *t);
        plat_mp_assign_workcpu(integrity_verify_t);
        thread_resume(integrity_verify_t);
#else
        ret = fit_image_integrity_verify(fit, noffset);
        LTRACEF_LEVEL(CRITICAL, "check %s integrity: %s\n",
                image_name, ret ? "fail" : "pass");
        if (ret < 0)
            return -EACCES;
#endif
    } /* verify end */

    rc = fit_processing_data(fit, image_name, noffset, load, load_size, entry);

#if WITH_SMP
    if (hash_check_enabled() && need_verified) {
        thread_join(integrity_verify_t, &ret, INFINITE_TIME);
        LTRACEF_LEVEL(CRITICAL, "check %s integrity: %s\n",
                image_name, ret ? "fail" : "pass");
        if (ret < 0)
            return -EACCES;
    }
#endif

    return rc;
}
