mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
* multi tokenizer * support fsmn_kws, fsmn_kws_mt, sanm_kws, sanm_kws_streaming training * kws --------- Co-authored-by: pengteng.spt <pengteng.spt@alibaba-inc.com>
287 lines
10 KiB
Python
287 lines
10 KiB
Python
""" This implementation is adapted from https://github.com/wenet-e2e/wekws/blob/main/wekws/bin/compute_det.py."""
|
|
|
|
import os
|
|
import json
|
|
import logging
|
|
import argparse
|
|
import threading
|
|
|
|
import kaldiio
|
|
import torch
|
|
from funasr.utils.kws_utils import split_mixed_label
|
|
|
|
|
|
class thread_wrapper(threading.Thread):
|
|
def __init__(self, func, args=()):
|
|
super(thread_wrapper, self).__init__()
|
|
self.func = func
|
|
self.args = args
|
|
self.result = []
|
|
|
|
def run(self):
|
|
self.result = self.func(*self.args)
|
|
|
|
def get_result(self):
|
|
try:
|
|
return self.result
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
def space_mixed_label(input_str):
|
|
splits = split_mixed_label(input_str)
|
|
space_str = ''.join(f'{sub} ' for sub in splits)
|
|
return space_str.strip()
|
|
|
|
|
|
def read_lists(list_file):
|
|
lists = []
|
|
with open(list_file, 'r', encoding='utf8') as fin:
|
|
for line in fin:
|
|
if line.strip() != '':
|
|
lists.append(line.strip())
|
|
return lists
|
|
|
|
|
|
def make_pair(wav_lists, trans_lists):
|
|
logging.info('make pair for wav-trans list')
|
|
|
|
trans_table = {}
|
|
for line in trans_lists:
|
|
arr = line.strip().replace('\t', ' ').split()
|
|
if len(arr) < 2:
|
|
logging.debug('invalid line in trans file: {}'.format(
|
|
line.strip()))
|
|
continue
|
|
|
|
trans_table[arr[0]] = line.replace(arr[0],'').strip()
|
|
|
|
lists = []
|
|
for line in wav_lists:
|
|
arr = line.strip().replace('\t', ' ').split()
|
|
if len(arr) == 2 and arr[0] in trans_table:
|
|
lists.append(
|
|
dict(key=arr[0],
|
|
txt=trans_table[arr[0]],
|
|
wav=arr[1],
|
|
sample_rate=16000))
|
|
else:
|
|
logging.debug("can't find corresponding trans for key: {}".format(
|
|
arr[0]))
|
|
continue
|
|
|
|
return lists
|
|
|
|
|
|
def count_duration(tid, data_lists):
|
|
results = []
|
|
|
|
for obj in data_lists:
|
|
assert 'key' in obj
|
|
assert 'wav' in obj
|
|
assert 'txt' in obj
|
|
key = obj['key']
|
|
wav_file = obj['wav']
|
|
txt = obj['txt']
|
|
|
|
try:
|
|
rate, waveform = kaldiio.load_mat(wav_file)
|
|
waveform = torch.tensor(waveform, dtype=torch.float32)
|
|
waveform = waveform.unsqueeze(0)
|
|
frames = len(waveform[0])
|
|
duration = frames / float(rate)
|
|
except:
|
|
logging.info(f'load file failed: {wav_file}')
|
|
duration = 0.0
|
|
|
|
obj['duration'] = duration
|
|
results.append(obj)
|
|
|
|
return results
|
|
|
|
|
|
def load_data_and_score(keywords_list, data_file, trans_file, score_file):
|
|
# score_table: {uttid: [keywordlist]}
|
|
score_table = {}
|
|
with open(score_file, 'r', encoding='utf8') as fin:
|
|
# read score file and store in table
|
|
for line in fin:
|
|
arr = line.strip().split()
|
|
key = arr[0]
|
|
is_detected = arr[1]
|
|
if is_detected == 'detected':
|
|
if key not in score_table:
|
|
score_table.update(
|
|
{key: {
|
|
'kw': space_mixed_label(arr[2]),
|
|
'confi': float(arr[3])
|
|
}})
|
|
else:
|
|
if key not in score_table:
|
|
score_table.update({key: {'kw': 'unknown', 'confi': -1.0}})
|
|
|
|
wav_lists = read_lists(data_file)
|
|
trans_lists = read_lists(trans_file)
|
|
data_lists = make_pair(wav_lists, trans_lists)
|
|
logging.info(f'origin list samples: {len(data_lists)}')
|
|
|
|
# count duration for each wave
|
|
num_workers = 8
|
|
start = 0
|
|
step = int(len(data_lists) / num_workers)
|
|
tasks = []
|
|
for idx in range(num_workers):
|
|
if idx != num_workers - 1:
|
|
task = thread_wrapper(count_duration,
|
|
(idx, data_lists[start:start + step]))
|
|
else:
|
|
task = thread_wrapper(count_duration, (idx, data_lists[start:]))
|
|
task.start()
|
|
tasks.append(task)
|
|
start += step
|
|
|
|
duration_lists = []
|
|
for task in tasks:
|
|
task.join()
|
|
duration_lists += task.get_result()
|
|
logging.info(f'after list samples: {len(duration_lists)}')
|
|
|
|
# build empty structure for keyword-filler infos
|
|
keyword_filler_table = {}
|
|
for keyword in keywords_list:
|
|
keyword = space_mixed_label(keyword)
|
|
keyword_filler_table[keyword] = {}
|
|
keyword_filler_table[keyword]['keyword_table'] = {}
|
|
keyword_filler_table[keyword]['keyword_duration'] = 0.0
|
|
keyword_filler_table[keyword]['filler_table'] = {}
|
|
keyword_filler_table[keyword]['filler_duration'] = 0.0
|
|
|
|
for obj in duration_lists:
|
|
assert 'key' in obj
|
|
assert 'wav' in obj
|
|
assert 'txt' in obj
|
|
assert 'duration' in obj
|
|
|
|
key = obj['key']
|
|
wav_file = obj['wav']
|
|
txt = obj['txt']
|
|
txt = space_mixed_label(txt)
|
|
txt_regstr_lrblk = ' ' + txt + ' '
|
|
duration = obj['duration']
|
|
assert key in score_table
|
|
|
|
for keyword in keywords_list:
|
|
keyword = space_mixed_label(keyword)
|
|
keyword_regstr_lrblk = ' ' + keyword + ' '
|
|
if txt_regstr_lrblk.find(keyword_regstr_lrblk) != -1:
|
|
if keyword == score_table[key]['kw']:
|
|
keyword_filler_table[keyword]['keyword_table'].update(
|
|
{key: score_table[key]['confi']})
|
|
else:
|
|
# uttrance detected but not match this keyword
|
|
keyword_filler_table[keyword]['keyword_table'].update(
|
|
{key: -1.0})
|
|
keyword_filler_table[keyword]['keyword_duration'] += duration
|
|
else:
|
|
if keyword == score_table[key]['kw']:
|
|
keyword_filler_table[keyword]['filler_table'].update(
|
|
{key: score_table[key]['confi']})
|
|
else:
|
|
# uttrance if detected, which is not FA for this keyword
|
|
keyword_filler_table[keyword]['filler_table'].update(
|
|
{key: -1.0})
|
|
keyword_filler_table[keyword]['filler_duration'] += duration
|
|
|
|
return keyword_filler_table
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(description='compute det curve')
|
|
parser.add_argument('--keywords',
|
|
type=str,
|
|
required=True,
|
|
help='preset keyword str, input all keywords')
|
|
parser.add_argument('--test_data', required=True, help='test data file')
|
|
parser.add_argument('--trans_data',
|
|
required=True,
|
|
default='',
|
|
help='transcription of test data')
|
|
parser.add_argument('--score_file', required=True, help='score file')
|
|
parser.add_argument('--step',
|
|
type=float,
|
|
default=0.001,
|
|
help='threshold step')
|
|
parser.add_argument('--stats_dir',
|
|
required=True,
|
|
help='to save det stats files')
|
|
args = parser.parse_args()
|
|
|
|
root_logger = logging.getLogger()
|
|
handlers = root_logger.handlers[:]
|
|
for handler in handlers:
|
|
root_logger.removeHandler(handler)
|
|
handler.close()
|
|
|
|
logging.basicConfig(level=logging.DEBUG,
|
|
format='%(asctime)s %(levelname)s %(message)s')
|
|
|
|
keywords_list = args.keywords.strip().split(',')
|
|
keyword_filler_table = load_data_and_score(keywords_list, args.test_data,
|
|
args.trans_data,
|
|
args.score_file)
|
|
|
|
stats_files = {}
|
|
for keyword in keywords_list:
|
|
keyword = space_mixed_label(keyword)
|
|
keyword_dur = keyword_filler_table[keyword]['keyword_duration']
|
|
keyword_num = len(keyword_filler_table[keyword]['keyword_table'])
|
|
filler_dur = keyword_filler_table[keyword]['filler_duration']
|
|
filler_num = len(keyword_filler_table[keyword]['filler_table'])
|
|
if keyword_num <= 0:
|
|
print('Can\'t compute det for {} without positive sample'.format(keyword))
|
|
continue
|
|
if filler_num <= 0:
|
|
print('Can\'t compute det for {} without negative sample'.format(keyword))
|
|
continue
|
|
|
|
logging.info('Computing det for {}'.format(keyword))
|
|
logging.info(' Keyword duration: {} Hours, wave number: {}'.format(
|
|
keyword_dur / 3600.0, keyword_num))
|
|
logging.info(' Filler duration: {} Hours'.format(filler_dur / 3600.0))
|
|
|
|
stats_file = os.path.join(args.stats_dir, 'stats.' + keyword.replace(' ', '_') + '.txt')
|
|
with open(stats_file, 'w', encoding='utf8') as fout:
|
|
threshold = 0.0
|
|
while threshold <= 1.0:
|
|
num_false_reject = 0
|
|
num_true_detect = 0
|
|
# transverse the all keyword_table
|
|
for key, confi in keyword_filler_table[keyword][
|
|
'keyword_table'].items():
|
|
if confi < threshold:
|
|
num_false_reject += 1
|
|
else:
|
|
num_true_detect += 1
|
|
|
|
num_false_alarm = 0
|
|
# transverse the all filler_table
|
|
for key, confi in keyword_filler_table[keyword][
|
|
'filler_table'].items():
|
|
if confi >= threshold:
|
|
num_false_alarm += 1
|
|
# print(f'false alarm: {keyword}, {key}, {confi}')
|
|
|
|
# false_reject_rate = num_false_reject / keyword_num
|
|
true_detect_rate = num_true_detect / keyword_num
|
|
|
|
num_false_alarm = max(num_false_alarm, 1e-6)
|
|
false_alarm_per_hour = num_false_alarm / (filler_dur / 3600.0)
|
|
false_alarm_rate = num_false_alarm / filler_num
|
|
|
|
fout.write('{:.3f} {:.6f} {:.6f} {:.6f}\n'.format(
|
|
threshold, true_detect_rate, false_alarm_rate,
|
|
false_alarm_per_hour))
|
|
threshold += args.step
|
|
|
|
stats_files[keyword] = stats_file
|