FunASR/runtime/onnxruntime/third_party/kaldi/lm/arpa-file-parser.h
Yabin Li 702ec03ad8
Dev new (#1065)
* add hotword for deploy_tools

* Support wfst decoder and contextual biasing (#1039)

* Support wfst decoder and contextual biasing

* Turn on fstbin compilation

---------

Co-authored-by: gongbo.gb <gongbo.gb@alibaba-inc.com>

* mv funasr/runtime runtime

* Fix crash caused by OOV in hotwords list

* funasr infer

* funasr infer

* funasr infer

* funasr infer

* funasr infer

* fix some bugs about fst hotword; support wfst for websocket server and clients; mv runtime out of funasr; modify relative docs

* del onnxruntime/include/gflags

* update tensor.h

* update run_server.sh

* update deploy tools

* update deploy tools

* update websocket-server

* update funasr-wss-server

* Remove self loop propagation

* Update websocket_protocol_zh.md

* Update websocket_protocol_zh.md

* update hotword protocol

* author zhaomingwork: change hotwords for h5 and java

* update hotword protocol

* catch exception for json_fst_hws

* update hotword on message

* update onnx benchmark for ngram&hotword

* update docs

* update funasr-wss-serve

* add NONE for LM_DIR

* update docs

* update run_server.sh

* add whats-new

* modify whats-new

* update whats-new

* update whats-new

* Support decoder option for beam searching

* update benchmark_onnx_cpp

* Support decoder option for websocket

* fix bug of CompileHotwordEmbedding

* update html client

* update docs

---------

Co-authored-by: gongbo.gb <35997837+aibulamusi@users.noreply.github.com>
Co-authored-by: gongbo.gb <gongbo.gb@alibaba-inc.com>
Co-authored-by: 游雁 <zhifu.gzf@alibaba-inc.com>
2023-11-07 18:34:29 +08:00

147 lines
5.3 KiB
C++

// lm/arpa-file-parser.h
// Copyright 2014 Guoguo Chen
// Copyright 2016 Smart Action Company LLC (kkm)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_LM_ARPA_FILE_PARSER_H_
#define KALDI_LM_ARPA_FILE_PARSER_H_
#include <fst/fst-decl.h>
#include <string>
#include <vector>
#include "base/kaldi-types.h"
#include "itf/options-itf.h"
namespace kaldi {
/**
Options that control ArpaFileParser
*/
struct ArpaParseOptions {
enum OovHandling {
kRaiseError, ///< Abort on OOV words
kAddToSymbols, ///< Add novel words to the symbol table.
kReplaceWithUnk, ///< Replace OOV words with <unk>.
kSkipNGram ///< Skip n-gram with OOV word and continue.
};
ArpaParseOptions():
bos_symbol(-1), eos_symbol(-1), unk_symbol(-1),
oov_handling(kRaiseError), max_warnings(30) { }
void Register(OptionsItf *opts) {
// Registering only the max_warnings count, since other options are
// treated differently by client programs: some want integer symbols,
// while other are passed words in their command line.
opts->Register("max-arpa-warnings", &max_warnings,
"Maximum warnings to report on ARPA parsing, "
"0 to disable, -1 to show all");
}
int32 bos_symbol; ///< Symbol for <s>, Required non-epsilon.
int32 eos_symbol; ///< Symbol for </s>, Required non-epsilon.
int32 unk_symbol; ///< Symbol for <unk>, Required for kReplaceWithUnk.
OovHandling oov_handling; ///< How to handle OOV words in the file.
int32 max_warnings; ///< Maximum warnings to report, <0 unlimited.
};
/**
A parsed n-gram from ARPA LM file.
*/
struct NGram {
NGram() : logprob(0.0), backoff(0.0) { }
std::vector<int32> words; ///< Symbols in left to right order.
float logprob; ///< Log-prob of the n-gram.
float backoff; ///< log-backoff weight of the n-gram.
///< Defaults to zero if not specified.
};
/**
ArpaFileParser is an abstract base class for ARPA LM file conversion.
See ConstArpaLmBuilder and ArpaLmCompiler for usage examples.
*/
class ArpaFileParser {
public:
/// Constructs the parser with the given options and optional symbol table.
/// If symbol table is provided, then the file should contain text n-grams,
/// and the words are mapped to symbols through it. bos_symbol and
/// eos_symbol in the options structure must be valid symbols in the table,
/// and so must be unk_symbol if provided. The table is not owned by the
/// parser, but may be augmented, if oov_handling is set to kAddToSymbols.
/// If symbol table is a null pointer, the file should contain integer
/// symbol values, and oov_handling has no effect. bos_symbol and eos_symbol
/// must be valid symbols still.
ArpaFileParser(const ArpaParseOptions& options, fst::SymbolTable* symbols);
virtual ~ArpaFileParser();
/// Read ARPA LM file from a stream.
void Read(std::istream &is);
/// Parser options.
const ArpaParseOptions& Options() const { return options_; }
protected:
/// Override called before reading starts. This is the point to prepare
/// any state in the derived class.
virtual void ReadStarted() { }
/// Override function called to signal that ARPA header with the expected
/// number of n-grams has been read, and ngram_counts() is now valid.
virtual void HeaderAvailable() { }
/// Pure override that must be implemented to process current n-gram. The
/// n-grams are sent in the file order, which guarantees that all
/// (k-1)-grams are processed before the first k-gram is.
virtual void ConsumeNGram(const NGram&) = 0;
/// Override function called after the last n-gram has been consumed.
virtual void ReadComplete() { }
/// Read-only access to symbol table. Not owned, do not make public.
const fst::SymbolTable* Symbols() const { return symbols_; }
/// Inside ConsumeNGram(), provides the current line number.
int32 LineNumber() const { return line_number_; }
/// Inside ConsumeNGram(), returns a formatted reference to the line being
/// compiled, to print out as part of diagnostics.
std::string LineReference() const;
/// Increments warning count, and returns true if a warning should be
/// printed or false if the count has exceeded the set maximum.
bool ShouldWarn();
/// N-gram counts. Valid from the point when HeaderAvailable() is called.
const std::vector<int32>& NgramCounts() const { return ngram_counts_; }
private:
ArpaParseOptions options_;
fst::SymbolTable* symbols_; // the pointer is not owned here.
int32 line_number_;
uint32 warning_count_;
std::string current_line_;
std::vector<int32> ngram_counts_;
};
} // namespace kaldi
#endif // KALDI_LM_ARPA_FILE_PARSER_H_