#!/usr/bin/python
# -*- coding: utf8 -*-

import json
import struct
import sys
import binascii
import cryptoSB
import copy

from jsoncomment import JsonComment

__Align = 16
__Align_16_en = 0

__TEST = 0

g_magic_size = 8
g_total_len_size = 4
g_bl_version_size = 4
g_img_number_size = 4
g_ls_cmd_number_size = 4
g_mi_header_info_size = (g_magic_size + g_total_len_size + g_bl_version_size + g_img_number_size + g_ls_cmd_number_size)

g_img_header_size = 96
g_ls_cmd_size = 8

g_img_oneoffset_table_size = 4
g_img_onehash_table_size = 64

g_public_key_size = 96
g_sbc_auth_alg_size = 4
g_sbc_auth_inf_size = 4
g_sbc_size = (g_public_key_size + g_sbc_auth_alg_size + g_sbc_auth_inf_size)

g_boot_hash_size = 64
g_signature_size = 96
g_auth_size = (g_boot_hash_size + g_signature_size)

def to_int(s):
    if isinstance(s, (str, unicode)):
        return int(s, 0)
    else:
        return s

def byte_to_int(bytes):
    result = 0

    for b in bytes:
        result = result*256 + ord(b)

    return result

def read_desc(mi_desc_fn):
    #desc = json.load(open(mi_desc_fn))

    parser = JsonComment(json)
    desc = parser.loads(open(mi_desc_fn).read())

    for img in desc['images']:
        img['img_enc_alg'] = to_int(img['img_enc_alg'])
        img['img_enc_inf'] = to_int(img['img_enc_inf'])

        '''0627 remove int assign avoid key start is (0x00..)'''
        #img['img_iv'] = to_int(img['img_iv'])

    return desc


def gen_ver_magic(major, minor, revision):
    major = to_int(major)
    minor = to_int(minor)
    revision = to_int(revision)

    x = major ^ 0xaa
    y = minor ^ 0x55
    z = revision ^ 0x99
    checksum = (x + y + z) & 0xff

    return (major << 0) | (minor << 8) | (revision << 16) | (checksum << 24)

def my_gen_ver_magic(major, minor, revision):
    major = to_int(major)
    minor = to_int(minor)
    revision = to_int(revision)

    x = major ^ 0xaa
    y = minor ^ 0x55
    z = revision ^ 0x99
    checksum = (x + y + z) & 0xff

    highmagic = byte_to_int(b'\x27\x21\xca\xfe')
    #print(hex(highmagic))

    return (major << 0) | (minor << 8) | (revision << 16) | (checksum << 24), highmagic

def my_gen_mtk_magic(magic):
    magic_str = my_to_bytes(magic, g_magic_size, endianess='little')

    lowmagic = byte_to_int(magic_str[4:8])
    highmagic = byte_to_int(magic_str[0:4])

    return lowmagic, highmagic

def pad_to_align(align, current_size):
    if not align:
        return ''

    pad_size = (align - current_size % align) % align
    pad = '\0' * pad_size

    return pad


def dump(data):
    for i in range(0, len(data)):
        if(i%16 == 0):
            print("[%04X]" % i),

        print("%02x" % ord(data[i])),
        if(((i+1)%16) == 0):
            print

def pack_load_script(mi_desc):
    ls_script = []
    reserved = 0
    images = mi_desc['images']
    img_num = len(images)

    ls_cmds = mi_desc['load_srcipt_cmd']
    ls_cmds_num = len(ls_cmds)
    #print("ls_cmds_num: %d" %ls_cmds_num)

    for i in range(ls_cmds_num):
        ls_cmd = ls_cmds[i]
        cmd = ls_cmd['cmd']

        if cmd == "LOAD":
            cmd_id = 0
            img_file = ls_cmd['img_file']
            addr = to_int(ls_cmd['addr'])

            img_id = -1
            for j in range(img_num):
                img = images[j]
                img_binary_name = img['img_file']
                if img_file == img_binary_name:
                    img_id = j
                    break

            if img_id == -1:
                print("Please check img file name")
                return

            packed_cmd = struct.pack('<BBHI', cmd_id, img_id, reserved, addr)
            #dump(packed_cmd)
            #print
        elif cmd == "MCU-ENTRY":
            cmd_id = 2
            mcu_id =  ls_cmd['mcu_id']
            addr = to_int(ls_cmd['addr'])
            packed_cmd = struct.pack('<BBHI', cmd_id, mcu_id, reserved, addr)
            #dump(packed_cmd)
            #print
        elif cmd == "MCU_RESET-ENTRY":
            cmd_id = 1
            mcu_id =  ls_cmd['mcu_id']
            addr = to_int(ls_cmd['addr'])
            packed_cmd = struct.pack('<BBHI', cmd_id, mcu_id, reserved, addr)
            #dump(packed_cmd)
            #print
        else:
            print("unknown command: %s" %cmd)
            return

        ls_script += packed_cmd

    #print("load script:")
    #dump(ls_script)

    return ls_script


def my_pack_images(mi_desc, privk, enckey, align=None, img_dir=''):
    '''
    mipack_file
        mi_header_info
            u32 magic
            u32 total_len
            u32 bl_version
            u32 img_number (N)
            u32 load_script_cmd_number (M)

        img_info[1]~img_info[N]
            u32     img_length
            u32     img_offset
            64Byes  img_hash
            u32     img_enc_inf
            u32     img_enc_alg
            16Bytes img_iv

        load_script_cmd[1]~load_script_cmd[M]
            u64     cmd

        sbc data
            96Bytes public_key
            u32     sbc_auth_inf
            u32     sbc_auth_alg

        auth data
            64Bytes boothash
            96Bytes signature

        [img binary 1]
        ...
        [img binary 1]
    '''

    key_mode = 0

    mi_header_struct = struct.Struct('<IIIIII')
    sbc_struct = struct.Struct('<II')

    #run hash
    hash = cryptoSB.HASH()

    # ver_magic
    #version = mi_desc['version']
    #major, minor, revision = version.split('.')
    #ver_magic_l, ver_magic_h = my_gen_ver_magic(major, minor, revision)

    ver_magic_l, ver_magic_h = my_gen_mtk_magic(to_int(mi_desc['magic_num']))

    # bl region
    bl_version = mi_desc['bl_version']
    images = mi_desc['images']
    img_num = len(images)
    ls_cmds = mi_desc['load_srcipt_cmd']
    ls_cmds_num = len(ls_cmds)

    '''size of bl + imgl + sbc + auth'''
    # Fixed size section: mi_header, sbc data, auth data
    # Non-fixed size section: img_info, load script
    bl_total_len = g_mi_header_info_size + g_sbc_size + g_auth_size
    bl_total_len += img_num * g_img_header_size
    bl_total_len += ls_cmds_num * g_ls_cmd_size


    #print("==== multi-image header size ====")
    #print("g_mi_header_info_size=%d, g_sbc_size=%d, g_auth_size=%d, bl_total_len = %d" %(g_mi_header_info_size, g_sbc_size, g_auth_size, bl_total_len))

    #img_header_table = []
    img_headers = []
    bins = []
    sbc_region = []
    hashtable = []
    imagetmp = [[], [], [], [], [], [], [], [], [], []]
    packtmp = []
    blh_hash = []
    signaturelist = []

    imagesize = []


    # sbc region
    ##pub_key = to_int(mi_desc['public_key'])
    ##public_key = my_binify(pub_key)
    '''20180627 remove int assign avoid key start is (0x00..)'''
    pub_key = (mi_desc['public_key'])
    public_key = cryptoSB.strip_key(pub_key)
    public_key_append = my_pad_to_align(public_key, 96)
    sbc_region.append(public_key_append)
    pubk = b'\x04' + public_key

    #print("======= sbc region =========")
    sbc_auth_alg = mi_desc['sbc_auth_alg']
    #print("sbc_auth_alg = 0x%x" %sbc_auth_alg)
    sbc_auth_inf = mi_desc['sbc_auth_inf']
    #print("sbc_auth_inf = 0x%x" %sbc_auth_inf)
    sbc = sbc_struct.pack(sbc_auth_inf, sbc_auth_alg)
    sbc_region.append(sbc)


    # pad before 1st image

    #pad = pad_to_align(64, expected_size)
    #bins.append(pad)
    #expected_size += len(pad)


    # images
    # 16 bytes aligned
    img_offset = (bl_total_len + 15) & (~15)
    img_offset_table = []

    for i in range(img_num):
        #print("\n\n======== Start image step: =========")


        img = images[i]
        img_file = img['img_file']
        img_enc_inf = img['img_enc_inf']
        img_enc_alg = img['img_enc_alg']
        #print("img_enc_inf=0x%x" %img_enc_inf)

        bin = open(img_dir + '/' + img_file, 'rb').read()

        iv_int = img['img_iv']
        #iv = my_binify(iv_int)
        '''0627 remove int assign avoid key start is (0x00..)'''
        iv = cryptoSB.strip_key(iv_int)
        iv0= int(iv[0:4].encode('hex'), 16)
        iv1= int(iv[4:8].encode('hex'), 16)
        iv2= int(iv[8:12].encode('hex'), 16)
        iv3= int(iv[12:16].encode('hex'), 16)
        #print type(iv)
        #print("IV:")
        #cryptoSB.dump(iv)


        if(img_enc_inf == 1):
            '''image encrypt'''
            out = cryptoSB.my_img_enc(bin, enckey, iv, img_enc_alg, __Align)
            bins.append(out)
            #bins = out
        elif(img_enc_inf == 2):
            '''image encrypt'''
            out = cryptoSB.my_img_enc(bin, enckey, iv, img_enc_alg, __Align)
            bins.append(out)
            #bins = out
        else:
            '''plaintext image'''
            out = my_pad_to_align(bin, __Align)
            bins.append(out)
            #bins = out

        # binary length should be 16 bytes aligned
        length = len(out)

        imagesize.append(length)
        #print("")
        #print("image[%d] offset : 0x%x" %(i, img_offset))
        #print("image[%d] size   : 0x%x" %(i, imagesize[i]))

        imagetmp[i] = copy.copy(bins)

        img_str = ''.join(bins)
        #print type(img_str)
        #print("========= image[%d] binary ==========" %i)
        #cryptoSB.dump(img_str)


        # hash each (image header + image binary)
        #print('')
        #print("========= image[%d] binary hash ==========" %i)
        hashvalue = cryptoSB.sb_hash(img_str, sbc_auth_alg)
        imghash = my_pad_to_align(hashvalue, 64)
        #cryptoSB.dump(imghash)

        # img_header
        img_hdr = struct.pack('<II',
                                length,
                                img_offset)
        img_hdr += "".join(imghash)
        '''20180719 fix IV order fail'''
        img_hdr += struct.pack('<II16s',
                                img_enc_inf,
                                img_enc_alg,
                                iv)


        img_offset_table.append(str(img_offset))
        img_offset += length
        #img_offset_table.append(my_to_bytes(img_offset, 4, endianess='little'))

        #print("\n=====>")
        #print("image[%d] header info :" %i)
        #cryptoSB.dump(img_hdr)
        img_headers.append(img_hdr)

        #img_headers.remove(img_hdr)
        while len(bins) > 0:
            bins.pop()


    #print("\n\nSTART to pack all sections ...")
    pack = []

    #print("======== append mi_header info ==========")
    total_len = int(img_offset_table[img_num-1]) + int(imagesize[img_num-1])

    mi_header = mi_header_struct.pack(
                        ver_magic_l,
                        ver_magic_h,
                        total_len,
                        bl_version,
                        img_num,
                        ls_cmds_num)


    pack.append(mi_header)

    # append image info
    for i in range(img_num):
        pack += img_headers[i]


    ls_script = pack_load_script(mi_desc)
    if ls_script == None:
        print("pack_load_script fail")
        return

    # append load script
    pack += ls_script

    # append sbc data
    pack += sbc_region

    # for easy view. please remove it while release final
    '''align for (bl + imgl + sbc)'''
    if(__Align_16_en == 1):
        padnum = pad_to_align(16, len(''.join(pack)))
        pack.append(padnum)

    #print("======== append mi_header hash: ==========")
    bl_header = ''.join(pack)
    #cryptoSB.dump(bl_header)
    blh = copy.copy(bl_header)
    boothash = cryptoSB.sb_hash(blh, sbc_auth_alg)
    boothash_append = my_pad_to_align(boothash, g_boot_hash_size)
    #cryptoSB.dump(boothash_append)
    blh_hash.append(boothash_append)

    # append hash
    pack += blh_hash

    #print("======== append mi_header signature: =========")
    #privk = "\xc1\xbe\xe4\xfa\x86\xaf\x86\x84\x67\x7c\xae\xee\xa8\x8a\xb0\x72\x3e\x55\x4a\xef\x01\x60\xb8\xfc\x65\x3c\x0e\x00\x08\x0f\x4f\x78"
    #pubk = "\x04\x14\xc1\xcf\x10\x99\x9d\x3a\x98\xf3\x71\xb8\xd8\x9b\x3b\x26\xb2\x9e\xe1\xbd\x99\xf3\xe0\x39\x3d\x34\x21\x6a\x6f\x49\x58\x7a\xb1\xdd\x8a\xba\x7a\x9d\x02\x99\x5f\xda\xa0\xb8\x62\x82\xae\xc2\xd0\xc6\x88\xc2\x26\x03\x97\x86\x65\x46\xbb\x20\xc9\xd1\x44\xb9\x84"

    if sbc_auth_inf == 0:
        padbytes = '\0' * g_signature_size
        signaturelist.append(padbytes)
    else:
        pem_key_format = mi_desc['pem_key_format']
        if(pem_key_format == 1):
            '''PEM format'''
            key_mode = 1
        elif(pem_key_format == 2):
            '''DER format'''
            key_mode = 2
        else:
            '''String format'''
            key_mode = 0

        #ecdsa = cryptoSB.ECDSA()
        #signature = ecdsa.sign(privk, pubk, boothash, sbc_auth_alg)
        ''' 20180616 fix the sb_sign msg error : not boothash -> blh is right'''
        signature = cryptoSB.sb_sign(privk, pubk, blh, sbc_auth_alg, key_mode)
        #print("signature size = %d" %len(signature))
        signature_append = my_pad_to_align(signature, g_signature_size)
        #cryptoSB.dump(signature_append)
        signaturelist.append(signature_append)

        #check verify
        #ret = ecdsa.verify(pubk, boothash, signature, sbc_auth_alg)
        #print("ecdsa verify: %s" %ret)


    #dump("".join(signaturelist))
    # append signature
    pack += signaturelist

    # for easy view, please remove it while release final
    '''align for (bl + imgl + sbc + auth)'''
    if(__Align_16_en == 1):
        padnum = pad_to_align(16, len(''.join(pack)))
        pack.append(padnum)

    # append image binary
    for i in range(img_num):
        offset = int(img_offset_table[i])
        pad_num = offset - len(''.join(pack))
        #print("offset = %d" %offset)
        #print("pad_num = %d" %pad_num)

        padbytes = '\0' * pad_num
        pack.append(padbytes)
        pack += imagetmp[i]

    #print(len(''.join(pack)))

    # clear list
    while len(signaturelist) > 0:
        signaturelist.pop()


    return ''.join(pack)

'''support to_bytes to python 2.7'''
def my_to_bytes(n, length, endianess='big'):
    h = '%x' % n
    s = ('0'*(len(h) % 2) + h).zfill(length*2).decode('hex')
    return s if endianess == 'big' else s[::-1]

'''long to byte string'''
def my_binify(x):
    h = hex(x)[2:].rstrip('L')
    return binascii.unhexlify(h)

def my_binify2(x):
    if(hex(x)[0:2] == '0x'):
        h = hex(x)[2:].rstrip('L')
    else:
        h = hex(x).rstrip('L')

    return binascii.unhexlify(h)

def my_pad_to_align(x, align):
    if(align == 0):
        return x
    else:
        size = len(x)
        #print("size 0x%x" %size)
        pad_size = (align - size % align) % align
        for i in range(0, pad_size, 1):
            x += b'\x00'

        #cryptoSB.dump(x)
        return x


def img_enc(mi_desc, key, align=None):

    images = mi_desc['images']
    img_num = len(images)

    img = images[0]
    #load_addr = img['load_addr']
    #entrypoint = img['entrypoint']
    img_file = img['img_file']
    img_enc_alg = img['img_enc_alg']
    #print(img_enc_alg)
    img_enc_inf = img['img_enc_inf']
    #print(img_enc_inf)
    iv = img['img_iv']
    iv = my_binify(iv)
    #print type(iv)
    #print("img_enc dump:")
    #cryptoSB.dump(iv)

    bin = open(img_file, 'rb').read()
    #cryptoSB.dump(bin)

    #align
    bin = my_pad_to_align(bin, 64)

    aes = cryptoSB.AESCipher(key, 16)
    encmsg = aes.aes_encrypt(bin)
    #print("result image enc:")
    #cryptoSB.dump(encmsg)

    return encmsg




if __name__ == '__main__':
    import os, sys, getopt


    def print_usage():
        print ('usage:', os.path.basename(sys.argv[0]), "[options] <image config>.json\n", \
            'options:\n', \
            '\t[-o | --output out.img]\n', \
            '\t[-h | --help]\n', \
            '\t[-i | --input]\n', \
            '\t[-k | --prikey hexkey e.g. 0x0102..]\n', \
            '\t[-s | --enckey hexkey e.g. 0x0102..]\n', \
            '\t[-p | --pemdir <pem path>\n', \
            '\t[-d | --imgdir <image path>\n')


    def main():
        opts, args = getopt.getopt(sys.argv[1:],
                                   'ho:a:i:k:s:p:d:',
                                   ['help', 'output=', 'align=', 'input=',
                                    'prikey=', 'enckey=', 'pemdir=', 'imgdir=']
                                  )

        out_name = None
        align = 0
        infile_name = None
        aeskey = None
        pubkey = None
        privk = None
        pem_dir = "binfile"
        img_dir = ''

        for o, a in opts:
            if o in ('-h', '--help'):
                print_usage()
                sys.exit()
            elif o in ('-o', '--output'):
                out_name = a
            elif o in ('-a', '--align'):
                align = int(a)
            elif o in ('-i', '--input'):
                ## doesn't need currently
                infile_name = a
            elif o in ('-k', '--prikey'):
                privkey = a
            elif o in ('-s', '--enckey'):
                aeskey = a
            elif o in ('-p', '--pemdir'):
                pem_dir = a
            elif o in ('-d', '--imgdir'):
                img_dir = a
            else:
                print_usage()
                sys.exit(1)

        if len(args) >= 1:
            mi_desc_fn = args[0]
        else:
            print_usage()
            sys.exit(1)

        if not out_name:
            fn, ext = os.path.splitext(mi_desc_fn)
            out_name = fn + '.img'

        """ read json script """
        mi_desc = read_desc(mi_desc_fn)

        #mipack = pack_images(mi_desc, align)

        #print 'output: %s (%d bytes)' % (out_name, len(mipack))
        #open(out_name, 'wb').write(mipack)

        cmd_line_key = mi_desc['cmd_line_key']
        sign_priv_key = mi_desc['sign_priv_key']
        aes_enc_sym_key = mi_desc['aes_enc_sym_key']

        """ Where is the key input from """

        pem_key_format = mi_desc['pem_key_format']
        if(pem_key_format == 1):
            sbc_auth_alg = mi_desc['sbc_auth_alg']
            if(sbc_auth_alg == 0):
                key_path = pem_dir + "/ecdsa_p256_private.pem"
                #print(key_path)
                privk = open(key_path,"rb").read()
                #privk = open("binfile/ecdsa_p256_private.pem","rb").read()
                #pubk_pem = open("binfile/ecdsa_p256_public.pem","rb").read()
            elif(sbc_auth_alg == 1):
                privk = open("binfile/ecdsa_p384_private.pem","rb").read()

            key = cryptoSB.strip_key(aes_enc_sym_key)

        else:
            if(cmd_line_key == 1):
                key = cryptoSB.strip_key(aeskey)
                privk = cryptoSB.strip_key(privkey)
                #print("dump privkey:")
                #cryptoSB.dump(privk)
            elif(cmd_line_key == 0):
                key = cryptoSB.strip_key(aes_enc_sym_key)
                privk = cryptoSB.strip_key(sign_priv_key)
                #print("dump privkey:")
                #cryptoSB.dump(privk)
            else:
                prnit("ERROR: please check cmd_line_key json")



        #run hash
        #hash = cryptoSB.HASH()
        #imghash = hash.hash_sha256(enc_data)
        #cryptoSB.dump(imghash)

        mipack = my_pack_images(mi_desc, privk, key, align, img_dir)
        if mipack == None:
            print("my_pack_images fail")
            return

        #out_name = 'my_' + fn + '.img'
        #out_name = fn + '.img'
        print('output: %s (%d bytes)' % (out_name, len(mipack)))
        open(out_name, 'wb').write(mipack)

    main()

