#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <string.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>

#include <mtd/mtd-abi.h>
#include <errno.h>

#include <sys/ioctl.h>
#include "mtd_api.h"

#define MTD_PATH_LENGTH_MAX (256)

static ssize_t readn(int fd, void *vptr, size_t n);
static ssize_t writen(int fd, const void *vptr, size_t n);

/*******************************************************************************
 *                       内部函数/mtd_private.h                                  *
 *******************************************************************************/
/**
 * @brief 从文件描述符读取数据
 * @param fd                 入参，文件描述符
 * @param vptr               出参，数据缓冲区
 * @param n                  入参，内容字节数
 * @return 成功返回成功读取的字节数，失败返回-1
 * @retval
 * @note 封装read接口，读取被打断时会重试直到全部读取成功
 * @warning
 */
static ssize_t readn(int fd, void *vptr, size_t n)
{
    size_t nleft;
    ssize_t nread;
    char *ptr;

    ptr = vptr;
    nleft = n;
    while (nleft > 0)
    {
        if ((nread = read(fd, ptr, nleft)) < 0)
        {
            if (errno == EINTR)
            {
                nread = 0; /* and call read() again */
            }
            else
            {
                return (-1);
            }
        }
        else if (nread == 0)
        {
            break; /* EOF */
        }

        nleft -= nread;
        ptr += nread;
    }
    return (n - nleft); /* return >= 0 */
}

/**
 * @brief 向文件描述符写入数据
 * @param fd                 入参，文件描述符
 * @param vptr               入参，数据缓冲区
 * @param n                  入参，内容字节数
 * @return 成功返回成功写入的字节数，失败返回-1
 * @retval
 * @note 封装write接口，写入被打断时会重试直到全部写入成功
 * @warning
 */
static ssize_t writen(int fd, const void *vptr, size_t n)
{
    size_t nleft;
    ssize_t nwritten;
    const char *ptr;

    ptr = vptr;
    nleft = n;
    while (nleft > 0)
    {
        nwritten = write(fd, ptr, nleft);
        if (nwritten < 0)
        {
            if (errno == EINTR)
            {
                nwritten = 0; /* call write() again */
            }
            else
            {
                return (-1); /* error */
            }
        }
        else if (nwritten == 0)
        {
            break;
        }

        nleft -= nwritten;
        ptr += nwritten;
    }
    return (n - nleft); /* return >= 0 */
}

/*******************************************************************************
 *                                外部函数定义                                  *
 *******************************************************************************/
int mtd_find(const char *i_parti_name, device_type_t device_type, char *o_mtd_path, unsigned int o_mtd_path_len)
{
    FILE *fp_mtd = 0;
    char buf[128];
    char *line_str;

    if (!o_mtd_path_len)
    {
        return -1;
    }

    fp_mtd = fopen("/proc/mtd", "r+");
    if (NULL == fp_mtd)
    {
        printf("[libmtd]: mtd_find, open file error:%s", strerror(errno));
        return -1;
    }
    // printf("[libmtd]: partition name:%s\n", i_parti_name);

    while (1)
    {
        int matches = 0;
        char mtdname[64] = {0};
        int mtdnum = 0;
        unsigned int mtdsize, mtderasesize;
        memset(buf, 0, sizeof(buf));
        line_str = fgets(buf, sizeof(buf) - 1, fp_mtd);

        if (NULL == line_str)
        {
            printf("[libmtd]: mtd_find, get info from mtd error:%s\n", strerror(errno));
            fclose(fp_mtd);
            return -1;
        }
        // mtd5: 00100000 00020000 "fotaflag"
        matches = sscanf(buf, "mtd%d: %x %x \"%63[^\"]",
                         &mtdnum, &mtdsize, &mtderasesize, mtdname);
        mtdname[63] = '\0';

        if ((matches == 4) && (strcmp(mtdname, i_parti_name) == 0))
        {
            memset(o_mtd_path, 0, o_mtd_path_len);
            if (device_type == DEVICE_MTD_BLOCK)
            {
                snprintf(o_mtd_path, o_mtd_path_len, "/dev/mtdblock%d", mtdnum);
            }
            else if (device_type == DEVICE_MTD)
            {
                snprintf(o_mtd_path, o_mtd_path_len, "/dev/mtd%d", mtdnum);
            }
            else if (device_type == DEVICE_ZFTL)
            {
                snprintf(o_mtd_path, o_mtd_path_len, "/dev/zftl%d", mtdnum);
            }
            else
            {
                printf("[libmtd]: mtd_find, unknown device type %d\n", device_type);
                fclose(fp_mtd);
                return -1;
            }
            // printf("[libmtd]: o_mtd_path=[%s]\n", o_mtd_path);
            break;
        }
    }
    fclose(fp_mtd);
    return 0;
}

int mtd_get_info(int fd, struct mtd_info_user *info)
{
    if (0 != ioctl(fd, MEMGETINFO, info))
    {
        printf("[libmtd]: mtd_get_info, get fd(%d) info error, %s\n", fd, strerror(errno));
        return -1;
    }
    return 0;
}

int mtd_block_isbad(int fd, off_t offset)
{
    int ret = 0;
    ret = ioctl(fd, MEMGETBADBLOCK, &offset);
    if (ret > 0)
    {
        printf("[libmtd]: mtd_block_isbad, bad block at 0x%lx, ret = %d\n", offset, ret);
        ret = 1;
    }
    else if (ret < 0)
    {
        printf("[libmtd]: mtd_block_isbad, ioctl(MEMGETBADBLOCK) error at 0x%lx, %s, ret = %d\n", offset, strerror(errno), ret);
        ret = -1;
    }
    return ret;
}

int mtd_erase_partition(const char *partition_name)
{
    int ret = 0;
    char mtd_path[MTD_PATH_LENGTH_MAX] = {0};
    int fd_mtd = -1;

    struct mtd_info_user meminfo = {0};
    struct erase_info_user64 erase_info = {0};

    if (NULL == partition_name)
    {
        return -1;
    }

    ret = mtd_find(partition_name, DEVICE_MTD, mtd_path, MTD_PATH_LENGTH_MAX);
    if (ret < 0)
    {
        printf("[libmtd]: mtd_erase_partition, mtd_find %s failed\n", partition_name);
        ret = -1;
        goto out;
    }
    fd_mtd = open(mtd_path, O_RDWR);
    if (fd_mtd < 0)
    {
        printf("[libmtd]: mtd_erase_partition, open %s error, %s\n", partition_name, strerror(errno));
        return -1;
    }
    ret = mtd_get_info(fd_mtd, &meminfo);
    if (ret < 0)
    {
        printf("[libmtd]: mtd_erase_partition, get %s info error, %s\n", partition_name, strerror(errno));
        ret = -1;
        goto out;
    }

    erase_info.length = meminfo.erasesize;
    for (erase_info.start = 0; erase_info.start < meminfo.size; erase_info.start += meminfo.erasesize)
    {
        ret = mtd_block_isbad(fd_mtd, erase_info.start);
        if (1 == ret)
        {
            continue;
        }
        else if (-1 == ret)
        {
            printf("[libmtd]: mtd_erase_partition, mtd_block_isbad %s error\n", partition_name);
            goto out;
        }

        if (0 != ioctl(fd_mtd, MEMERASE64, &erase_info))
        {
            printf("[libmtd]: mtd_erase_partition, erasing %s failure at 0x%llx\n", partition_name, erase_info.start);
            if (ioctl(fd_mtd, MEMSETBADBLOCK, &(erase_info.start)) < 0)
            {
                printf("[libmtd]: mtd_erase_partition, mark %s bad block error, %s\n", partition_name, strerror(errno));
                ret = -1;
                goto out;
            }
            continue;
        }
    }
    ret = 0;
out:
    if (fd_mtd >= 0)
    {
        close(fd_mtd);
    }
    return ret;
}

int mtd_write_partition(const char *partition_name, const char *image_file)
{
    int ret = 0;
    ssize_t wr_len;
    char mtd_path[MTD_PATH_LENGTH_MAX] = {0};
    int fd_mtd = -1;
    struct mtd_info_user meminfo = {0};

    off_t data_size = 0;
    off_t index = 0;
    off_t fd_img_index = 0;
    int fd_img = -1;
    char *buf = NULL;
    struct stat statbuff = {0};

    if (NULL == partition_name || NULL == image_file)
    {
        return -1;
    }

    ret = mtd_find(partition_name, DEVICE_MTD, mtd_path, MTD_PATH_LENGTH_MAX);
    if (ret < 0)
    {
        printf("[libmtd]: mtd_write_partition, mtd_find %s failed\n", partition_name);
        ret = -1;
        goto out;
    }
    fd_mtd = open(mtd_path, O_RDWR);
    if (fd_mtd < 0)
    {
        printf("[libmtd]: mtd_write_partition, open %s error, %s\n", partition_name, strerror(errno));
        return -1;
    }
    ret = mtd_get_info(fd_mtd, &meminfo);
    if (ret < 0)
    {
        printf("[libmtd]: mtd_write_partition, get %s info error, %s\n", partition_name, strerror(errno));
        ret = -1;
        goto out;
    }

    if (stat(image_file, &statbuff) < 0)
    {
        printf("[libmtd]: mtd_write_partition, stat %s failed, %s\n", image_file, strerror(errno));
        ret = -1;
        goto out;
    }

    fd_img = open(image_file, O_RDONLY);
    if (fd_img < 0)
    {
        printf("[libmtd]: mtd_write_partition, open %s failed, %s\n", image_file, strerror(errno));
        ret = -1;
        goto out;
    }

    buf = (char *)malloc(meminfo.erasesize);
    if (!buf)
    {
        printf("[libmtd]: mtd_write_partition, malloc failed\n");
        ret = -1;
        goto out;
    }

    for (index = 0; index < meminfo.size && fd_img_index < statbuff.st_size; index += meminfo.erasesize)
    {
        ret = mtd_block_isbad(fd_mtd, index);
        if (1 == ret)
        {
            continue;
        }
        else if (-1 == ret)
        {
            printf("[libmtd]: mtd_write_partition, mtd_block_isbad %s error,at 0x%llx\n", partition_name, index);
            goto out;
        }

        if (lseek(fd_img, fd_img_index, SEEK_SET) < 0)
        {
            printf("[libmtd]: mtd_write_partition, lseek %s error = %s!\n", image_file, strerror(errno));
            ret = -1;
            goto out;
        }
        if (fd_img_index + meminfo.erasesize > statbuff.st_size)
        {
            data_size = statbuff.st_size - fd_img_index;
        }
        else
        {
            data_size = meminfo.erasesize;
        }
        wr_len = readn(fd_img, buf, (size_t)data_size);
        if (wr_len < 0)
        {
            printf("[libmtd]: mtd_write_partition, read %s error, %s\n", image_file, strerror(errno));
            ret = -1;
            goto out;
        }
        fd_img_index += data_size;

        if (lseek(fd_mtd, index, SEEK_SET) < 0)
        {
            printf("[libmtd]: mtd_write_partition, lseek %s error = %s!\n", partition_name, strerror(errno));
            ret = -1;
            goto out;
        }
        wr_len = writen(fd_mtd, buf, (size_t)data_size);
        if (wr_len < 0)
        {
            printf("[libmtd]: mtd_write_partition, write %s error, %s\n", partition_name, strerror(errno));
            ret = -1;
            goto out;
        }
        memset(buf, 0xFF, meminfo.erasesize);
    }
    if (fd_img_index < statbuff.st_size)
    {
        printf("[libmtd]: mtd_write_partition, No space left, writelen=0x%lx, filesize=0x%lx\n", fd_img_index, statbuff.st_size);
        ret = -1;
        goto out;
    }
    ret = 0;
out:
    if (fd_mtd >= 0)
    {
        close(fd_mtd);
    }

    if (buf != NULL)
    {
        // memset(buf, 0xFF, meminfo.erasesize);
        free(buf);
    }

    if (fd_img >= 0)
    {
        close(fd_img);
    }

    return ret;
}

int mtd_erase_offset(int fd, off_t offset)
{
    struct mtd_info_user meminfo = {0};
    struct erase_info_user64 erase_info = {0};

    if (fd < 0)
    {
        printf("[libmtd]: mtd_erase_offset, fd(%d) error\n", fd);
        return -1;
    }

    if (0 != mtd_get_info(fd, &meminfo))
    {
        printf("[libmtd]: mtd_erase_offset, get fd(%d) info error, %s\n", fd, strerror(errno));
        return -1;
    }

    if (0 != (offset % meminfo.erasesize))
    {
        printf("[libmtd]: mtd_erase_offset, not at the beginning of a block, erasesize is %d\n", meminfo.erasesize);
        return -1;
    }

    erase_info.length = meminfo.erasesize;
    erase_info.start = offset;

    if (mtd_block_isbad(fd, erase_info.start))
    {
        return -1;
    }
    if (0 != ioctl(fd, MEMERASE64, &erase_info))
    {
        printf("[libmtd]: mtd_erase_offset, erasing failure at 0x%llx\n", erase_info.start);
        if (ioctl(fd, MEMSETBADBLOCK, &(erase_info.start)) < 0)
        {
            printf("[libmtd]: mtd_erase_offset, mark bad block error, %s\n", strerror(errno));
        }
        return -1;
    }

    return 0;
}

ssize_t mtd_write_offset(int fd, off_t offset, const void *buf, size_t count)
{
    ssize_t writen_len;
    struct mtd_info_user meminfo = {0};

    if (fd < 0)
    {
        return -1;
    }

    if (0 != mtd_get_info(fd, &meminfo))
    {
        printf("[libmtd]: mtd_write_offset, get fd(%d) info error, %s\n", fd, strerror(errno));
        return -1;
    }
    if (count <= 0 || count > meminfo.erasesize)
    {
        printf("[libmtd]: mtd_write_offset, count(0x%lx), less than 0, or larger than erasesize(0x%x)\n", count, meminfo.erasesize);
        return -1;
    }
    if (0 != (offset % meminfo.erasesize))
    {
        printf("[libmtd]: mtd_write_offset, not at the beginning of a block, erasesize is 0x%x\n", meminfo.erasesize);
        return -1;
    }

    if (mtd_block_isbad(fd, offset))
    {
        return -1;
    }

    if (lseek(fd, offset, SEEK_SET) < 0)
    {
        printf("[libmtd]: mtd_write_offset, lseek error = %s!\n", strerror(errno));
        return -1;
    }
    writen_len = writen(fd, buf, count);
    if (writen_len != count)
    {
        if (-1 == writen_len)
        {
            printf("[libmtd]: mtd_write_offset, write error, %s\n", strerror(errno));
        }
        return -1;
    }

    return writen_len;
}

ssize_t mtd_read_offset(int fd, off_t offset, void *buf, size_t count)
{
    ssize_t readn_len;
    struct mtd_info_user meminfo = {0};

    if (fd < 0)
    {
        return -1;
    }

    if (0 != mtd_get_info(fd, &meminfo))
    {
        printf("[libmtd]: mtd_read_offset, get fd(%d) info error, %s\n", fd, strerror(errno));
        return -1;
    }
    if (count <= 0 || count > meminfo.erasesize)
    {
        printf("[libmtd]: mtd_read_offset, count(0x%lx), less than 0, or larger than erasesize(0x%x)\n", count, meminfo.erasesize);
        return -1;
    }
    if (0 != (offset % meminfo.erasesize))
    {
        printf("[libmtd]: mtd_read_offset, not at the beginning of a block, erasesize is %x\n", meminfo.erasesize);
        return -1;
    }

    if (mtd_block_isbad(fd, offset))
    {
        return -1;
    }

    if (lseek(fd, offset, SEEK_SET) < 0)
    {
        printf("[libmtd]: mtd_read_offset, lseek error = %s!\n", strerror(errno));
        return -1;
    }
    readn_len = readn(fd, buf, count);
    if (readn_len != count)
    {
        if (-1 == readn_len)
        {
            printf("[libmtd]: mtd_read_offset, read error, %s\n", strerror(errno));
        }
        return -1;
    }

    return readn_len;
}

static int zftl_verify_partition(const char *partition_name, const char *image_file)
{
    int ret = 0;
    ssize_t rd_len;
    size_t read_size = 4096;
    char zftl_path[MTD_PATH_LENGTH_MAX] = {0};
    int fd_zftl = -1;
    struct mtd_info_user meminfo = {0};

    off_t data_size = 0;
    off_t index = 0;
    off_t fd_img_index = 0;
    int fd_img = -1;
    char *buf = NULL;
    char *buf2 = NULL;
    struct stat statbuff = {0};

    if (NULL == partition_name || NULL == image_file)
    {
        return -1;
    }

    ret = mtd_find(partition_name, DEVICE_ZFTL, zftl_path, MTD_PATH_LENGTH_MAX);
    if (ret < 0)
    {
        printf("[libmtd]: zftl_verify_partition, mtd_find %s failed\n", partition_name);
        return -1;
    }
    fd_zftl = open(zftl_path, O_RDONLY);
    if (fd_zftl < 0)
    {
        printf("[libmtd]: zftl_verify_partition, open %s error, %s\n", partition_name, strerror(errno));
        return -1;
    }

    if (stat(image_file, &statbuff) < 0)
    {
        printf("[libmtd]: zftl_verify_partition, stat %s failed, %s\n", image_file, strerror(errno));
        ret = -1;
        goto out;
    }

    fd_img = open(image_file, O_RDONLY);
    if (fd_img < 0)
    {
        printf("[libmtd]: zftl_verify_partition, open %s failed, %s\n", image_file, strerror(errno));
        ret = -1;
        goto out;
    }

    buf = (char *)malloc(read_size);
    if (!buf)
    {
        printf("[libmtd]: zftl_verify_partition, malloc failed\n");
        ret = -1;
        goto out;
    }
    buf2 = (char *)malloc(read_size);
    if (!buf2)
    {
        printf("[libmtd]: zftl_verify_partition, malloc2 failed\n");
        ret = -1;
        goto out;
    }

    for (index = 0; fd_img_index < statbuff.st_size; index += read_size)
    {
        //memset(buf, 0xFF, read_size);

        if (fd_img_index + read_size > statbuff.st_size)
        {
            data_size = statbuff.st_size - fd_img_index;
        }
        else
        {
            data_size = read_size;
        }
        // data_size = (fd_img_index + read_size > statbuff.st_size) ? (statbuff.st_size - fd_img_index) : (read_size);

        rd_len = readn(fd_img, buf, (size_t)data_size);
        if (rd_len < 0)
        {
            printf("[libmtd]: zftl_verify_partition, read image file %s error, %s\n", image_file, strerror(errno));
            ret = -1;
            goto out;
        }
        fd_img_index += data_size;

        rd_len = readn(fd_zftl, buf2, (size_t)data_size);
        if (rd_len < 0)
        {
            printf("[libmtd]: zftl_verify_partition, read zftl %s error, %s\n", partition_name, strerror(errno));
            ret = -1;
            goto out;
        }

        if (memcmp(buf, buf2, (size_t)data_size) != 0)
        {
            printf("[libmtd]: zftl_verify_partition, data memcmp %s error\n", partition_name);
            ret = -1;
            goto out;
        }
    }

    ret = 0;
out:
    if (fd_zftl >= 0)
    {
        //fsync(fd_zftl);
        close(fd_zftl);
    }

    if (buf != NULL)
    {
        // memset(buf, 0xFF, meminfo.erasesize);
        free(buf);
    }
    if (buf2 != NULL)
    {
        // memset(buf, 0xFF, meminfo.erasesize);
        free(buf2);
    }

    if (fd_img >= 0)
    {
        close(fd_img);
    }
    if (ret == 0)
    {
        printf("[libmtd]: zftl %s verify sucess\n", partition_name);
    }
    else
    {
        printf("[libmtd]: zftl %s verify fail\n", partition_name);
    }

    return ret;
}

int zftl_write_partition(const char *partition_name, const char *image_file)
{
    int ret = 0;
    ssize_t wr_len;
    size_t write_size = 4096;
    char zftl_path[MTD_PATH_LENGTH_MAX] = {0};
    int fd_zftl = -1;
    struct mtd_info_user meminfo = {0};

    off_t data_size = 0;
    off_t index = 0;
    off_t fd_img_index = 0;
    int fd_img = -1;
    char *buf = NULL;
    struct stat statbuff = {0};

    if (NULL == partition_name || NULL == image_file)
    {
        return -1;
    }

    ret = mtd_find(partition_name, DEVICE_ZFTL, zftl_path, MTD_PATH_LENGTH_MAX);
    if (ret < 0)
    {
        printf("[libmtd]: zftl_write_partition, mtd_find %s failed\n", partition_name);
        return -1;
    }
    fd_zftl = open(zftl_path, O_RDWR);
    if (fd_zftl < 0)
    {
        printf("[libmtd]: zftl_write_partition, open %s error, %s\n", partition_name, strerror(errno));
        return -1;
    }

    if (stat(image_file, &statbuff) < 0)
    {
        printf("[libmtd]: zftl_write_partition, stat %s failed, %s\n", image_file, strerror(errno));
        ret = -1;
        goto out;
    }

    fd_img = open(image_file, O_RDONLY);
    if (fd_img < 0)
    {
        printf("[libmtd]: zftl_write_partition, open %s failed, %s\n", image_file, strerror(errno));
        ret = -1;
        goto out;
    }

    buf = (char *)malloc(write_size);
    if (!buf)
    {
        printf("[libmtd]: zftl_write_partition, malloc failed\n");
        ret = -1;
        goto out;
    }

    for (index = 0; fd_img_index < statbuff.st_size; index += write_size)
    {
        memset(buf, 0xFF, write_size);

        if (fd_img_index + write_size > statbuff.st_size)
        {
            data_size = statbuff.st_size - fd_img_index;
        }
        else
        {
            data_size = write_size;
        }
        // data_size = (fd_img_index + write_size > statbuff.st_size) ? (statbuff.st_size - fd_img_index) : (write_size);

        wr_len = readn(fd_img, buf, (size_t)data_size);
        if (wr_len < 0)
        {
            printf("[libmtd]: zftl_write_partition, read %s error, %s\n", image_file, strerror(errno));
            ret = -1;
            goto out;
        }
        fd_img_index += data_size;

        wr_len = writen(fd_zftl, buf, (size_t)data_size);
        if (wr_len < 0)
        {
            printf("[libmtd]: zftl_write_partition, write %s error, %s\n", partition_name, strerror(errno));
            ret = -1;
            goto out;
        }
        
    }
    if (fd_img_index < statbuff.st_size)
    {
        printf("[libmtd]: zftl_write_partition, No space left, writelen=0x%lx, filesize=0x%lx\n", fd_img_index, statbuff.st_size);
        ret = -1;
        goto out;
    }
    ret = 0;
out:
    if (fd_zftl >= 0)
    {
        fsync(fd_zftl);
        close(fd_zftl);
    }

    if (buf != NULL)
    {
        // memset(buf, 0xFF, meminfo.erasesize);
        free(buf);
    }

    if (fd_img >= 0)
    {
        close(fd_img);
    }

    if (ret == 0)
    {
        //write sucess then do verify
        ret = zftl_verify_partition(partition_name, image_file);
    }

    return ret;
}