feat: Using partition API to replace SPIFFS file system

This commit is contained in:
sxy 2023-03-17 12:05:22 +08:00
parent ae281b3599
commit 235142ad62
11 changed files with 438 additions and 169 deletions

View File

@ -1,7 +1,8 @@
# Change log for esp-sr
## Unreleased
- 6x reduction in model loading time. The latency of MultiNet6 loading is reduced from 12s to 2s.
- Read all parameters sequentially, which reduces about 5x in model loading time.
- Use esp_partition_mmap to replace spiffs file system, which further reduces about 3x in model loading time
## 1.2.0
- ESP-DSP dependency is now installed from the component registry

View File

@ -81,70 +81,29 @@ elseif(${IDF_TARGET} STREQUAL "esp32s3")
"-Wl,--end-group")
set(MVMODEL_EXE ${COMPONENT_PATH}/model/movemodel.py)
idf_build_get_property(build_dir BUILD_DIR)
set(image_file ${build_dir}/srmodels/srmodels.bin)
add_custom_command(
OUTPUT ${PROJECT_DIR}/target/_MODEL_INFO_
COMMENT "Running move model..."
COMMAND python ${MVMODEL_EXE} -d1 ${PROJECT_DIR} -d2 ${COMPONENT_PATH}
DEPENDS ${COMPONENT_DIR}/model/
VERBATIM)
OUTPUT ${image_file}
COMMENT "Move and Pack models..."
COMMAND python ${MVMODEL_EXE} -d1 ${PROJECT_DIR} -d2 ${COMPONENT_PATH} -d3 ${build_dir}
DEPENDS ${PROJECT_DIR}/sdkconfig
VERBATIM)
add_custom_target(model)
add_dependencies(${COMPONENT_LIB} model)
idf_build_get_property(idf_path IDF_PATH)
set(spiffsgen_py ${PYTHON} ${idf_path}/components/spiffs/spiffsgen.py)
get_filename_component(base_dir_full_path ${PROJECT_DIR}/target/ ABSOLUTE)
add_custom_target(srmodels_bin ALL DEPENDS ${image_file})
add_dependencies(flash srmodels_bin)
partition_table_get_partition_info(size "--partition-name model" "size")
partition_table_get_partition_info(offset "--partition-name model" "offset")
if("${size}" AND "${offset}" AND CONFIG_MODEL_IN_SPIFFS AND CONFIG_USE_WAKENET)
set(image_file ${CMAKE_BINARY_DIR}/model.bin)
if(CONFIG_SPIFFS_USE_MAGIC)
set(use_magic "--use-magic")
endif()
if(CONFIG_SPIFFS_USE_MAGIC_LENGTH)
set(use_magic_len "--use-magic-len")
endif()
if(CONFIG_SPIFFS_FOLLOW_SYMLINKS)
set(follow_symlinks "--follow-symlinks")
endif()
# Execute SPIFFS image generation; this always executes as there is no way to specify for CMake to watch for
# contents of the base dir changing.
add_custom_target(spiffs_model_bin ALL
COMMAND ${spiffsgen_py} ${size} ${base_dir_full_path} ${image_file}
--page-size=${CONFIG_SPIFFS_PAGE_SIZE}
--obj-name-len=${CONFIG_SPIFFS_OBJ_NAME_LEN}
--meta-len=${CONFIG_SPIFFS_META_LENGTH}
${follow_symlinks}
${use_magic}
${use_magic_len}
DEPENDS ${PROJECT_DIR}/target/_MODEL_INFO_
)
set_property(DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}" APPEND PROPERTY
ADDITIONAL_MAKE_CLEAN_FILES
${image_file})
idf_component_get_property(main_args esptool_py FLASH_ARGS)
idf_component_get_property(sub_args esptool_py FLASH_SUB_ARGS)
# Last (optional) parameter is the encryption for the target. In our
# case, spiffs is not encrypt so pass FALSE to the function.
esptool_py_flash_target(model-flash "${main_args}" "${sub_args}" ALWAYS_PLAINTEXT)
esptool_py_flash_to_partition(model-flash "model" "${image_file}")
add_dependencies(model-flash spiffs_model_bin)
if("${size}" AND "${offset}")
# add_dependencies(model-flash srmodels_bin)
# esptool_py_flash_to_partition(model-flash "model" "${image_file}")
esptool_py_flash_to_partition(flash "model" "${image_file}")
add_dependencies(flash spiffs_model_bin)
else()
set(message "Failed to create SPIFFS image for partition 'model'. "
"Check project configuration if using the correct partition table file.")
set(message "Failed to find model in partition table file"
"Please add a line(Name=model, Size>recommended size in log) to the partition file.")
endif()
elseif(${IDF_TARGET} STREQUAL "esp32s2")
set(COMPONENT_ADD_INCLUDEDIRS

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -2,6 +2,8 @@ import io
import os
import argparse
import shutil
from pack_model import pack_models
def calculate_total_size(folder_path):
total_size = 0
for file_name in os.listdir(folder_path):
@ -12,98 +14,115 @@ def calculate_total_size(folder_path):
total_size = total_size + os.path.getsize(path)
return total_size
def copy_wakenet_from_sdkconfig(model_path, sdkconfig_path, target_path):
"""
Copy wakenet model from model_path to target_path based on sdkconfig
"""
with io.open(sdkconfig_path, "r") as f:
models_string = ''
for label in f:
label = label.strip("\n")
if 'CONFIG_SR_WN' in label and label[0] != '#':
models_string += label
models = []
if "CONFIG_SR_WN_WN7Q8_XIAOAITONGXUE" in models_string:
models.append('wn7q8_xiaoaitongxue')
if "CONFIG_SR_WN_WN7_XIAOAITONGXUE" in models_string:
models.append('wn7_xiaoaitongxue')
if "CONFIG_SR_WN_WN8_HILEXIN" in models_string:
models.append('wn8_hilexin')
if "CONFIG_SR_WN_WN8_ALEXA" in models_string:
models.append('wn8_alexa')
if "CONFIG_SR_WN_WN8_HIESP" in models_string:
models.append('wn8_hiesp')
if "CONFIG_SR_WN_WN9_XIAOAITONGXUE" in models_string:
models.append('wn9_xiaoaitongxue')
if "CONFIG_SR_WN_WN9_HILEXIN" in models_string:
models.append('wn9_hilexin')
if "CONFIG_SR_WN_WN9_ALEXA" in models_string:
models.append('wn9_alexa')
if "CONFIG_SR_WN_WN9_HIESP" in models_string:
models.append('wn9_hiesp')
if "CONFIG_SR_WN_WN9_NIHAOXIAOZHI" in models_string:
models.append('wn9_nihaoxiaozhi')
if "CONFIG_SR_WN_WN9_CUSTOMWORD" in models_string:
models.append('wn9_customword')
for item in models:
shutil.copytree(model_path + '/wakenet_model/' + item, target_path+'/'+item)
def copy_multinet_from_sdkconfig(model_path, sdkconfig_path, target_path):
"""
Copy multinet model from model_path to target_path based on sdkconfig
"""
with io.open(sdkconfig_path, "r") as f:
models_string = ''
for label in f:
label = label.strip("\n")
if 'CONFIG_SR_MN' in label and label[0] != '#':
models_string += label
models = []
if "CONFIG_SR_MN_CN_MULTINET3_SINGLE_RECOGNITION" in models_string and len(models) < 2:
models.append('mn3_cn')
elif "CONFIG_SR_MN_CN_MULTINET4_5_SINGLE_RECOGNITION_QUANT8" in models_string:
models.append('mn4q8_cn')
elif "CONFIG_SR_MN_CN_MULTINET4_5_SINGLE_RECOGNITION" in models_string and len(models) < 2:
models.append('mn4_cn')
elif "CONFIG_SR_MN_CN_MULTINET5_RECOGNITION_QUANT8" in models_string and len(models) < 2:
models.append('mn5q8_cn')
elif "CONFIG_SR_MN_EN_MULTINET5_SINGLE_RECOGNITION_QUANT8" in models_string and len(models) < 2:
models.append('mn5q8_en')
elif "CONFIG_SR_MN_EN_MULTINET5_SINGLE_RECOGNITION" in models_string and len(models) < 2:
models.append('mn5_en')
elif "CONFIG_SR_MN_EN_MULTINET6_QUANT" in models_string and len(models) < 2:
models.append('mn6_en')
elif "CONFIG_SR_MN_CN_MULTINET6_QUANT" in models_string and len(models) < 2:
models.append('mn6_cn')
if "MULTINET6" in models_string:
models.append('fst')
for item in models:
shutil.copytree(model_path + '/multinet_model/' + item, target_path+'/'+item)
def copy_nsnet_from_sdkconfig(model_path, sdkconfig_path, target_path):
"""
Copy nsnet model from model_path to target_path based on sdkconfig
"""
with io.open(sdkconfig_path, "r") as f:
models_string = ''
for label in f:
label = label.strip("\n")
if 'CONFIG_SR_NSN' in label and label[0] != '#':
models_string += label
models = []
if "CONFIG_SR_NSN_NSNET1" in models_string:
models.append('nsnet1')
for item in models:
shutil.copytree(model_path + '/nsnet_model/' + item, target_path+'/'+item)
if __name__ == '__main__':
# input parameter
parser = argparse.ArgumentParser(description='Model generator tool')
parser.add_argument('-d1', '--project_path')
parser.add_argument('-d2', '--model_path')
parser.add_argument('-d3', '--build_path')
args = parser.parse_args()
sdkconfig_path = args.project_path + '/sdkconfig'
model_path = args.model_path + '/model'
target_path = args.build_path + '/srmodels'
print(sdkconfig_path)
print(model_path)
if os.path.exists(target_path):
shutil.rmtree(target_path)
os.makedirs(target_path)
with io.open(sdkconfig_path, "r") as f:
WN_STRING = ''
MN_STRING = ''
NSN_STRING = ''
for label in f:
label = label.strip("\n")
if 'CONFIG_SR_WN' in label and label[0] != '#':
WN_STRING += label
if 'CONFIG_SR_MN' in label and label[0] != '#':
MN_STRING += label
if 'CONFIG_SR_NSN' in label and label[0] != '#':
NSN_STRING += label
wakenet_model = []
if "CONFIG_SR_WN_WN7Q8_XIAOAITONGXUE" in WN_STRING:
wakenet_model.append('wn7q8_xiaoaitongxue')
if "CONFIG_SR_WN_WN7_XIAOAITONGXUE" in WN_STRING:
wakenet_model.append('wn7_xiaoaitongxue')
if "CONFIG_SR_WN_WN8_HILEXIN" in WN_STRING:
wakenet_model.append('wn8_hilexin')
if "CONFIG_SR_WN_WN8_ALEXA" in WN_STRING:
wakenet_model.append('wn8_alexa')
if "CONFIG_SR_WN_WN8_HIESP" in WN_STRING:
wakenet_model.append('wn8_hiesp')
if "CONFIG_SR_WN_WN9_XIAOAITONGXUE" in WN_STRING:
wakenet_model.append('wn9_xiaoaitongxue')
if "CONFIG_SR_WN_WN9_HILEXIN" in WN_STRING:
wakenet_model.append('wn9_hilexin')
if "CONFIG_SR_WN_WN9_ALEXA" in WN_STRING:
wakenet_model.append('wn9_alexa')
if "CONFIG_SR_WN_WN9_HIESP" in WN_STRING:
wakenet_model.append('wn9_hiesp')
if "CONFIG_SR_WN_WN9_NIHAOXIAOZHI" in WN_STRING:
wakenet_model.append('wn9_nihaoxiaozhi')
if "CONFIG_SR_WN_WN9_CUSTOMWORD" in WN_STRING:
wakenet_model.append('wn9_customword')
multinet_model = []
if "CONFIG_SR_MN_CN_MULTINET3_SINGLE_RECOGNITION" in MN_STRING and len(multinet_model) < 2:
multinet_model.append('mn3_cn')
elif "CONFIG_SR_MN_CN_MULTINET4_5_SINGLE_RECOGNITION_QUANT8" in MN_STRING:
multinet_model.append('mn4q8_cn')
elif "CONFIG_SR_MN_CN_MULTINET4_5_SINGLE_RECOGNITION" in MN_STRING and len(multinet_model) < 2:
multinet_model.append('mn4_cn')
elif "CONFIG_SR_MN_CN_MULTINET5_RECOGNITION_QUANT8" in MN_STRING and len(multinet_model) < 2:
multinet_model.append('mn5q8_cn')
if "CONFIG_SR_MN_EN_MULTINET5_SINGLE_RECOGNITION_QUANT8" in MN_STRING and len(multinet_model) < 2:
multinet_model.append('mn5q8_en')
elif "CONFIG_SR_MN_EN_MULTINET5_SINGLE_RECOGNITION" in MN_STRING and len(multinet_model) < 2:
multinet_model.append('mn5_en')
elif "CONFIG_SR_MN_EN_MULTINET6_QUANT" in MN_STRING and len(multinet_model) < 2:
multinet_model.append('mn6_en')
elif "CONFIG_SR_MN_CN_MULTINET6_QUANT" in MN_STRING and len(multinet_model) < 2:
multinet_model.append('mn6_cn')
nsnet_model = ''
if "CONFIG_SR_NSN_NSNET1" in NSN_STRING:
nsnet_model = 'nsnet1'
print(wakenet_model)
print(multinet_model)
print(nsnet_model)
target_model = args.project_path + '/target'
if os.path.exists(target_model):
shutil.rmtree(target_model)
os.makedirs(target_model)
if len(wakenet_model) != 0:
for wakenet_model_item in wakenet_model:
shutil.copytree(model_path + '/wakenet_model/' + wakenet_model_item, target_model+'/'+wakenet_model_item)
if len(multinet_model) != 0:
for multinet_model_item in multinet_model:
shutil.copytree(model_path + '/multinet_model/' + multinet_model_item, target_model+'/'+multinet_model_item)
if nsnet_model != '':
shutil.copytree(model_path + '/nsnet_model/' + nsnet_model, target_model+'/'+nsnet_model)
# os.system("cp %s %s" % (wakenet_model+'/_MODEL_INFO_', target_model))
shutil.copytree(f'{model_path}/multinet_model/fst', target_model + '/fst')
total_size = calculate_total_size(target_model)
print("Recommended model partition size: ", str(int((total_size / 1024 + 900) / 4 ) * 4) + 'KB')
copy_multinet_from_sdkconfig(model_path, sdkconfig_path, target_path)
copy_wakenet_from_sdkconfig(model_path, sdkconfig_path, target_path)
total_size = calculate_total_size(target_path)
pack_models(target_path, "srmodels.bin")
print("Recommended model partition size: ", str(int((total_size / 1024 + 900) / 4 ) * 4) + 'KB')

122
model/pack_model.py Normal file
View File

@ -0,0 +1,122 @@
import os
import struct
import argparse
def struct_pack_string(string, max_len=None):
"""
pack string to binary data.
if max_len is None, max_len = len(string) + 1
else len(string) < max_len, the left will be padded by struct.pack('x')
string: input python string
max_len: output
"""
if max_len == None :
max_len = len(string)
else:
assert len(string) <= max_len
left_num = max_len - len(string)
out_bytes = None
for char in string:
if out_bytes == None:
out_bytes = struct.pack('b', ord(char))
else:
out_bytes += struct.pack('b', ord(char))
for i in range(left_num):
out_bytes += struct.pack('x')
return out_bytes
def read_data(filename):
"""
Read binary data, like index and mndata
"""
data = None
with open(filename, "rb") as f:
data = f.read()
return data
def pack_models(model_path, out_file="srmodels.bin"):
"""
Pack all models into one binary file by the following format:
{
model_num: int
model1_info: model_info_t
model2_info: model_info_t
...
model1_index,model1_data,model1_MODEL_INFO
model1_index,model1_data,model1_MODEL_INFO
...
}model_pack_t
{
model_name: char[32]
file_number: int
file1_name: char[32]
file1_start: int
file1_len: int
file2_name: char[32]
file2_start: int // data_len = info_start - data_start
file2_len: int
...
}model_info_t
model_path: the path of models
out_file: the ouput binary filename
"""
models = {}
file_num = 0
model_num = 0
for root, dirs, _ in os.walk(model_path):
for model_name in dirs:
models[model_name] = {}
model_dir = os.path.join(root, model_name)
model_num += 1
for _, _, files in os.walk(model_dir):
for file_name in files:
file_num += 1
file_path = os.path.join(model_dir, file_name)
models[model_name][file_name] = read_data(file_path)
model_num = len(models)
header_len = 4 + model_num*(32+4) + file_num*(32+4+4)
out_bin = struct.pack('I', model_num) # model number
data_bin = None
for key in models:
model_bin = struct_pack_string(key, 32) # + model name
model_bin += struct.pack('I', len(models[key])) # + file number in this model
for file_name in models[key]:
model_bin += struct_pack_string(file_name, 32) # + file name
if data_bin == None:
model_bin += struct.pack('I', header_len)
data_bin = models[key][file_name]
model_bin += struct.pack('I', len(models[key][file_name]))
# print(file_name, header_len, len(models[key][file_name]), len(data_bin))
else:
model_bin += struct.pack('I', header_len+len(data_bin))
# print(file_name, header_len+len(data_bin), len(models[key][file_name]))
data_bin += models[key][file_name]
model_bin += struct.pack('I', len(models[key][file_name]))
out_bin += model_bin
assert len(out_bin) == header_len
out_bin += data_bin
out_file = os.path.join(model_path, out_file)
with open(out_file, "wb") as f:
f.write(out_bin)
if __name__ == "__main__":
# input parameter
parser = argparse.ArgumentParser(description='Model package tool')
parser.add_argument('-m', '--model_path', help="the path of model files")
parser.add_argument('-o', '--out_file', default="srmodels.bin", help="the path of binary file")
args = parser.parse_args()
# convert(args.model_path, args.out_file)
pack_models(model_path=args.model_path, out_file=args.out_file)

View File

@ -1,20 +1,37 @@
#pragma once
#define SRMODEL_STRING_LENGTH 32
#ifdef ESP_PLATFORM
#include "esp_partition.h"
#endif
typedef struct
{
char **model_name; // the name of models, like "wn9_hilexin"(wakenet9, hilexin), "mn5_en"(multinet5, english)
char *partition_label; // partition label used to save the files of model
int num; // the number of models
// char *name; // the name of model, like "wn9_hilexin"(wakenet9, hilexin)
int num; // the number of files
char **files; // the model files, like wn9_index, wn9_data
char **data; // the pointer of file data
int *sizes; // the size of different file
} srmodel_data_t;
typedef struct
{
char **model_name; // the name of models, like "wn9_hilexin"(wakenet9, hilexin), "mn5_en"(multinet5, english)
char *partition_label; // partition label used to save the files of model
spi_flash_mmap_handle_t *mmap_handle;// mmap_handle if using esp_partition_mmap else NULL;
int num; // the number of models
srmodel_data_t **model_data; // the model data , NULL if spiffs format
} srmodel_list_t;
#define MODEL_NAME_MAX_LENGTH 64
/**
* @brief Return all avaliable models in spiffs or selected in Kconfig.
* @brief Return all avaliable models in flash.
*
* @param partition_label The spiffs label defined in your partition file used to save models.
*
* @return all avaliable models in spiffs,save as srmodel_list_t.
* @return all avaliable models,save as srmodel_list_t.
*/
srmodel_list_t* esp_srmodel_init(const char* partition_label);
@ -78,6 +95,14 @@ void srmodel_spiffs_deinit(srmodel_list_t *models);
*/
char *get_model_base_path(void);
/**
* @brief Return static srmodel pointer.
* static srmodel pointer will be set after esp_srmodel_init
*
* @return the pointer of srmodel_list_t
*/
srmodel_list_t *get_static_srmodels(void);
#ifdef ESP_PLATFORM
#include "dl_lib_coefgetter_if.h"

View File

@ -15,14 +15,38 @@
static char *TAG = "MODEL_LOADER";
static char *SRMODE_BASE_PATH = "/srmodel";
static srmodel_list_t *static_srmodels = NULL;
void set_model_base_path(const char *base_path)
{
SRMODE_BASE_PATH = (char *)base_path;
}
static srmodel_list_t* srmodel_list_alloc(void)
{
srmodel_list_t *models = (srmodel_list_t*) malloc(sizeof(srmodel_list_t));
models->mmap_handle = NULL;
models->model_data = NULL;
models->model_name = NULL;
models->num = 0;
models->partition_label = NULL;
return models;
}
#ifdef ESP_PLATFORM
srmodel_list_t *read_models_form_spiffs(esp_vfs_spiffs_conf_t *conf)
{
if (static_srmodels == NULL)
static_srmodels = srmodel_list_alloc();
else
return static_srmodels;
srmodel_list_t *models = static_srmodels;
struct dirent *ret;
DIR *dir = NULL;
dir = opendir(conf->base_path);
srmodel_list_t *models = NULL;
int model_num = 0;
int idx = 0;
@ -46,7 +70,6 @@ srmodel_list_t *read_models_form_spiffs(esp_vfs_spiffs_conf_t *conf)
if (model_num == 0) {
return models;
} else {
models = malloc(sizeof(srmodel_list_t));
models->num = model_num;
models->partition_label = (char *)conf->partition_label;
models->model_name = malloc(models->num*sizeof(char*));
@ -140,12 +163,18 @@ void srmodel_spiffs_deinit(srmodel_list_t *models)
}
free(models);
}
models = NULL;
}
srmodel_list_t *srmodel_config_init()
{
srmodel_list_t *models = malloc(sizeof(srmodel_list_t));
if (static_srmodels == NULL)
static_srmodels = srmodel_list_alloc();
else
return static_srmodels;
srmodel_list_t *models = static_srmodels;
models->num = 2;
models->model_name = malloc(models->num*sizeof(char*));
for (int i=0; i<models->num; i++) {
@ -174,7 +203,7 @@ srmodel_list_t *srmodel_config_init()
free(models);
models = NULL;
}
return models;
}
@ -189,6 +218,7 @@ void srmodel_config_deinit(srmodel_list_t *models)
}
free(models);
}
models = NULL;
}
model_coeff_getter_t* srmodel_get_model_coeff(char *model_name)
@ -197,6 +227,103 @@ model_coeff_getter_t* srmodel_get_model_coeff(char *model_name)
return gettercb;
}
static uint32_t read_int32(char *data) {
uint32_t value = 0;
value |= data[0] << 0;
value |= data[1] << 8;
value |= data[2] << 16;
value |= data[3] << 24;
return value;
}
srmodel_list_t *srmodel_mmap_init(esp_partition_t *part)
{
if (static_srmodels == NULL)
static_srmodels = srmodel_list_alloc();
else
return static_srmodels;
srmodel_list_t *models = static_srmodels;
void *root = NULL;
models->mmap_handle = (spi_flash_mmap_handle_t *)malloc(sizeof(spi_flash_mmap_handle_t));
esp_err_t err=esp_partition_mmap(part, 0, part->size, SPI_FLASH_MMAP_DATA, &root, models->mmap_handle);
if (err != ESP_OK) {
ESP_LOGE(TAG, "Can not map %s partition!\n", part->label);
return NULL;
} else {
ESP_LOGI(TAG, "partition %s size: %d by mmap\n", part->label, part->size);
}
models->partition_label = part->label;
char *start = root;
char *data = root;
int str_len = SRMODEL_STRING_LENGTH;
int int_len = 4;
//read model number
models->num = read_int32(data);
data += int_len;
models->model_data = (srmodel_data_t **)malloc(sizeof(srmodel_data_t*) * models->num);
models->model_name = (char **)malloc(sizeof(char*) * models->num);
for (int i=0; i<models->num; i++) {
srmodel_data_t *model_data = (srmodel_data_t *) malloc(sizeof(srmodel_data_t));
// read model name
models->model_name[i] = (char*)malloc((strlen(data)+1)*sizeof(char));
strcpy(models->model_name[i], data);
data += str_len;
printf("%s\n", models->model_name[i]);
//read model number
int file_num = read_int32(data);
model_data->num = file_num;
data += int_len;
model_data->files = (char **) malloc(sizeof(char*)*file_num);
model_data->data = (void **) malloc(sizeof(void*)*file_num);
model_data->sizes = (int *) malloc(sizeof(int)*file_num);
for (int j=0; j<file_num; j++) {
//read file name
// model_data->files[j] = (char*)malloc(str_len*sizeof(char));
// memcpy(model_data->files[j], data, str_len);
model_data->files[j] = data;
data += str_len;
//read file start index
int index = read_int32(data);
data += int_len;
model_data->data[j] = start + index;
//read file size
int size = read_int32(data);
data += int_len;
model_data->sizes[j] = size;
}
models->model_data[i] = model_data;
}
set_model_base_path(NULL);
return models;
}
void srmodel_mmap_deinit(srmodel_list_t *models)
{
if (models != NULL) {
// esp_partition_munmap(models->mmap_handle); // support esp-idf v5.0
spi_flash_munmap(models->mmap_handle);
if (models->num>0) {
for (int i=0; i<models->num; i++) {
free(models->model_data[i]->files);
free(models->model_data[i]->data);
free(models->model_data[i]->sizes);
free(models->model_data[i]);
free(models->model_name[i]);
}
}
free(models->model_data);
free(models->model_name);
free(models);
}
models = NULL;
}
#endif
char *get_model_base_path(void)
@ -204,15 +331,12 @@ char *get_model_base_path(void)
return SRMODE_BASE_PATH;
}
int set_model_base_path(const char *base_path)
srmodel_list_t *get_static_srmodels(void)
{
if (base_path == NULL) return 0;
SRMODE_BASE_PATH = (char *)base_path;
return 1;
return static_srmodels;
}
char* _join_path_(const char* dirname, const char *filename)
static char* join_path(const char* dirname, const char *filename)
{
if (dirname == NULL || filename == NULL)
return NULL;
@ -234,11 +358,17 @@ char* _join_path_(const char* dirname, const char *filename)
srmodel_list_t* srmodel_sdcard_init(const char *base_path)
{
printf("Initializing models from path: %s\n", base_path);
if (static_srmodels == NULL)
static_srmodels = srmodel_list_alloc();
else
return static_srmodels;
srmodel_list_t *models = static_srmodels;
set_model_base_path(base_path);
struct dirent *ret;
DIR *dir = NULL;
dir = opendir(base_path);
srmodel_list_t *models = NULL;
int model_num = 0;
int idx = 0;
FILE* fp;
@ -250,8 +380,8 @@ srmodel_list_t* srmodel_sdcard_init(const char *base_path)
{ // NULL if reach the end of directory
if (ret->d_type == DT_DIR) { // if d_type is directory
char *sub_path = _join_path_(base_path, ret->d_name);
char *info_file = _join_path_(sub_path, "_MODEL_INFO_");
char *sub_path = join_path(base_path, ret->d_name);
char *info_file = join_path(sub_path, "_MODEL_INFO_");
fp = fopen(info_file, "r");
if (fp != NULL) {
model_num ++; // If _MODLE_INFO_ file exists, model_num ++
@ -268,7 +398,6 @@ srmodel_list_t* srmodel_sdcard_init(const char *base_path)
if (model_num == 0) {
return models;
} else {
models = malloc(sizeof(srmodel_list_t));
models->num = model_num;
models->partition_label = NULL;
models->model_name = malloc(models->num*sizeof(char*));
@ -283,8 +412,8 @@ srmodel_list_t* srmodel_sdcard_init(const char *base_path)
{ // NULL if reach the end of directory
if (ret->d_type == DT_DIR) { // if d_type is directory
char *sub_path = _join_path_(base_path, ret->d_name);
char *info_file = _join_path_(sub_path, "_MODEL_INFO_");
char *sub_path = join_path(base_path, ret->d_name);
char *info_file = join_path(sub_path, "_MODEL_INFO_");
fp = fopen(info_file, "r");
if (fp != NULL) {
memcpy(models->model_name[idx], ret->d_name, strlen(ret->d_name));
@ -312,22 +441,36 @@ void srmodel_sdcard_deinit(srmodel_list_t *models)
}
free(models);
}
models = NULL;
}
srmodel_list_t* esp_srmodel_init(const char* base_path)
srmodel_list_t* esp_srmodel_init(const char* partition_label)
{
#ifdef ESP_PLATFORM
#ifdef CONFIG_IDF_TARGET_ESP32
return srmodel_config_init();
#else
return srmodel_spiffs_init(base_path);
const esp_partition_t* part = NULL;
// find spiffs partition
part = esp_partition_find_first(
ESP_PARTITION_TYPE_DATA, ESP_PARTITION_SUBTYPE_ANY, partition_label
);
return srmodel_mmap_init(part);
if (part) {
return srmodel_mmap_init(part);
} else {
ESP_LOGE(TAG, "Can not find %s in partition table\n", partition_label);
}
return NULL;
#endif
#else
return srmodel_sdcard_init(base_path);
return srmodel_sdcard_init(partition_label);
#endif
}
@ -347,7 +490,7 @@ void esp_srmodel_deinit(srmodel_list_t *models)
}
// repackage strstr function to support needle==NULL
char *_esp_strstr_(const char *haystack, const char *needle)
static char *esp_strstr(const char *haystack, const char *needle)
{
if (needle == NULL) return (char *)haystack;
else return (char *)strstr(haystack, needle);
@ -360,8 +503,8 @@ char *esp_srmodel_filter(srmodel_list_t *models, const char *keyword1, const cha
// return the first model name including specific keyword
for (int i=0; i<models->num; i++) {
if (_esp_strstr_(models->model_name[i], keyword1) != NULL) {
if (_esp_strstr_(models->model_name[i], keyword2) != NULL)
if (esp_strstr(models->model_name[i], keyword1) != NULL) {
if (esp_strstr(models->model_name[i], keyword2) != NULL)
return models->model_name[i];
}
}