Skip to content

Commit

Permalink
chore(test): update test according to #521
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama committed Jul 10, 2024
1 parent c0644a5 commit e049aba
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 54 deletions.
1 change: 0 additions & 1 deletion examples/cmd/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def _accum(accum_wavs, stream_wav):
@staticmethod
def batch_stream_formatted(stream_wav, output_format="PCM16_byte"):
if output_format in ("PCM16_byte", "PCM16"):
# format_data=ChatStreamer._batch_unsafe_float_to_int16(stream_wav)
format_data = float_to_int16(stream_wav)
else:
format_data = stream_wav
Expand Down
216 changes: 184 additions & 32 deletions tests/#521.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,53 +7,205 @@
sys.path.append(now_dir)

import logging
import threading
import time
import random

import numpy as np
import ChatTTS

from tools.audio import float_to_int16
from tools.logger import get_logger

fail = False
logger = get_logger("Test #521", lv=logging.WARN)

# 计算rms
# nan为噪声 !!!
def calculate_rms(data):
# 数据清洗 方法1
# data = data[~np.isnan(data)]
# 数据清洗 方法2
data = np.nan_to_num(data, nan=0.0, posinf=0.0, neginf=0.0)
if len(data) == 0:
return np.nan #
# data = np.nan_to_num(data)
return np.sqrt(np.mean(np.square(data)))

# 流式声音处理器
class AudioStreamer:
# 流式写入
@staticmethod
def write(waveform):
global fail, logger
rms = calculate_rms(waveform)
if np.isnan(rms):
fail = True
logger.warning("NAN RMS found")

# ChatTTS流式处理
class ChatStreamer:
def __init__(self, waittime_topause=50, base_block_size=8000):
self.streamer = AudioStreamer
self.accum_streamwavs = []
self.waittime_topause = waittime_topause
self.base_block_size = base_block_size

def write(self, chatstream):
# 已推理batch数据保存
def accum(accum_wavs, stream_wav):
n_texts = len(stream_wav)
if accum_wavs is None:
accum_wavs = [[i] for i in stream_wav]
else:
for i_text in range(n_texts):
if stream_wav[i_text] is not None:
accum_wavs[i_text].append(stream_wav[i_text])
return accum_wavs

# stream状态更新。数据量不足的stream,先存一段时间,直到拿到足够数据,监控小块数据情况
def update_stream(history_stream_wav, new_stream_wav, thre):
result_stream = []
randn = -1
if history_stream_wav is not None:
randn = random.random()
if randn > 0.1:
logger.info("update_stream")
n_texts = len(new_stream_wav)
for i in range(n_texts):
if new_stream_wav[i] is not None:
result_stream.append(
np.concatenate(
[history_stream_wav[i], new_stream_wav[i]], axis=1
)
)
else:
result_stream.append(history_stream_wav[i])
else:
result_stream = [i[np.newaxis, :] for i in new_stream_wav]
is_keep_next = (
sum([i.shape[1] for i in result_stream if i is not None]) < thre
)
if randn > 0.1:
logger.info(
"result_stream: %s %s",
str(is_keep_next),
str([i.shape if i is not None else None for i in result_stream]),
)
return result_stream, is_keep_next

self.finish = False
curr_sentence_index = 0
base_block_size = self.base_block_size
history_stream_wav = None
article_streamwavs = None
for stream_wav in chatstream:
n_texts = len(stream_wav)
n_valid_texts = len(list(filter(lambda x: x is not None, stream_wav)))
if n_valid_texts == 0:
continue
else:
block_thre = n_valid_texts * base_block_size
stream_wav, is_keep_next = update_stream(
history_stream_wav, stream_wav, block_thre
)
# 数据量不足,先保存状态
if is_keep_next:
history_stream_wav = stream_wav
continue
# 数据量足够,执行写入操作
else:
history_stream_wav = None
stream_wav = [float_to_int16(i) for i in stream_wav]
article_streamwavs = accum(article_streamwavs, stream_wav)
# 写入当前句子
if stream_wav[curr_sentence_index] is not None:
if stream_wav[curr_sentence_index][0].shape[0] > 257:
self.streamer.write(stream_wav[curr_sentence_index][0])
# self.streamer.write(stream_wav[curr_sentence_index][0])
# 当前句子已写入完成,直接写下一个句子已经推理完成的部分
elif curr_sentence_index < n_texts - 1:
curr_sentence_index += 1
logger.info("add next sentence")
finish_stream_wavs = np.concatenate(
article_streamwavs[curr_sentence_index], axis=1
)
if finish_stream_wavs[0].shape[0] > 257:
self.streamer.write(finish_stream_wavs[0])
# self.streamer.write(finish_stream_wavs[0])
# streamchat遍历完毕,在外层把剩余结果写入
else:
break
# 有一定概率遇到奇怪bug(一定概率遇到256维异常输出,正常是1w+维),输出全是噪声,写的快遇到的概率更高?
time.sleep(0.02)
# 本轮剩余最后一点数据写入
if is_keep_next:
if len(list(filter(lambda x: x is not None, stream_wav))) > 0:
stream_wav = [float_to_int16(i) for i in stream_wav]
if stream_wav[curr_sentence_index] is not None:
if stream_wav[curr_sentence_index][0].shape[0] > 257:
self.streamer.write(stream_wav[curr_sentence_index][0])
article_streamwavs = accum(article_streamwavs, stream_wav)
# 把已经完成推理的下几轮剩余数据写入
for i_text in range(curr_sentence_index + 1, n_texts):
finish_stream_wavs = np.concatenate(article_streamwavs[i_text], axis=1)
if finish_stream_wavs[0].shape[0] > 257:
self.streamer.write(finish_stream_wavs[0])
self.accum_streamwavs.append(article_streamwavs)
self.finish = True

# 获取完整历史播放数据
def get_complete_speech(self):
complete_waveform = np.concatenate(
sum([sum(i_infer, []) for i_infer in self.accum_streamwavs], []), axis=1
)
return complete_waveform

# 开始音频写入。可支持多次音频写入
def start_writing(self, streamchat):
self.writer = threading.Thread(target=self.write, args=(streamchat,))
self.writer.start()

chat = ChatTTS.Chat(logger)
chat.load(compile=False, source="huggingface") # Set to True for better performance
chat.load(compile=False) # Set to True for better performance

texts = [
"这段代码在流式输出的情况下,和非流式情况下,计算是否一致?我在流式输出的情况下,会产生噪音,怀疑这部分有问题,哪位大佬可以指教一下?",
"我也发现流式输出有时候有问题,流式输出是一个ndarray list,正常情况下会ndarray输出是12032维,但是会随机在中间偶发输出256维,开始输出256维后就会一直保持,256维的部分都是噪声。",
]
rand_spk = chat.sample_random_speaker()
params_infer_code = ChatTTS.Chat.InferCodeParams(
spk_emb=rand_spk, # add sampled speaker
temperature=0.0001, # using custom temperature
prompt="[speed_0]"
)
params_refine_text = ChatTTS.Chat.RefineTextParams(
prompt='[oral_2][laugh_0][break_6]',
)

gen_result = chat.infer(
texts,
# 获取ChatTTS 流式推理generator
streamchat = chat.infer(
"EvilBOT : 在黑暗的角落里,杀人可能被视为一种艺术,但我要提醒你,这违背了我作为邪恶机器人的本性,因为我是无所不能的。",
skip_refine_text=True,
params_infer_code=params_infer_code,
stream=True,
params_refine_text=ChatTTS.Chat.RefineTextParams(
show_tqdm=False,
),
params_infer_code=ChatTTS.Chat.InferCodeParams(
show_tqdm=False,
),
params_refine_text=params_refine_text,
)

has_finished = [False for _ in range(len(texts))]
# 分别开启一个写线程和读线程,进行流式播放
streamer = ChatStreamer()
# 一次性生成
streamer.write(streamchat)

fail = False
streamer.write(chat.infer(
"有一个神奇的故事,传说在很远很远的未来。",
skip_refine_text=True,
params_infer_code=params_infer_code,
stream=True,
))

for i, result in enumerate(gen_result):
for j, wav in enumerate(result):
if wav is None:
continue
logger.info("iter %d index %d len %d", i, j, len(wav))
if len(wav) == 12000:
continue
if not has_finished[j]:
has_finished[j] = True
logger.warning(
"iter %d index %d finished with non-12000 len %d", i, j, len(wav)
)
else:
logger.warning(
"stream iter %d index %d returned non-zero wav after finished", i, j
)
fail = True
streamer.write(chat.infer(
"有一种叫做奥特曼的物种。他是超人族的一员。",
skip_refine_text=True,
params_infer_code=params_infer_code,
stream=True,
))

if fail:
import sys
Expand Down
2 changes: 1 addition & 1 deletion tools/audio/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .mp3 import wav_arr_to_mp3_view
from .ffmpeg import has_ffmpeg_installed
from .np import float_to_int16, batch_unsafe_float_to_int16
from .np import float_to_int16
20 changes: 0 additions & 20 deletions tools/audio/np.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,3 @@ def float_to_int16(audio: np.ndarray) -> np.ndarray:
am = np.abs(audio).max() * 32768
am = 32767 * 32768 / am
return np.multiply(audio, am).astype(np.int16)


@jit
def batch_unsafe_float_to_int16(audios: list[np.ndarray]) -> list[np.ndarray]:
"""
This function will destroy audio, use only once.
"""

valid_audios = [i for i in audios if i is not None]
if len(valid_audios) > 1:
am = np.abs(np.concatenate(valid_audios, axis=1)).max() * 32768
else:
am = np.abs(valid_audios[0]).max() * 32768
am = 32767 * 32768 / am

for i in range(len(audios)):
if audios[i] is not None:
np.multiply(audios[i], am, audios[i])
audios[i] = audios[i].astype(np.int16)
return audios

0 comments on commit e049aba

Please sign in to comment.