#!/usr/bin/env python

import argparse
import sys
import struct
import os
import math
import xml.dom.minidom

# max strobe delay cycle
max_strobe_dly = 3

def calculate_sdr_acccon(ac, rate):
	# turn clock rate from HZ into KHZ
	rate /= 1000
	# calculate tpoecs
	tpoecs = max(ac["tALH"], ac["tCLH"])
	tpoecs = int(math.ceil(float(tpoecs) * rate / 1000000))
	tpoecs &= 0xf
	# calculate tprecs
	tprecs = max(ac["tCLS"], ac["tALS"])
	tprecs = int(math.ceil(float(tprecs) * rate / 1000000))
	tprecs &= 0x3f
	# calculate tc2r
	tc2r = int(math.ceil(float(ac["tCR"]) * rate / 1000000)) - 1
	tc2r = int(math.ceil(float(tc2r) / 2))
	tc2r &= 0x3f
	# calculate tw2r
	tw2r = int(math.ceil(float(ac["tWHR"]) * rate / 1000000)) - 1
	tw2r = int(math.ceil(float(tw2r) / 2))
	tw2r &= 0xf;
	# calculate twh
	twh = max(ac["tREH"], ac["tWH"])
	twh = int(math.ceil(float(twh) * rate / 1000000)) - 1
	twh &= 0xf
	# calculate twst
	twst = 0
	if (twh + 1) * 1000000 / rate < ac["tWC"]:
		twst = ac["tWC"] - (twh + 1) * 1000000 / rate
	twst = max(ac["tWP"], twst)
	twst = int(math.ceil(float(twst) * rate / 1000000)) - 1
	twst &= 0xf
	# calculate trlt
	trlt = 0
	if (twh + 1) * 1000000 / rate < ac["tRC"]:
		trlt = ac["tRC"] - (twh + 1) * 1000000 / rate
	trlt = max(ac["tRP"], trlt)
	trlt = int(math.ceil(float(trlt) * rate / 1000000)) - 1
	trlt &= 0xf
	# calculate strobe sel
	tsel = 0
	if (trlt + 1) * 1000000 / rate < ac["tREA"]:
		tsel = int(math.ceil(float(ac["tREA"]) * rate / 1000000))
		tsel -= trlt + 1
		if tsel > max_strobe_dly:
			trlt += tsel - max_strobe_dly
			tsel = max_strobe_dly
	# get acccon
	acccon = (tpoecs) << 28 | (tprecs) << 22 | (tc2r) << 16 | (tw2r) << 12 | (twh) << 8 | (twst) << 4 | (trlt)

	return (acccon, tsel)

def check_sdr_acccon_match_spec(ac, rate, acccon):
	# turn clock rate from HZ into KHZ
	rate /= 1000
	# check tWH
	twh = ((acccon >> 8) & 0xf) + 1
	if twh * 1000000 / rate < ac["tWH"]:
		return 1
	# check tREH
	if twh * 1000000 / rate < ac["tREH"]:
		return 1
	# check tWP
	twst = ((acccon >> 4) & 0xf) + 1
	if twst * 1000000 / rate < ac["tWP"]:
		return 1
	# check tRP
	trlt = (acccon & 0xf) + 1
	if trlt * 1000000 / rate < ac["tRP"]:
		return 1
	# check tWC
	if (twh + twst) * 1000000 / rate < ac["tWC"]:
		return 1
	# check tRC
	if (twh + trlt) * 1000000 / rate < ac["tRC"]:
		return 1
	# check more below

	return 0

def calculate_sdr_speed(ac, acccon, array, rate):
	# turn clock rate from HZ into KHZ
	rate /= 1000

	twh = ((acccon >> 8) & 0xf) + 1
	twst = ((acccon >> 4) & 0xf) + 1
	trlt = (acccon & 0xf) + 1
	chunk = array["page_size"] + array["spare_size"]

	# read speed
	tread = (float)(twh + trlt) * 1000000 / rate
	tread *= chunk
	tread += ac["tR"]
	tread /= 1000
	rspeed = 1000000 / tread
	rspeed *= array["page_size"] / 1024

	# write speed
	twrite = (float)(twh + twst) * 1000000 / rate
	twrite *= chunk
	twrite += ac["tPROG"]
	twrite /= 1000
	wspeed = 1000000 / twrite
	wspeed *= array["page_size"] / 1024

	# erase speed
	terase = ac["tBERS"] / 1000
	espeed = 1000000 / terase
	espeed *= array["page_size"] * array["page_per_block"] / 1024

	return (rspeed, wspeed, espeed)

def get_array_para(nand):
	array = {}
	for node in nand.childNodes:
		if node.nodeName != "entry":
			continue
		artype = node.getAttribute("type")
		if artype == "name":
			array["name"] = node.getAttribute("value")
		elif artype == "page_size":
			array["page_size"] = eval(node.getAttribute("value"))
		elif artype == "spare_size":
			array["spare_size"] = eval(node.getAttribute("value"))
		elif artype == "page_per_block":
			array["page_per_block"] = eval(node.getAttribute("value"))
	print ("NAND: %s, page size: %d, spare size: %d, page per block: %d" %
		(array["name"], array["page_size"], array["spare_size"], array["page_per_block"]))

	return array

def get_sdr_ac_timing(ac_timing):
	# get AC timing parameter value
	ac = {}
	for node in ac_timing.childNodes:
		if node.nodeName != "entry":
			continue
		actype = node.getAttribute("type")
		if actype == "tREA":
			ac["tREA"] = eval(node.getAttribute("value"))
		elif actype == "tREH":
			ac["tREH"] = eval(node.getAttribute("value"))
		elif actype == "tCR":
			ac["tCR"] = eval(node.getAttribute("value"))
		elif actype == "tRP":
			ac["tRP"] = eval(node.getAttribute("value"))
		elif actype == "tWP":
			ac["tWP"] = eval(node.getAttribute("value"))
		elif actype == "tWH":
			ac["tWH"] = eval(node.getAttribute("value"))
		elif actype == "tWHR":
			ac["tWHR"] = eval(node.getAttribute("value"))
		elif actype == "tCLS":
			ac["tCLS"] = eval(node.getAttribute("value"))
		elif actype == "tALS":
			ac["tALS"] = eval(node.getAttribute("value"))
		elif actype == "tCLH":
			ac["tCLH"] = eval(node.getAttribute("value"))
		elif actype == "tALH":
			ac["tALH"] = eval(node.getAttribute("value"))
		elif actype == "tWC":
			ac["tWC"] = eval(node.getAttribute("value"))
		elif actype == "tRC":
			ac["tRC"] = eval(node.getAttribute("value"))
		elif actype == "tR":
			ac["tR"] = eval(node.getAttribute("value"))
		elif actype == "tPROG":
			ac["tPROG"] = eval(node.getAttribute("value"))
		elif actype == "tBERS":
			ac["tBERS"] = eval(node.getAttribute("value"))
	print ("tR: %d, tPROG: %d, tBERS: %d ns" % (ac["tR"], ac["tPROG"], ac["tBERS"]))
	return ac

def sdr_calculator(nand, freq):
	print ("This is SDR mode calculator!!")
	nand_file = "./sdr/" + nand
	doc_nand = xml.dom.minidom.parse(nand_file)
	root_nand = doc_nand.documentElement
	for nand_node in root_nand.childNodes:
		if nand_node.nodeName == "ac_timing":
			ac = get_sdr_ac_timing(nand_node)
		elif nand_node.nodeName == "array":
			array = get_array_para(nand_node)

	# get frequency
	freq_file = "./freq/" + freq
	root_freq = xml.dom.minidom.parse(freq_file)
	for frequency in root_freq.childNodes:
		if frequency.nodeName == "frequency":
			break

	divider = eval(frequency.getAttribute("divider"))

	# calculate acccon, stobe, performance
	print ("Read IO speed: %.2f KB/s, Write IO speed: %.2f KB/s" %
		((1000000000 / float(ac["tRC"]) / 1024), (1000000000 / float(ac["tWC"]) / 1024)))
	out_format = "{:^12}\t{:^12}\t{:^5}\t{:^12}\t{:^12}\t{:^12}\t"
	print (out_format.format("freq","acccon","strobe","read(KB/s)","write(KB/s)", "erase(KB/s)"))
	performance = []
	for node in frequency.childNodes:
		if node.nodeName != "entry":
			continue
		rate = eval(node.getAttribute("value"))
		rate /= divider
		freq_name = node.getAttribute("name")
		(acccon, strobe) = calculate_sdr_acccon(ac, rate)
		acccon_match = check_sdr_acccon_match_spec(ac, rate, acccon)
		if acccon_match > 1:
			print ("acccon %#x strobe %d not match spec" % (acccon, strobe))
			raise Exception("not match spec")
		(rspeed, wspeed, espeed) = calculate_sdr_speed(ac, acccon, array, rate)
		rate *= divider
		print (out_format.format(rate, hex(acccon), strobe, format(rspeed, ".2f"), \
			format(wspeed, ".2f"), espeed))
		performance.append({"rate":rate, "acccon":hex(acccon), "strobe":strobe, \
					"rspeed":rspeed, "wspeed":wspeed, "espeed":espeed})

	best_read = best_write = 0
	for best in performance:
		if best["rspeed"] > best_read:
			best_read = best["rspeed"]
			best_read_idx = performance.index(best)
		if best["wspeed"] > best_write:
			best_write = best["wspeed"]
			best_write_idx = performance.index(best)
	print ("frequency[%d] has the best read performance [%.2f KB/s]" %
			(performance[best_read_idx]["rate"], performance[best_read_idx]["rspeed"]))
	print ("frequency[%d] has the best write performance [%.2f KB/s]" %
			(performance[best_write_idx]["rate"], performance[best_read_idx]["wspeed"]))

def get_onfi_ac_timing(ac_timing):
	# get AC timing parameter value
	ac = {}
	for node in ac_timing.childNodes:
		if node.nodeName != "entry":
			continue
		actype = node.getAttribute("type")
		if actype == "tCAD":
			ac["tCAD"] = eval(node.getAttribute("value"))
		elif actype == "tWRCK":
			ac["tWRCK"] = eval(node.getAttribute("value"))
		elif actype == "tDQSCK":
			ac["tDQSCK"] = eval(node.getAttribute("value"))
		elif actype == "tWHR":
			ac["tWHR"] = eval(node.getAttribute("value"))
		elif actype == "tWPRE":
			ac["tWPRE"] = eval(node.getAttribute("value"))
		elif actype == "tWPST":
			ac["tWPST"] = eval(node.getAttribute("value"))
		elif actype == "rate_max":
			ac["rate_max"] = eval(node.getAttribute("value"))
		elif actype == "tR":
			ac["tR"] = eval(node.getAttribute("value"))
		elif actype == "tPROG":
			ac["tPROG"] = eval(node.getAttribute("value"))
		elif actype == "tBERS":
			ac["tBERS"] = eval(node.getAttribute("value"))
	print ("tR: %d, tPROG: %d, tBERS: %d ns" % (ac["tR"], ac["tPROG"], ac["tBERS"]))

	ac["tCKWR"] = ac["tDQSCK"]

	return ac

def calculate_onfi_acccon(ac, rate):
	# turn clock rate from HZ into KHZ
	rate /= 1000
	# calculate tWPRE, tWRST, tCKWR
	ac["tWPRE_real"] = int(math.ceil(float(ac["tWPRE"]) * 1000000 / rate))
	ac["tWPST_real"] = int(math.ceil(float(ac["tWPST"]) * 1000000 / rate))
	ac["tCKWR"] = int(math.ceil(math.ceil(float(ac["tDQSCK"]) * rate / 1000000) * 1000000 / rate))
	# acccon register setting
	# calculate tprecs
	tprecs = max(ac["tCAD"], ac["tWRCK"])
	if ac["tWPRE_real"] > int(math.ceil(1.5 * 1000000 / rate)):
		tprecs = max(tprecs, ac["tWPRE_real"] - int(math.ceil(1.5 * 1000000 / rate)))
	tprecs = int(math.ceil(float(tprecs) * rate / 1000000))
	tprecs &= 0x3f;
	# calculate tw2r
	tw2r = int(math.ceil(float(ac["tWHR"]) * rate / 1000000))
	tw2r = int(math.ceil(float(tw2r - 1) / 2))
	tw2r &= 0xf
	acccon = (tprecs) << 22 | (tw2r) << 12

	# acccon1 register setting
	# calculate trdpst
	trdpst = int(math.ceil(float(ac["tCKWR"]) * rate / 1000000)) - 1
	trdpst &= 0x3f
	# calculate twrpst
	twrpst = int(math.ceil(float(ac["tWPST_real"]) * rate / 1000000)) - 1
	twrpst &= 0x3f
	acccon1 = (trdpst) << 8 | (twrpst)

	return (acccon, acccon1)

def get_toggle_ac_timing(ac_timing):
	# get AC timing parameter value
	ac = {}
	for node in ac_timing.childNodes:
		if node.nodeName != "entry":
			continue
		actype = node.getAttribute("type")
		if actype == "tCH":
			ac["tCH"] = eval(node.getAttribute("value"))
		elif actype == "tCALS":
			ac["tCALS"] = eval(node.getAttribute("value"))
		elif actype == "tCALH":
			ac["tCALH"] = eval(node.getAttribute("value"))
		elif actype == "tWPRE":
			ac["tWPRE"] = eval(node.getAttribute("value"))
		elif actype == "tWPST":
			ac["tWPST"] = eval(node.getAttribute("value"))
		elif actype == "tWPSTH":
			ac["tWPSTH"] = eval(node.getAttribute("value"))
		elif actype == "tCR":
			ac["tCR"] = eval(node.getAttribute("value"))
		elif actype == "tDQSRE":
			ac["tDQSRE"] = eval(node.getAttribute("value"))
		elif actype == "tRPSTH":
			ac["tRPSTH"] = eval(node.getAttribute("value"))
		elif actype == "tCDQSS":
			ac["tCDQSS"] = eval(node.getAttribute("value"))
		elif actype == "tWHR":
			ac["tWHR"] = eval(node.getAttribute("value"))
		elif actype == "rate_max":
			ac["rate_max"] = eval(node.getAttribute("value"))
		elif actype == "tR":
			ac["tR"] = eval(node.getAttribute("value"))
		elif actype == "tPROG":
			ac["tPROG"] = eval(node.getAttribute("value"))
		elif actype == "tBERS":
			ac["tBERS"] = eval(node.getAttribute("value"))
	print ("tR: %d, tPROG: %d, tBERS: %d ns" % (ac["tR"], ac["tPROG"], ac["tBERS"]))

	return ac

def calculate_toggle_acccon(ac, rate):
	# turn clock rate from HZ into KHZ
	rate /= 1000
	# calculate tRPST
	ac["tRPST"] = ac["tDQSRE"] + 0.5 * 1000000 / rate
	# acccon register setting
	# calculate tpoecs
	tpoecs = max(max(ac["tCALH"], ac["tWPSTH"]), ac["tRPSTH"])
	tpoecs = int(math.ceil(float(tpoecs) * rate / 1000000)) - 1
	tpoecs &= 0xf
	# calculate tprecs
	tprecs = max(ac["tCALS"], ac["tCDQSS"])
	tprecs = int(math.ceil(float(tprecs) * rate / 1000000)) - 1
	tprecs &= 0x3f
	# calculate tc2r
	tc2r = math.ceil(float(ac["tCR"]) * rate / 1000000) - 1
	tc2r = int(math.ceil(tc2r / 2))
	tc2r &= 0x3f
	# calculate tw2r
	tw2r = math.ceil(float(ac["tWHR"]) * rate / 1000000) - 1
	tw2r = int(math.ceil(tw2r / 2))
	tw2r &= 0x3f
	# calculate twh
	twh = max(ac["tCH"], ac["tCALH"])
	twh = int(math.ceil(float(twh) * rate / 1000000)) - 1
	twh &= 0xf
	# calculate twst
	twst = int(math.ceil(float(ac["tCALS"]) * rate / 1000000)) - 1
	twst &= 0xf
	trlt = 0
	acccon = (tpoecs) << 28 | (tprecs) << 22 | (tc2r) << 16 | (tw2r) << 12 | (twh) << 8 | (twst) << 4 | (trlt)

	# acccon1 register setting
	# calculate trdpre
	trdpre = int(math.ceil(float(ac["tWPRE"]) * rate / 1000000)) - 1
	trdpre &= 0x3f
	# calculate trdpst
	trdpst = int(math.ceil(float(ac["tRPST"]) * rate / 1000000)) - 1
	trdpst &= 0x3f
	# calculate twrpre
	twrpre = int(math.ceil(float(ac["tWPRE"]) * rate / 1000000)) - 1
	twrpre &= 0x3f
	# calculate twrpst
	twrpst = int(math.ceil(float(ac["tWPST"]) * rate / 1000000)) - 1
	twrpst &= 0x3f
	acccon1 = (trdpre << 24) | (twrpre) << 16 | (trdpst) << 8 | (twrpst)

	return (acccon, acccon1)

def calculate_ddr_speed(ac, array, rate):
	# turn clock rate from HZ into KHZ
	rate /= 1000

	chunk = array["page_size"] + array["spare_size"]

	# read speed
	tread = 1000000 / float(rate * 2)
	tread *= chunk
	tread += ac["tR"]
	tread /= 1000
	rspeed = 1000000 / tread
	rspeed *= array["page_size"] / 1024

	# write speed
	twrite = 1000000 / float(rate * 2)
	twrite *= chunk
	twrite += ac["tPROG"]
	twrite /= 1000
	wspeed = 1000000 / twrite
	wspeed *= array["page_size"] / 1024

	# erase speed
	terase = ac["tBERS"] / 1000
	espeed = 1000000 / terase
	espeed *= array["page_size"] * array["page_per_block"] / 1024

	return (rspeed, wspeed, espeed)

def ddr_calculator(nand, freq, mode):
	print ("This is %s mode calculator!!" % mode)
	nand_file = "./" + mode + "/" + nand
	doc_nand = xml.dom.minidom.parse(nand_file)
	root_nand = doc_nand.documentElement
	for nand_node in root_nand.childNodes:
		if nand_node.nodeName == "ac_timing":
			if mode == "ddr_onfi":
				ac = get_onfi_ac_timing(nand_node)
			elif mode == "ddr_toggle":
				ac = get_toggle_ac_timing(nand_node)
			else:
				raise Exception("unsupported ddr mode!!")
		elif nand_node.nodeName == "array":
			array = get_array_para(nand_node)

	# get frequency
	freq_file = "./freq/" + freq
	root_freq = xml.dom.minidom.parse(freq_file)
	for frequency in root_freq.childNodes:
		if frequency.nodeName == "frequency":
			break

	divider = eval(frequency.getAttribute("divider"))

	out_format = "{:^12}\t{:^12}\t{:^12}\t{:^12}\t{:^12}\t{:^12}\t"
	print (out_format.format("freq","acccon","acccon1","read(KB/s)","write(KB/s)", "erase(KB/s)"))
	rate_invalid = []
	for node in frequency.childNodes:
		if node.nodeName != "entry":
			continue
		rate = eval(node.getAttribute("value"))
		rate /= divider
		if rate > ac["rate_max"]:
			rate_invalid.append(rate * divider)
			continue
		freq_name = node.getAttribute("name")
		if mode == "ddr_onfi":
			(acccon, acccon1) = calculate_onfi_acccon(ac, rate)
		elif mode == "ddr_toggle":
			(acccon, acccon1) = calculate_toggle_acccon(ac, rate)
		else:
			raise Exception("unsupported ddr mode!!")
		(rspeed, wspeed, espeed) = calculate_ddr_speed(ac, array, rate)
		rate *= divider
		print (out_format.format(rate, hex(acccon), hex(acccon1), format(rspeed, ".2f"), \
			format(wspeed, ".2f"), espeed))

	for rate in rate_invalid:
		print ("rate[%d] not supported!" % rate)

def ddr_toggle_calculator(nand, freq):
	print ("This is Toggle DDR mode calculator!!")

def main(argv):
	parser = argparse.ArgumentParser()
	parser.add_argument('mode', help = 'working mode: sdr, ddr_onfi, ddr_toggle')
	parser.add_argument('nand', help= 'nand AC timing xml file')
	parser.add_argument('freq', help = 'frequency file')
	args = parser.parse_args()

	if args.mode == "sdr":
		sdr_calculator(args.nand, args.freq)
	elif args.mode == "ddr_onfi" or args.mode == "ddr_toggle":
		ddr_calculator(args.nand, args.freq, args.mode)
	else:
		raise Exception("Working mode not supported!")

if __name__ == "__main__":
        sys.exit(main(sys.argv))
