mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
更新go client 的原生实现 (#2532)
This commit is contained in:
parent
8b0fb74bde
commit
ab2148ec18
@ -2,19 +2,18 @@ package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-audio/wav"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
@ -49,6 +48,63 @@ type AudioData struct {
|
||||
AudioBytes string `json:"audio_bytes"`
|
||||
}
|
||||
|
||||
// 完全模拟Python wave库的行为读取WAV文件
|
||||
func readWAVFile(filePath string, chunkSize []int, chunkInterval int) (*AudioData, error) {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open WAV file: %v", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// 使用wav库解码WAV文件,但只获取基本信息
|
||||
decoder := wav.NewDecoder(file)
|
||||
if !decoder.IsValidFile() {
|
||||
return nil, fmt.Errorf("invalid WAV file format")
|
||||
}
|
||||
|
||||
// 获取WAV文件格式信息
|
||||
format := decoder.Format()
|
||||
sampleRate := int(format.SampleRate)
|
||||
|
||||
// 读取所有音频数据
|
||||
buf, err := decoder.FullPCMBuffer()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read full PCM data: %v", err)
|
||||
}
|
||||
|
||||
// 获取原始音频字节数据(等同于Python的bytes(frames))
|
||||
audioBytes := make([]byte, len(buf.Data)*2)
|
||||
for i, sample := range buf.Data {
|
||||
// 确保与Python wave库的字节序一致(小端序16位)
|
||||
sample16 := int16(sample)
|
||||
audioBytes[i*2] = byte(sample16 & 0xFF)
|
||||
audioBytes[i*2+1] = byte((sample16 >> 8) & 0xFF)
|
||||
}
|
||||
|
||||
// 完全按照Python的计算方式计算参数
|
||||
// Python: stride = int(60 * chunk_size[1] / chunk_interval / 1000 * sample_rate * 2)
|
||||
// 注意:Python中的除法是浮点除法,Go中需要显式转换避免整数除法截断
|
||||
stride := int(float64(60*chunkSize[1]) / float64(chunkInterval) / 1000.0 * float64(sampleRate) * 2.0)
|
||||
|
||||
// 添加安全检查防止除零
|
||||
if stride <= 0 {
|
||||
return nil, fmt.Errorf("calculated stride is zero or negative: %d", stride)
|
||||
}
|
||||
|
||||
// Python: chunk_num = (len(audio_bytes) - 1) // stride + 1
|
||||
chunkNum := (len(audioBytes)-1)/stride + 1
|
||||
|
||||
// 编码为Base64(与Python的base64.b64encode().decode('utf-8')一致)
|
||||
audioBase64 := base64.StdEncoding.EncodeToString(audioBytes)
|
||||
|
||||
return &AudioData{
|
||||
SampleRate: sampleRate,
|
||||
Stride: stride,
|
||||
ChunkNum: chunkNum,
|
||||
AudioBytes: audioBase64,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func IntSlicetoString(nums []int) string {
|
||||
strNums := make([]string, len(nums))
|
||||
for i, num := range nums {
|
||||
@ -109,20 +165,10 @@ func recordFromScp(chunk_begin, chunk_size int) {
|
||||
}
|
||||
|
||||
if strings.HasSuffix(wav_path, ".wav") {
|
||||
cmd := exec.Command("python", "wavhandler.py", wav_path, IntSlicetoString(args.chunk_size), strconv.Itoa(args.chunk_interval))
|
||||
|
||||
var out bytes.Buffer
|
||||
cmd.Stdout = &out
|
||||
err := cmd.Run()
|
||||
// 使用Go原生实现替换Python调用
|
||||
audioData, err := readWAVFile(wav_path, args.chunk_size, args.chunk_interval)
|
||||
if err != nil {
|
||||
fmt.Println("Error running Python script:", err)
|
||||
return
|
||||
}
|
||||
|
||||
var audioData AudioData
|
||||
err = json.Unmarshal(out.Bytes(), &audioData)
|
||||
if err != nil {
|
||||
fmt.Println("Error parsing JSON:", err)
|
||||
fmt.Println("Error reading WAV file:", err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -222,7 +268,7 @@ func message(id string) {
|
||||
var ibestWriter *os.File
|
||||
var err error
|
||||
if args.output_dir != "" {
|
||||
filePath := fmt.Sprintf("%s/text.%d", args.output_dir, id)
|
||||
filePath := fmt.Sprintf("%s/text.%s", args.output_dir, id)
|
||||
ibestWriter, err = os.OpenFile(filePath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to open file: %v", err)
|
||||
|
||||
@ -1,27 +0,0 @@
|
||||
import sys
|
||||
import wave
|
||||
import json
|
||||
import base64
|
||||
|
||||
if __name__ == "__main__":
|
||||
wav_path = sys.argv[1]
|
||||
chunk_size = [int(x) for x in sys.argv[2].split(",")]
|
||||
chunk_interval = int(sys.argv[3])
|
||||
|
||||
with wave.open(wav_path, "rb") as wav_file:
|
||||
params = wav_file.getparams()
|
||||
sample_rate = wav_file.getframerate()
|
||||
frames = wav_file.readframes(wav_file.getnframes())
|
||||
audio_bytes = bytes(frames)
|
||||
|
||||
stride = int(60 * chunk_size[1] / chunk_interval / 1000 * sample_rate * 2)
|
||||
chunk_num = (len(audio_bytes) - 1) // stride + 1
|
||||
|
||||
result = {
|
||||
"sample_rate": sample_rate,
|
||||
"stride": stride,
|
||||
"chunk_num": chunk_num,
|
||||
"audio_bytes": base64.b64encode(audio_bytes).decode('utf-8')
|
||||
}
|
||||
|
||||
print(json.dumps(result))
|
||||
Loading…
Reference in New Issue
Block a user