mirror of
https://github.com/espressif/esp-sr.git
synced 2025-09-15 15:28:44 +08:00
95 lines
3.6 KiB
Python
95 lines
3.6 KiB
Python
import io
|
|
import os
|
|
import argparse
|
|
import shutil
|
|
def calculate_total_size(folder_path):
|
|
total_size = 0
|
|
for file_name in os.listdir(folder_path):
|
|
path = os.path.join(folder_path, file_name)
|
|
if os.path.isdir(path):
|
|
total_size = total_size + calculate_total_size(path)
|
|
if os.path.isfile(path):
|
|
total_size = total_size + os.path.getsize(path)
|
|
return total_size
|
|
|
|
if __name__ == '__main__':
|
|
# input parameter
|
|
parser = argparse.ArgumentParser(description='Model generator tool')
|
|
parser.add_argument('-d1', '--project_path')
|
|
parser.add_argument('-d2', '--model_path')
|
|
args = parser.parse_args()
|
|
|
|
sdkconfig_path = args.project_path + '/sdkconfig'
|
|
model_path = args.model_path + '/model'
|
|
|
|
print(sdkconfig_path)
|
|
print(model_path)
|
|
|
|
with io.open(sdkconfig_path, "r") as f:
|
|
WN_STRING = ''
|
|
MN_STRING = ''
|
|
for label in f:
|
|
label = label.strip("\n")
|
|
if 'CONFIG_SR_WN' in label and label[0] != '#':
|
|
WN_STRING += label
|
|
if 'CONFIG_SR_MN' in label and label[0] != '#':
|
|
MN_STRING += label
|
|
|
|
wakenet_model = []
|
|
if "CONFIG_SR_WN_WN7Q8_XIAOAITONGXUE" in WN_STRING:
|
|
wakenet_model.append('wn7q8_xiaoaitongxue')
|
|
if "CONFIG_SR_WN_WN7_XIAOAITONGXUE" in WN_STRING:
|
|
wakenet_model.append('wn7_xiaoaitongxue')
|
|
if "CONFIG_SR_WN_WN8_HILEXIN" in WN_STRING:
|
|
wakenet_model.append('wn8_hilexin')
|
|
if "CONFIG_SR_WN_WN8_ALEXA" in WN_STRING:
|
|
wakenet_model.append('wn8_alexa')
|
|
if "CONFIG_SR_WN_WN8_HIESP" in WN_STRING:
|
|
wakenet_model.append('wn8_hiesp')
|
|
if "CONFIG_SR_WN_WN9_XIAOAITONGXUE" in WN_STRING:
|
|
wakenet_model.append('wn9_xiaoaitongxue')
|
|
if "CONFIG_SR_WN_WN9_HILEXIN" in WN_STRING:
|
|
wakenet_model.append('wn9_hilexin')
|
|
if "CONFIG_SR_WN_WN9_ALEXA" in WN_STRING:
|
|
wakenet_model.append('wn9_alexa')
|
|
if "CONFIG_SR_WN_WN9_HIESP" in WN_STRING:
|
|
wakenet_model.append('wn9_hiesp')
|
|
if "CONFIG_SR_WN_WN9_NIHAOXIAOZHI" in WN_STRING:
|
|
wakenet_model.append('wn9_nihaoxiaozhi')
|
|
if "CONFIG_SR_WN_WN9_CUSTOMWORD" in WN_STRING:
|
|
wakenet_model.append('wn9_customword')
|
|
|
|
multinet_model = []
|
|
if "CONFIG_SR_MN_CN_MULTINET3_SINGLE_RECOGNITION" in MN_STRING and len(multinet_model) < 2:
|
|
multinet_model.append('mn3_cn')
|
|
elif "CONFIG_SR_MN_CN_MULTINET4_5_SINGLE_RECOGNITION_QUANT8" in MN_STRING:
|
|
multinet_model.append('mn4q8_cn')
|
|
elif "CONFIG_SR_MN_CN_MULTINET4_5_SINGLE_RECOGNITION" in MN_STRING and len(multinet_model) < 2:
|
|
multinet_model.append('mn4_cn')
|
|
elif "CONFIG_SR_MN_CN_MULTINET5_RECOGNITION_QUANT8" in MN_STRING and len(multinet_model) < 2:
|
|
multinet_model.append('mn5q8_cn')
|
|
|
|
if "CONFIG_SR_MN_EN_MULTINET5_SINGLE_RECOGNITION_QUANT8" in MN_STRING and len(multinet_model) < 2:
|
|
multinet_model.append('mn5q8_en')
|
|
elif "CONFIG_SR_MN_EN_MULTINET5_SINGLE_RECOGNITION" in MN_STRING and len(multinet_model) < 2:
|
|
multinet_model.append('mn5_en')
|
|
|
|
print(wakenet_model)
|
|
print(multinet_model)
|
|
|
|
target_model = model_path + '/target'
|
|
if os.path.exists(target_model):
|
|
shutil.rmtree(target_model)
|
|
os.makedirs(target_model)
|
|
if len(wakenet_model) != 0:
|
|
for wakenet_model_item in wakenet_model:
|
|
shutil.copytree(model_path + '/wakenet_model/' + wakenet_model_item, target_model+'/'+wakenet_model_item)
|
|
if len(multinet_model) != 0:
|
|
for multinet_model_item in multinet_model:
|
|
shutil.copytree(model_path + '/multinet_model/' + multinet_model_item, target_model+'/'+multinet_model_item)
|
|
|
|
# os.system("cp %s %s" % (wakenet_model+'/_MODEL_INFO_', target_model))
|
|
|
|
total_size = calculate_total_size(target_model)
|
|
print("Recommended model partition size: ", str(int((total_size / 1024 + 900) / 4 ) * 4) + 'KB')
|