/*
 * Copyright (c) 2018 MediaTek Inc.
 *
 * Permission is hereby granted, free of charge, to any person obtaining
 * a copy of this software and associated documentation files
 * (the "Software"), to deal in the Software without restriction,
 * including without limitation the rights to use, copy, modify, merge,
 * publish, distribute, sublicense, and/or sell copies of the Software,
 * and to permit persons to whom the Software is furnished to do so,
 * subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be
 * included in all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
 * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
 * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
 * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
 * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
 */

#include <assert.h>
#include <debug.h>
#include <err.h>
#include <kernel/mutex.h>
#include <kernel/thread.h>
#include <lib/kcmdline.h>
#include <libfdt.h>
#include <list.h>
#include <malloc.h>
#include <stddef.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <trace.h>

#if LK_DEBUGLEVEL > 0
#define LOCAL_TRACE 1
#else
#define LOCAL_TRACE 0
#endif

#define CMDLINE_OVERFLOW_STR    "[ERROR] CMDLINE overflow"

struct subst_entry {
    struct list_node node;
    char *old_arg;
    char *new_arg;
};

/* variable for keeping substition arg */
struct list_node subst_list = LIST_INITIAL_VALUE(subst_list);

/* variable for keeping append arg */
static char *cmdline_buf;
static char *cmdline_tail;
static char *cmdline_end;

static mutex_t lock = MUTEX_INITIAL_VALUE(lock);

static inline void validate_cmdline_boundary(const char *tail, const char *end)
{
    if (tail >= end) {
        dprintf(CRITICAL, CMDLINE_OVERFLOW_STR"\n");
        panic(CMDLINE_OVERFLOW_STR);
    }
}

static void dump_cmdline(void *fdt)
{
    int len;
    int chosen_node_offset;
    const char *cmdline;

    chosen_node_offset = fdt_path_offset(fdt, "/chosen");
    if (chosen_node_offset < 0) {
        LTRACEF("can't find chosen node.\n");
        return;
    }

    cmdline = fdt_getprop(fdt, chosen_node_offset, "bootargs", &len);
    if (!cmdline) {
        LTRACEF("fdt_getprop bootargs failed.\n");
        return;
    }

    LTRACEF("cmdline len=%zd, str=\"%s\"\n", strlen(cmdline), cmdline);
}

int kcmdline_init(void)
{
    if (cmdline_buf)
        return ERR_ALREADY_EXISTS;

    cmdline_buf = (char *)malloc(CMDLINE_LEN);
    if (!cmdline_buf)
        return ERR_NO_MEMORY;

    memset(cmdline_buf, 0, CMDLINE_LEN);
    cmdline_tail = cmdline_buf;
    cmdline_end = cmdline_buf + CMDLINE_LEN;

    return NO_ERROR;
}

int kcmdline_finalized(void *fdt, size_t size)
{
    int n, ret;
    int len;
    int offset;
    const void *fdt_bootargs;
    char *temp_ptr = NULL;
    size_t append_arg_len = 0;
    struct subst_entry *entry;
    struct subst_entry *temp;

    if (!cmdline_buf)
        return ERR_NOT_READY;

    if (!fdt || !size) {
        LTRACEF("Invalid args: fdt(%p), size(%zd)\n", fdt, size);
        return ERR_INVALID_ARGS;
    }

    ret = NO_ERROR;
    mutex_acquire(&lock);

    if ((*cmdline_buf == 0x0) && list_is_empty(&subst_list))
        goto exit;

    /* cmdline_buf is filled with bootargs before kcmdline_finalize is called */
    if (*cmdline_buf != 0x0) {
        append_arg_len = strlen(cmdline_buf);
        temp_ptr = (char *)malloc(append_arg_len + 1);
        if (!temp_ptr) {
            ret = ERR_NO_MEMORY;
            goto exit;
        }
        n = snprintf(temp_ptr, (append_arg_len + 1), "%s", cmdline_buf);
        if (n < 0 || n >= (append_arg_len + 1)) {
            ret = ERR_NO_MEMORY;
            goto exit;
        }
        LTRACEF("append_arg_len:%zu cmdline_buf:%s\n",
                append_arg_len, cmdline_buf);
    }

    ret = fdt_open_into(fdt, fdt, size);
    if (ret) {
        LTRACEF("fdt_open_into failed\n");
        goto exit;
    }
    ret = fdt_check_header(fdt);
    if (ret) {
        LTRACEF("fdt_check_header failed\n");
        goto exit;
    }

    /* Reset cmdline_tail */
    cmdline_tail = cmdline_buf;
    offset = fdt_path_offset(fdt, "/chosen");
    if (offset < 0 ) {
        LTRACEF_LEVEL(CRITICAL, "Can't find chosen node\n");
        ret = ERR_NOT_FOUND;
        goto exit;
    }

    fdt_bootargs = fdt_getprop(fdt, offset, "bootargs", &len);
    if (!fdt_bootargs) {
        ret = ERR_NOT_FOUND;
        goto exit;
    }

    /* add appended string to final string */
    validate_cmdline_boundary(cmdline_tail + append_arg_len +
                              strlen(fdt_bootargs) + 2, cmdline_end);
    cmdline_tail += snprintf(cmdline_tail, CMDLINE_LEN, "%s",
                             (char *)fdt_bootargs);
    if (temp_ptr) {
        cmdline_tail += snprintf(cmdline_tail, cmdline_end - cmdline_tail,
                                 " %s", temp_ptr);
    }

    /* subst string in final string */
    list_for_every_entry_safe(&subst_list, entry, temp,
                              struct subst_entry, node) {
        char *pos = strstr(cmdline_buf, entry->old_arg);
        size_t old_len = strlen(entry->old_arg);
        if (pos && ((*(pos + old_len) == ' ') || (*(pos + old_len) == '\0'))) {
            size_t new_len = strlen(entry->new_arg);
            char *p;

            /* erase old arg with space */
            memset(pos, ' ', old_len);
            if (old_len >= new_len) {
                p = pos; /* replace old arg with new arg */
            } else {
                /* append new arg in the end of cmdline */
                validate_cmdline_boundary(cmdline_tail + new_len + 2,
                                          cmdline_end);
                p = cmdline_tail;
                cmdline_tail += (new_len + 1);
                *p++ = ' ';
            }
            memcpy(p, entry->new_arg, new_len);
        }

        /* free memory and delete node */
        free(entry->old_arg);
        free(entry->new_arg);
        list_delete(&entry->node);
        free(entry);
        entry = NULL;
    }

    ret = fdt_setprop(fdt, offset, "bootargs", cmdline_buf,
                      strlen(cmdline_buf) + 1);
    if (ret != 0) {
        dprintf(CRITICAL, "fdt_setprop error.\n");
        ret = ERR_GENERIC;
        goto exit;
    }

    ret = fdt_pack(fdt);
    if (ret != 0) {
        dprintf(CRITICAL, "fdt_pack error.\n");
        ret = ERR_GENERIC;
        goto exit;
    }

    free(cmdline_buf);
    cmdline_buf = NULL;

#if LOCAL_TRACE
    dump_cmdline(fdt);
#endif

exit:
    free(temp_ptr);
    mutex_release(&lock);

    return ret;
}

void kcmdline_print(void)
{
    struct subst_entry *entry;

    if (!cmdline_buf)
        return;

    mutex_acquire(&lock);
    LTRACEF("append cmdline: %s\n", cmdline_buf);
    LTRACEF("append cmdline size: %zd\n", strlen(cmdline_buf));
    LTRACEF("subst list:\n");

    /* traverse list to show subst_list  */
    list_for_every_entry(&subst_list, entry, struct subst_entry, node) {
        LTRACEF("old_arg: %s, new_arg: %s\n", entry->old_arg, entry->new_arg);
    }
    mutex_release(&lock);
}

int kcmdline_append(const char *append_arg)
{
    size_t append_arg_len;

    if (!cmdline_buf)
        return ERR_NOT_READY;

    if (!append_arg)
        return ERR_INVALID_ARGS;

    mutex_acquire(&lock);
    append_arg_len = strlen(append_arg);
    validate_cmdline_boundary(cmdline_tail + append_arg_len + 1, cmdline_end);
    cmdline_tail += snprintf(cmdline_tail, cmdline_end - cmdline_tail, " %s",
                             append_arg);
    mutex_release(&lock);

    return NO_ERROR;
}

int kcmdline_subst(const char *old_arg, const char *new_arg)
{
    struct subst_entry *entry;

    if (!cmdline_buf)
        return ERR_NOT_READY;

    if (!old_arg || !new_arg) {
        LTRACEF("Invalid args: old_arg(%p), new_arg(%p)\n", old_arg, new_arg);
        return ERR_INVALID_ARGS;
    }

    entry = (struct subst_entry *)malloc(sizeof(struct subst_entry));
    if (!entry)
        return ERR_NO_MEMORY;

    memset(entry, 0, sizeof(struct subst_entry));
    entry->old_arg = strdup(old_arg);
    entry->new_arg = strdup(new_arg);

    if (!entry->old_arg || !entry->new_arg) {
        free(entry->old_arg);
        free(entry->new_arg);
        free(entry);

        return ERR_NO_MEMORY;
    }

    mutex_acquire(&lock);
    list_add_tail(&subst_list, &entry->node);
    mutex_release(&lock);

    return NO_ERROR;
}

