/*-----------------------------------------------------------------------------------------------*/
/**
  @file mbtk_openssl.c
  @brief OPENSSL API
*/
/*-----------------------------------------------------------------------------------------------*/

/*-------------------------------------------------------------------------------------------------
  Copyright (c) 2024 mobiletek Wireless Solution, Co., Ltd. All Rights Reserved.
  mobiletek Wireless Solution Proprietary and Confidential.
-------------------------------------------------------------------------------------------------*/

/*-------------------------------------------------------------------------------------------------
  EDIT HISTORY
  This section contains comments describing changes made to the file.
  Notice that changes are listed in reverse chronological order.
  $Header: $
  when       who          what, where, why
  --------   ---------    -----------------------------------------------------------------
  20250410    yq.wang      Created .
-------------------------------------------------------------------------------------------------*/
#ifdef MBTK_OPENSSL_V3_0_0_SUPPORT
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <sys/socket.h>

#include <sys/select.h>
#include <sys/time.h> 

#include "mbtk_openssl.h"
#include "mbtk_log.h"

#define MBTK_SSL_INFO_FD_DEFAULT -1

// X509 * SSL_get1_peer_certificate(SSL *ssl);

static mbtk_openssl_result_e mbtk_openssl_wait_for_socket(int sockfd, bool is_read)
{
    int ret = -1;
    fd_set fds;
    struct timeval tv = {.tv_sec = 5, .tv_usec = 0}; // 5s timeout
    
    FD_ZERO(&fds);
    FD_SET(sockfd, &fds);
    ret = select(sockfd + 1, is_read ? &fds : NULL, is_read ? NULL : &fds, NULL, &tv);
    if (0 >= ret)
    {
        LOGE("[%s] select() fail.[%d]", __func__, ret);
        return MBTK_OPENSSL_RESULT_FAIL;
    }
    return MBTK_OPENSSL_RESULT_SUCCESS;
}

mbtk_openssl_result_e mbtk_openssl_options_default(mbtk_openssl_options_s *opt)
{
    if(NULL == opt)
    {
        LOGE("[%s] opt [NULL]", __func__);
        return MBTK_OPENSSL_RESULT_FAIL;
    }

    opt->load_cert = false;
    opt->ca_file = NULL;
    opt->crt_file = NULL;
    opt->key_file = NULL;
    opt->ssl_filetype = MBTK_OPENSSL_FILETYPE_PEM;
    opt->verify_mode = MBTK_OPENSSL_VERIFY_PEER;
    opt->verify_cb = NULL;
    opt->init_opts = MBTK_OPENSSL_INIT_LOAD_SSL_STRINGS | MBTK_OPENSSL_INIT_LOAD_CRYPTO_STRINGS \
                     | MBTK_OPENSSL_INIT_ADD_ALL_CIPHERS | MBTK_OPENSSL_INIT_ADD_ALL_DIGESTS;
    opt->safety_level = MBTK_OPENSSL_SAFETY_LEVEL_2;
    
    return MBTK_OPENSSL_RESULT_SUCCESS;
}

int mbtk_openssl_write(SSL *ssl, const void *buf, int len)
{
    return SSL_write(ssl, buf, len);
}

int mbtk_openssl_read(SSL *ssl, void *buf, int len)
{
    return SSL_read(ssl, buf, len);
}

mbtk_openssl_result_e mbtk_openssl_init(int fd, mbtk_openssl_options_s *opt, mbtk_openssl_info_s *inter_info)
{
    int ret = -1;
    int ssl_error = -1;
    long verify_res = -1;
    char *line = NULL;
    X509 *cert = NULL;
    mbtk_openssl_result_e mbtk_ssl_ret = MBTK_OPENSSL_RESULT_SUCCESS;
    const SSL_METHOD *method = NULL;
    mbtk_openssl_info_s temp_inter_info = {0};
    mbtk_openssl_options_s temp_opt = {0};

    if(NULL == inter_info)
    {
        LOGE("[%s] inter_info [NULL]", __func__);
        return MBTK_OPENSSL_RESULT_FAIL;
    }

    if(NULL == opt)
    {
        mbtk_openssl_options_default(&temp_opt);
    }
    else
    {
        memset(&temp_opt, 0x00, sizeof(mbtk_openssl_options_s));
        memcpy(&temp_opt, opt, sizeof(mbtk_openssl_options_s));
    }
    
    //1.Initializes the OPENSSL library
    OPENSSL_init_ssl(temp_opt.init_opts, NULL);

    memset(&temp_inter_info, 0x00, sizeof(mbtk_openssl_info_s));
    //2.Create an SSL/TLS context object
    method = TLS_client_method();
    temp_inter_info.ctx = SSL_CTX_new(method);
    if(NULL == temp_inter_info.ctx)
    {
        LOGE("[%s] SSL_CTX_new() fail", __func__);
        goto error;
    }

    //3.Load certificate
    if(temp_opt.load_cert)
    {
        //3.1-Set the certificate security level
        SSL_CTX_set_security_level(temp_inter_info.ctx, temp_opt.safety_level);
        
        //3.2-Loading a CA Certificate
        if(NULL != temp_opt.ca_file)
        {
            ret = SSL_CTX_load_verify_locations(temp_inter_info.ctx, temp_opt.ca_file, NULL);
            if(1 != ret)
            {
                LOGE("[%s] SSL_CTX_load_verify_locations() fail.[%d]", __func__, ret);
                goto error;
            }
        }

        //3.3-Load the client public key
        if(NULL != temp_opt.crt_file)
        {
            ret = SSL_CTX_use_certificate_file(temp_inter_info.ctx, temp_opt.crt_file, temp_opt.ssl_filetype);
            if(1 != ret)
            {
                LOGE("[%s] SSL_CTX_use_certificate_file() fail.[%d]", __func__, ret);
                goto error;
            }
        }

        //3.4-Load the client private key
        if(NULL != temp_opt.key_file)
        {
            ret = SSL_CTX_use_PrivateKey_file(temp_inter_info.ctx, temp_opt.key_file, temp_opt.ssl_filetype);
            if(1 != ret)
            {
                LOGE("[%s] SSL_CTX_use_PrivateKey_file() fail.[%d]", __func__, ret);
                goto error;
            }
        }

        //3.5-Verify the private key matching certificate
        ret = SSL_CTX_check_private_key(temp_inter_info.ctx);
        if (1 != ret)
        {
            LOGE("[%s] SSL_CTX_check_private_key() fail.[%d]", __func__, ret);
            goto error;
        }

        //3.6-Set verification mode
        SSL_CTX_set_verify(temp_inter_info.ctx, temp_opt.verify_mode, temp_opt.verify_cb);
    }

    //4.Creates and initializes a new SSL/TLS session object
    temp_inter_info.ssl = SSL_new(temp_inter_info.ctx);
    if(NULL == temp_inter_info.ssl)
    {
        LOGE("[%s] SSL_new() fail", __func__);
        goto error;
    }
    SSL_set_fd(temp_inter_info.ssl, fd);

    LOGD("[%s] Performing the SSL/TLS handshake...", __func__);
    //5.Executive handshake
    //SSL_set_connect_state(temp_inter_info.ssl);
    while((ret = SSL_connect(temp_inter_info.ssl)) <= 0)
    {
        ssl_error = SSL_get_error(temp_inter_info.ssl, ret);  
        if(ssl_error == SSL_ERROR_WANT_READ)
        {    
            mbtk_ssl_ret = mbtk_openssl_wait_for_socket(fd, true);
            if(MBTK_OPENSSL_RESULT_SUCCESS != mbtk_ssl_ret)
            {
                LOGE("[%s] mbtk_openssl_wait_for_socket() fail", __func__);
                goto error;
            }
        }
        else if(ssl_error == SSL_ERROR_WANT_WRITE)
        {
            mbtk_ssl_ret = mbtk_openssl_wait_for_socket(fd, false);
            if(MBTK_OPENSSL_RESULT_SUCCESS != mbtk_ssl_ret)
            {
                LOGE("[%s] mbtk_openssl_wait_for_socket() fail", __func__);
                goto error;
            }
        }
        else
        {
            LOGE("[%s] SSL_connect() fail.[%d]", __func__, ssl_error);
            goto error;
        }
    }

    LOGD("[%s] SSL connect ok: Protocol[%s], Ciphersuite[%s]", __func__, SSL_get_version(temp_inter_info.ssl), SSL_get_cipher_name(temp_inter_info.ssl));

    //6.Verification certificate
    if(temp_opt.load_cert)
    {
        cert = SSL_get1_peer_certificate(temp_inter_info.ssl);
        if(NULL != cert)
        {
            verify_res = SSL_get_verify_result(temp_inter_info.ssl);
            if(X509_V_OK != verify_res)
            {
                LOGE("[%s] SSL_get_verify_result() fail.[%s]", __func__, X509_verify_cert_error_string(verify_res));
                goto error;
            }

            LOGD("[%s] Digital certificate information:", __func__);
            
            line = X509_NAME_oneline(X509_get_subject_name(cert), 0, 0);
            LOGD("[%s] certificate: [%s]", __func__, line);
            free(line);
            line = NULL;
            
            line = X509_NAME_oneline(X509_get_issuer_name(cert), 0, 0);
            LOGD("[%s] issuer: [%s]", __func__, line);
            free(line);
            line = NULL;
            
            X509_free(cert);
            cert = NULL;
        }
        else
        {
            LOGD("[%s] No server certificate received", __func__);
            if(temp_opt.verify_mode != MBTK_OPENSSL_VERIFY_NONE)
            {
                LOGE("[%s] Verification fail", __func__);
                goto error;
            }
        }
    }

    temp_inter_info.fd = fd;
    memcpy(inter_info, &temp_inter_info, sizeof(mbtk_openssl_info_s));
    
    return MBTK_OPENSSL_RESULT_SUCCESS;
error:
    if(NULL != cert)
    {
        X509_free(cert);
        cert = NULL;
    }
    if(NULL != temp_inter_info.ssl)
    {
        SSL_shutdown(temp_inter_info.ssl);
        SSL_free(temp_inter_info.ssl);
        temp_inter_info.ssl = NULL;
    }

    if(NULL != temp_inter_info.ctx)
    {
        SSL_CTX_free(temp_inter_info.ctx);
        temp_inter_info.ctx = NULL;
    }

    return MBTK_OPENSSL_RESULT_FAIL;
}

mbtk_openssl_result_e mbtk_openssl_deinit(mbtk_openssl_info_s *inter_info)
{
    if(NULL == inter_info)
    {
        LOGE("[%s] inter_info [NULL]", __func__);
        return MBTK_OPENSSL_RESULT_FAIL;
    }

    if(NULL != inter_info->ssl)
    {
        SSL_shutdown(inter_info->ssl);
        SSL_free(inter_info->ssl);
        inter_info->ssl = NULL;
    }

    if(NULL != inter_info->ctx)
    {
        SSL_CTX_free(inter_info->ctx);
        inter_info->ctx = NULL;
    }

    inter_info->fd = MBTK_SSL_INFO_FD_DEFAULT;
    return MBTK_OPENSSL_RESULT_SUCCESS;
}

#endif
