bug fix for punc and umap

This commit is contained in:
shixian.shi 2024-01-23 11:34:03 +08:00
parent 2c3183b611
commit ae4dceecf0
3 changed files with 5 additions and 3 deletions

View File

@ -119,6 +119,7 @@ class UmapHdbscan:
self.metric = metric
def __call__(self, X):
from umap.umap_ import UMAP
umap_X = umap.UMAP(
n_neighbors=self.n_neighbors,
min_dist=0.0,
@ -156,6 +157,7 @@ class ClusterBackend(torch.nn.Module):
if X.shape[0] < 20:
return np.zeros(X.shape[0], dtype='int')
if X.shape[0] < 2048 or k is not None:
# unexpected corner case
labels = self.spectral_cluster(X, k)
else:
labels = self.umap_hdbscan_cluster(X)

View File

@ -336,10 +336,11 @@ class CTTransformer(torch.nn.Module):
elif new_mini_sentence[-1] != "" and new_mini_sentence[-1] != "" and len(new_mini_sentence[-1].encode())!=1:
new_mini_sentence_out = new_mini_sentence + ""
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
if len(punctuations): punctuations[-1] = 2
elif new_mini_sentence[-1] != "." and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==1:
new_mini_sentence_out = new_mini_sentence + "."
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
if len(punctuations): punctuations[-1] = 2
# keep a punctuations array for punc segment
if punc_array is None:
punc_array = punctuations
@ -347,6 +348,5 @@ class CTTransformer(torch.nn.Module):
punc_array = torch.cat([punc_array, punctuations], dim=0)
result_i = {"key": key[0], "text": new_mini_sentence_out, "punc_array": punc_array}
results.append(result_i)
return results, meta_data

View File

@ -38,7 +38,7 @@ requirements = {
# "protobuf",
"tqdm",
"hdbscan",
"umap",
"umap_learn",
"jaconv",
"hydra-core>=1.3.2",
],