* v1.0.28

* version checker

* version checker

* rollback cif_v1 for training bug

* fixbug

* fixbug for cif

* fixbug

---------

Co-authored-by: 维石 <shixian.shi@alibaba-inc.com>
This commit is contained in:
zhifu gao 2024-06-24 10:20:53 +08:00 committed by GitHub
parent 5ac34941d1
commit 068a4054ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 65 additions and 17 deletions

View File

@ -54,7 +54,7 @@ DISTRIBUTED_ARGS="
--nproc_per_node $gpu_num \
--node_rank ${RANK:-0} \
--master_addr ${MASTER_ADDR:-127.0.0.1} \
--master_port ${MASTER_PORT: 26669}
--master_port ${MASTER_PORT:-26669}
"
echo $DISTRIBUTED_ARGS

View File

@ -55,7 +55,7 @@ DISTRIBUTED_ARGS="
--nproc_per_node $gpu_num \
--node_rank ${RANK:-0} \
--master_addr ${MASTER_ADDR:-127.0.0.1} \
--master_port ${MASTER_PORT: 26669}
--master_port ${MASTER_PORT:-26669}
"
echo $DISTRIBUTED_ARGS

View File

@ -37,7 +37,7 @@ DISTRIBUTED_ARGS="
--nproc_per_node $gpu_num \
--node_rank ${RANK:-0} \
--master_addr ${MASTER_ADDR:-127.0.0.1} \
--master_port ${MASTER_PORT: 26669}
--master_port ${MASTER_PORT:-26669}
"
echo $DISTRIBUTED_ARGS

View File

@ -37,7 +37,7 @@ DISTRIBUTED_ARGS="
--nproc_per_node $gpu_num \
--node_rank ${RANK:-0} \
--master_addr ${MASTER_ADDR:-127.0.0.1} \
--master_port ${MASTER_PORT: 26669}
--master_port ${MASTER_PORT:-26669}
"
echo $DISTRIBUTED_ARGS

View File

@ -51,7 +51,7 @@ DISTRIBUTED_ARGS="
--nproc_per_node $gpu_num \
--node_rank ${RANK:-0} \
--master_addr ${MASTER_ADDR:-127.0.0.1} \
--master_port ${MASTER_PORT: 26669}
--master_port ${MASTER_PORT:-26669}
"
echo $DISTRIBUTED_ARGS

View File

@ -52,7 +52,7 @@ DISTRIBUTED_ARGS="
--nproc_per_node $gpu_num \
--node_rank ${RANK:-0} \
--master_addr ${MASTER_ADDR:-127.0.0.1} \
--master_port ${MASTER_PORT: 26669}
--master_port ${MASTER_PORT:-26669}
"
echo $DISTRIBUTED_ARGS

View File

@ -50,7 +50,7 @@ DISTRIBUTED_ARGS="
--nproc_per_node $gpu_num \
--node_rank ${RANK:-0} \
--master_addr ${MASTER_ADDR:-127.0.0.1} \
--master_port ${MASTER_PORT: 26669}
--master_port ${MASTER_PORT:-26669}
"
echo $DISTRIBUTED_ARGS

View File

@ -52,7 +52,7 @@ DISTRIBUTED_ARGS="
--nproc_per_node $gpu_num \
--node_rank ${RANK:-0} \
--master_addr ${MASTER_ADDR:-127.0.0.1} \
--master_port ${MASTER_PORT: 26669}
--master_port ${MASTER_PORT:-26669}
"
echo $DISTRIBUTED_ARGS

View File

@ -111,6 +111,13 @@ class AutoModel:
def __init__(self, **kwargs):
try:
from funasr.utils.version_checker import check_for_update
check_for_update()
except:
pass
log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
logging.basicConfig(level=log_level)

View File

@ -33,8 +33,8 @@ def compute_wer(
if cn_postprocess:
value = " ".join(value)
value = value.replace(" ", "")
if value[0] == "":
value = value[1:]
# if value[0] == "请":
# value = value[1:]
value = [x for x in value]
hyp_dict[key] = value
with open(ref_file, "r") as ref_reader:

View File

@ -80,7 +80,7 @@ class CifPredictor(torch.nn.Module):
hidden, alphas, token_num, mask=mask
)
acoustic_embeds, cif_peak = cif_v1(hidden, alphas, self.threshold)
acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
if target_length is None and self.tail_threshold > 0.0:
token_num_int = torch.max(token_num).type(torch.int32).item()
@ -245,7 +245,7 @@ class CifPredictorV2(torch.nn.Module):
hidden, alphas, token_num, mask=None
)
acoustic_embeds, cif_peak = cif_v1(hidden, alphas, self.threshold)
acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
if target_length is None and self.tail_threshold > 0.0:
token_num_int = torch.max(token_num).type(torch.int32).item()
acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
@ -506,7 +506,10 @@ def cif_v1_export(hidden, alphas, threshold: float):
frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device)
prefix_sum = torch.cumsum(alphas, dim=1)
# prefix_sum = torch.cumsum(alphas, dim=1)
prefix_sum = torch.cumsum(alphas, dim=1, dtype=torch.float64).to(
torch.float32
) # cumsum precision degradation cause wrong result in extreme
prefix_sum_floor = torch.floor(prefix_sum)
dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1)
dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum)
@ -518,8 +521,8 @@ def cif_v1_export(hidden, alphas, threshold: float):
fires[fire_idxs] = 1
fires = fires + prefix_sum - prefix_sum_floor
# prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
frames = prefix_sum_hidden[fire_idxs]
shift_frames = torch.roll(frames, 1, dims=0)
@ -530,6 +533,7 @@ def cif_v1_export(hidden, alphas, threshold: float):
shift_frames[shift_batch_idxs] = 0
remains = fires - torch.floor(fires)
# remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
@ -537,8 +541,11 @@ def cif_v1_export(hidden, alphas, threshold: float):
frames = frames - shift_frames + shift_remain_frames - remain_frames
max_label_len = batch_len.max()
# max_label_len = batch_len.max()
max_label_len = alphas.sum(dim=-1)
max_label_len = torch.floor(max_label_len).max().to(dtype=torch.int64)
# frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
indices = torch.arange(max_label_len, device=device).expand(batch_size, -1)
frame_fires_idxs = indices < batch_len.unsqueeze(1)
@ -667,7 +674,10 @@ def cif_wo_hidden_v1(alphas, threshold, return_fire_idxs=False):
fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device)
prefix_sum = torch.cumsum(alphas, dim=1)
# prefix_sum = torch.cumsum(alphas, dim=1)
prefix_sum = torch.cumsum(alphas, dim=1, dtype=torch.float64).to(
torch.float32
) # cumsum precision degradation cause wrong result in extreme
prefix_sum_floor = torch.floor(prefix_sum)
dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1)
dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum)
@ -689,6 +699,8 @@ def cif_v1(hidden, alphas, threshold):
device = hidden.device
dtype = hidden.dtype
batch_size, len_time, hidden_size = hidden.size()
# frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
# prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
@ -702,6 +714,7 @@ def cif_v1(hidden, alphas, threshold):
shift_frames[shift_batch_idxs] = 0
remains = fires - torch.floor(fires)
# remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
@ -709,8 +722,12 @@ def cif_v1(hidden, alphas, threshold):
frames = frames - shift_frames + shift_remain_frames - remain_frames
max_label_len = batch_len.max()
# max_label_len = batch_len.max()
max_label_len = (
torch.round(alphas.sum(-1)).int().max()
) # torch.round to calculate the max length
# frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
indices = torch.arange(max_label_len, device=device).expand(batch_size, -1)
frame_fires_idxs = indices < batch_len.unsqueeze(1)

View File

@ -0,0 +1,24 @@
import requests
from packaging import version
from funasr import __version__ # Ensure that __version__ is defined in your package's __init__.py
def get_pypi_version(package_name):
url = f"https://pypi.org/pypi/{package_name}/json"
response = requests.get(url)
if response.status_code == 200:
data = response.json()
return version.parse(data["info"]["version"])
else:
raise Exception("Failed to retrieve version information from PyPI.")
def check_for_update():
current_version = version.parse(__version__)
pypi_version = get_pypi_version("funasr")
if current_version < pypi_version:
print(f"New version available: {pypi_version}. Your current version is {current_version}.")
print('Please use the command "pip install -U funasr" to upgrade.')
else:
print(f"You are using the latest version of funasr-{current_version}")