#include <common.h>
#include <tee.h>
#include <malloc.h>
#include <asm/errno.h>
#include <asm/arch/cpu.h>
#include "cipher_optee.h"
#include "../geu/asr_geu.h"

struct tee_contex {
	struct tee_device *tdev;
	struct tee_open_session_arg arg;
};

/*
notice: 1) use hardware key if burned hardware key;
		1) use input key if not burned hardware key but input key not null;
		2) use default key if not burned rkek but input key is null;
		char default_key[64] = {"asr-aes-default-key-without-rkek"};
*/
static const char default_key[64] = {"asr-aes-default-key-without-rkek"};

static int cipher_optee(struct tee_contex *tee_ctx, int op_mode, uint8_t *iv, 
                uint8_t *key, uint32_t key_len, void *in, void *out, uint32_t size)
{
    int ret;
	struct tee_shm *shm;
	struct tee_invoke_arg arg_func = {0};
    uint32_t num_params, ma_len;
    struct tee_param param[4] = {0};
    uint8_t *ma;
    uint32_t iv_size = 16;

    ma_len = 2*size + key_len + iv_size;
	ma = malloc(ma_len);
	if (!ma) {
		printf("No enough memory\n");
		return -1;
	}
	memset(ma, 0, ma_len);

	ret = tee_shm_register(tee_ctx->tdev, (void *)(ulong)ma, ma_len, 0x0, &shm);
	if (ret < 0) {
		printf("Cannot register output shared memory 0x%X\n", ret);
        free(ma);
		return -1;
	}

    memcpy(ma, in, size);
	param[0].attr = TEE_PARAM_ATTR_TYPE_MEMREF_INPUT;
	param[0].u.memref.shm = shm;
    param[0].u.memref.shm_offs = 0;
	param[0].u.memref.size = size;

	param[1].attr = TEE_PARAM_ATTR_TYPE_MEMREF_OUTPUT;
	param[1].u.memref.shm = shm;
    param[1].u.memref.shm_offs = size;
	param[1].u.memref.size = size;

    /* use input key */
    if (key) {
        memcpy((ma + 2*size), key, key_len);
        /*  cbc*/
        if (iv) {
            param[2].attr = TEE_PARAM_ATTR_TYPE_VALUE_INPUT;
            param[2].u.value.a = op_mode;
            param[2].u.value.b = key_len;

            memcpy((ma + 2*size + key_len), iv, iv_size);
            param[3].attr = TEE_PARAM_ATTR_TYPE_MEMREF_INPUT;
            param[3].u.memref.shm = shm;
            param[3].u.memref.shm_offs = 2*size;
	        param[3].u.memref.size = key_len + iv_size;
            num_params = 4;
            arg_func.func = CMD_AES_CBC;
        } else {    /* ecb */
            param[2].attr = TEE_PARAM_ATTR_TYPE_VALUE_INPUT;
            param[2].u.value.a = op_mode;

            param[3].attr = TEE_PARAM_ATTR_TYPE_MEMREF_INPUT;
            param[3].u.memref.shm = shm;
            param[3].u.memref.shm_offs = 2*size;
	        param[3].u.memref.size = key_len;
            num_params = 4;
            arg_func.func = CMD_AES_ECB;
        }
    } else {    /* use rkek */
        if (iv) {    /*  cbc*/
            param[2].attr = TEE_PARAM_ATTR_TYPE_VALUE_INPUT;
            param[2].u.value.a = key_len;
            param[2].u.value.b = op_mode;

            memcpy((ma + 2*size + key_len), iv, iv_size);
            param[3].attr = TEE_PARAM_ATTR_TYPE_MEMREF_INPUT;
            param[3].u.memref.shm = shm;
            param[3].u.memref.shm_offs = 2*size + key_len;
	        param[3].u.memref.size = iv_size;
            num_params = 4;
            arg_func.func = CMD_AES_HWKEY_CBC;
        } else {     /* ecb */
            param[2].attr = TEE_PARAM_ATTR_TYPE_VALUE_INPUT;
            param[2].u.value.a = key_len;
            param[2].u.value.b = op_mode;
            num_params = 3;
            arg_func.func = CMD_AES_HWKEY_ECB;
        }
    }

	arg_func.session = tee_ctx->arg.session;

	ret = tee_invoke_func(tee_ctx->tdev , &arg_func, num_params, param);
    if (ret) {
        printf("Cannot get sha data from OP-TEE PTA_AES\n");
    } else if (arg_func.ret != 0) {
        free(ma);
		return -1;
	}
    memcpy(out, (ma + size), size);

    tee_shm_free(shm);
    free(ma);

	return 0;
}

static int get_rkek_state(struct tee_contex *tee_ctx, int *rkek_burned)
{
    int ret;
    struct tee_param param[1] = {0};
	struct tee_invoke_arg arg_func = {0};

    param[0].attr = TEE_PARAM_ATTR_TYPE_VALUE_OUTPUT;
    arg_func.func = CMD_AES_HWKEY_STATUS;
	arg_func.session = tee_ctx->arg.session;

	ret = tee_invoke_func(tee_ctx->tdev , &arg_func, 1, param);
    if (ret) {
        printf("Cannot get sha data from OP-TEE PTA_AES\n");
        return ret;
    }

    *rkek_burned = param[0].u.value.a;

	return 0;
}

int aes_encrypt_optee(int op_mode, uint8_t *iv, uint8_t *key, bool use_rkek,
                uint32_t key_len, void *in, void *out, uint32_t size)
{
    int ret;
    uint8_t *use_key;
    int rkek_burned;
    struct tee_contex tee_ctx = {0};
	const struct tee_optee_ta_uuid uuid = OPTEE_AES_ACCESS_UUID;

	if ((key_len > 32) && (key_len < 16)) {
		printf("err: aes encrypt key len %d\n", key_len);
		return -1;
	}

	if ((key == NULL) && (use_rkek == false)) {
		printf("%s error: key can't NULL when not use rkek\n", __func__);
		return -1;
	}

	tee_ctx.tdev = tee_find_device(NULL, NULL, NULL, NULL);
	if (!tee_ctx.tdev) {
		printf("Cannot get OP-TEE device\n");
		return -1;
	}

	/* Set TA UUID */
	tee_optee_ta_uuid_to_octets(tee_ctx.arg.uuid, &uuid);

	/* Open TA session */
	ret = tee_open_session(tee_ctx.tdev , &tee_ctx.arg, 0, NULL);
	if (ret < 0) {
		printf("Cannot open session with PTA Blob 0x%X\n", ret);
		return -1;
	}

    ret = get_rkek_state(&tee_ctx, &rkek_burned);
	if (ret < 0) {
		printf("Cannot get rkek burned state 0x%X\n", ret);
		return -1;
	}

    if (!rkek_burned) {
        if (key == NULL) {
            use_key = (uint8_t *)default_key;
        } else {
            use_key = key;
        }
    } else {
        if (use_rkek) {
            use_key = NULL;
        } else {
            use_key = key;
        }
    }

    ret = cipher_optee(&tee_ctx, op_mode, iv, use_key, key_len, in, out, size);
	if (ret < 0) {
        return -1;
    }

	ret = tee_close_session(tee_ctx.tdev, tee_ctx.arg.session);
	if (ret < 0) {
		printf("Cannot close session with PTA_AES 0x%X\n", ret);
        return -1;
    }

    return 0;
}

int aes_ecb_encrypt_optee(uint8_t *key, uint32_t key_len, bool use_rkek,
                        void *in, void *out, uint32_t size)
{
    return aes_encrypt_optee(1, NULL, key, use_rkek, key_len, in, out, size);
}

int aes_ecb_decrypt_optee(uint8_t *key, uint32_t key_len, bool use_rkek,
                        void *in, void *out, uint32_t size)
{
    return aes_encrypt_optee(0, NULL, key, use_rkek, key_len, in, out, size);
}

int aes_cbc_encrypt_optee(uint8_t *iv, uint8_t *key, uint32_t key_len, 
                        bool use_rkek, void *in, void *out, uint32_t size)
{
    if (!iv) {
        return -1;
    }
    return aes_encrypt_optee(1, iv, key, use_rkek, key_len, in, out, size);
}

int aes_cbc_decrypt_optee(uint8_t *iv, uint8_t *key, uint32_t key_len, 
                        bool use_rkek, void *in, void *out, uint32_t size)
{
    if (!iv) {
        return -1;
    }
    return aes_encrypt_optee(0, iv, key, use_rkek, key_len, in, out, size);
}
