mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Update run_evaluate.py (#2175)
This commit is contained in:
parent
1a45b647a8
commit
c3e667b217
@ -9,16 +9,14 @@ from fun_text_processing.text_normalization.data_loader_utils import (
|
||||
training_data_to_tokens,
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
Runs Evaluation on data in the format of : <semiotic class>\t<unnormalized text>\t<`self` if trivial class or normalized text>
|
||||
like the Google text normalization data https://www.kaggle.com/richardwilliamsproat/text-normalization-for-english-russian-and-polish
|
||||
"""
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--input", help="input file path", type=str)
|
||||
parser.add_argument("--input", help="input file path", type=str, required=True)
|
||||
parser.add_argument(
|
||||
"--lang",
|
||||
help="language",
|
||||
@ -39,15 +37,13 @@ def parse_args():
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage:
|
||||
# python run_evaluate.py --input=<INPUT> --cat=<CATEGORY> --filter
|
||||
args = parse_args()
|
||||
if args.lang == "en":
|
||||
from fun_text_processing.inverse_text_normalization.en.clean_eval_data import (
|
||||
filter_loaded_data,
|
||||
)
|
||||
from fun_text_processing.inverse_text_normalization.en.clean_eval_data import filter_loaded_data
|
||||
|
||||
file_path = args.input
|
||||
inverse_normalizer = InverseNormalizer()
|
||||
|
||||
@ -57,6 +53,7 @@ if __name__ == "__main__":
|
||||
if args.filter:
|
||||
training_data = filter_loaded_data(training_data)
|
||||
|
||||
# Evaluate at sentence level if no specific category is provided
|
||||
if args.category is None:
|
||||
print("Sentence level evaluation...")
|
||||
sentences_un_normalized, sentences_normalized, _ = training_data_to_sentences(training_data)
|
||||
@ -68,12 +65,12 @@ if __name__ == "__main__":
|
||||
)
|
||||
print("- Accuracy: " + str(sentences_accuracy))
|
||||
|
||||
# Evaluate at token level
|
||||
print("Token level evaluation...")
|
||||
tokens_per_type = training_data_to_tokens(training_data, category=args.category)
|
||||
token_accuracy = {}
|
||||
for token_type in tokens_per_type:
|
||||
for token_type, (tokens_un_normalized, tokens_normalized) in tokens_per_type.items():
|
||||
print("- Token type: " + token_type)
|
||||
tokens_un_normalized, tokens_normalized = tokens_per_type[token_type]
|
||||
print(" - Data: " + str(len(tokens_normalized)) + " tokens")
|
||||
tokens_prediction = inverse_normalizer.inverse_normalize_list(tokens_normalized)
|
||||
print(" - Denormalized. Evaluating...")
|
||||
@ -81,9 +78,9 @@ if __name__ == "__main__":
|
||||
tokens_prediction, tokens_un_normalized, input=tokens_normalized
|
||||
)
|
||||
print(" - Accuracy: " + str(token_accuracy[token_type]))
|
||||
token_count_per_type = {
|
||||
token_type: len(tokens_per_type[token_type][0]) for token_type in tokens_per_type
|
||||
}
|
||||
|
||||
# Calculate weighted token accuracy
|
||||
token_count_per_type = {token_type: len(tokens) for token_type, (tokens, _) in tokens_per_type.items()}
|
||||
token_weighted_accuracy = [
|
||||
token_count_per_type[token_type] * accuracy
|
||||
for token_type, accuracy in token_accuracy.items()
|
||||
@ -96,19 +93,17 @@ if __name__ == "__main__":
|
||||
if token_type not in known_types:
|
||||
raise ValueError("Unexpected token type: " + token_type)
|
||||
|
||||
# Output table summarizing evaluation results if no specific category is provided
|
||||
if args.category is None:
|
||||
c1 = ["Class", "sent level"] + known_types
|
||||
c2 = ["Num Tokens", len(sentences_normalized)] + [
|
||||
token_count_per_type[known_type] if known_type in tokens_per_type else "0"
|
||||
for known_type in known_types
|
||||
str(token_count_per_type.get(known_type, 0)) for known_type in known_types
|
||||
]
|
||||
c3 = ["Denormalization", sentences_accuracy] + [
|
||||
token_accuracy[known_type] if known_type in token_accuracy else "0"
|
||||
for known_type in known_types
|
||||
c3 = ["Denormalization", str(sentences_accuracy)] + [
|
||||
str(token_accuracy.get(known_type, "0")) for known_type in known_types
|
||||
]
|
||||
|
||||
for i in range(len(c1)):
|
||||
print(f"{str(c1[i]):10s} | {str(c2[i]):10s} | {str(c3[i]):5s}")
|
||||
print(f"{c1[i]:10s} | {c2[i]:10s} | {c3[i]:5s}")
|
||||
else:
|
||||
print(f"numbers\t{token_count_per_type[args.category]}")
|
||||
print(f"Denormalization\t{token_accuracy[args.category]}")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user