Skip to content

Commit

Permalink
chore(format): run black on dev
Browse files Browse the repository at this point in the history
  • Loading branch information
github-actions[bot] committed Aug 2, 2024
1 parent c6bae90 commit d325af4
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 21 deletions.
5 changes: 4 additions & 1 deletion ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,10 @@ def _infer_code(

input_ids, attention_mask, text_mask = self.tokenizer.encode(
self.tokenizer.decorate_code_prompts(
text, params.prompt, params.txt_smp, params.spk_emb,
text,
params.prompt,
params.txt_smp,
params.spk_emb,
),
self.config.gpt.num_vq,
prompt_str=params.spk_smp,
Expand Down
16 changes: 11 additions & 5 deletions ChatTTS/model/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,15 @@ def encode(

@torch.inference_mode
def decode(
self, sequences: Union[List[int], List[List[int]]],
self,
sequences: Union[List[int], List[List[int]]],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = None,
**kwargs,
):
return self._tokenizer.batch_decode(sequences, skip_special_tokens, clean_up_tokenization_spaces, **kwargs)
return self._tokenizer.batch_decode(
sequences, skip_special_tokens, clean_up_tokenization_spaces, **kwargs
)

@staticmethod
def _decode_spk_emb(spk_emb: str) -> np.ndarray:
Expand Down Expand Up @@ -219,11 +222,14 @@ def _encode_spk_emb(spk_emb: torch.Tensor) -> str:
)
del arr
return s

@staticmethod
@torch.no_grad()
def decorate_code_prompts(
text: List[str], prompt: str, txt_smp: Optional[str], spk_emb: Optional[str],
text: List[str],
prompt: str,
txt_smp: Optional[str],
spk_emb: Optional[str],
) -> List[str]:
for i, t in enumerate(text):
text[i] = (
Expand All @@ -244,7 +250,7 @@ def decorate_code_prompts(
text = [f"[Stts][spk_emb]{txt_smp}{i}[Ptts]" for i in text]
else:
text = [f"[Stts][empty_spk]{txt_smp}{i}[Ptts]" for i in text]

return text

@staticmethod
Expand Down
7 changes: 5 additions & 2 deletions ChatTTS/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,28 +33,30 @@ def _fast_replace(
replaced_words.append((chr(ch), chr(repl_char)))
return result, replaced_words


@jit
def _split_tags(text: str) -> Tuple[List[str], List[str]]:
texts: List[str] = []
tags: List[str] = []
current_text = ""
current_tag = ""
for c in text:
if c == '[':
if c == "[":
texts.append(current_text)
current_text = ""
current_tag = c
elif current_tag != "":
current_tag += c
else:
current_text += c
if c == ']':
if c == "]":
tags.append(current_tag)
current_tag = ""
if current_text != "":
texts.append(current_text)
return texts, tags


@jit
def _combine_tags(texts: List[str], tags: List[str]) -> str:
text = ""
Expand All @@ -65,6 +67,7 @@ def _combine_tags(texts: List[str], tags: List[str]) -> str:
text += t + tg
return text


class Normalizer:
def __init__(self, map_file_path: str, logger=logging.getLogger(__name__)):
self.logger = logger
Expand Down
43 changes: 30 additions & 13 deletions tests/#655.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,30 +28,37 @@
rand_spk = chat.sample_random_speaker()


text = ['What is [uv_break]your favorite english food?[laugh][lbreak]']
text = ["What is [uv_break]your favorite english food?[laugh][lbreak]"]

fail = False

with TorchSeedContext(12345):
refined_text = chat.infer(
text, refine_text_only=True,
text,
refine_text_only=True,
params_refine_text=ChatTTS.Chat.RefineTextParams(
prompt='[oral_2][laugh_0][break_6]',
prompt="[oral_2][laugh_0][break_6]",
),
)
if refined_text[0] != "like [uv_break] what is [uv_break] your favorite english food [laugh] [lbreak]":
if (
refined_text[0]
!= "like [uv_break] what is [uv_break] your favorite english food [laugh] [lbreak]"
):
fail = True
logger.warning("refined text is '%s'", refined_text[0])

params = ChatTTS.Chat.InferCodeParams(
spk_emb = rand_spk, # add sampled speaker
temperature = .3, # using custom temperature
top_P = 0.7, # top P decode
top_K = 20, # top K decode
spk_emb=rand_spk, # add sampled speaker
temperature=0.3, # using custom temperature
top_P=0.7, # top P decode
top_K=20, # top K decode
)
input_ids, attention_mask, text_mask = chat.tokenizer.encode(
chat.tokenizer.decorate_code_prompts(
text, params.prompt, params.txt_smp, params.spk_emb,
text,
params.prompt,
params.txt_smp,
params.spk_emb,
),
chat.config.gpt.num_vq,
prompt_str=params.spk_smp,
Expand All @@ -62,11 +69,21 @@
input_ids.shape[0], device=input_ids.device, dtype=torch.long
).fill_(input_ids.shape[1])

recoded_text = chat.tokenizer.decode(chat.gpt._prepare_generation_outputs(
input_ids, start_idx, end_idx, [], [], True,
).ids)
recoded_text = chat.tokenizer.decode(
chat.gpt._prepare_generation_outputs(
input_ids,
start_idx,
end_idx,
[],
[],
True,
).ids
)

if recoded_text[0] != '[Stts] [spk_emb] [speed_5] what is [uv_break] your favorite english food? [laugh] [lbreak] [Ptts]':
if (
recoded_text[0]
!= "[Stts] [spk_emb] [speed_5] what is [uv_break] your favorite english food? [laugh] [lbreak] [Ptts]"
):
fail = True
logger.warning("recoded text is '%s'", refined_text)

Expand Down

0 comments on commit d325af4

Please sign in to comment.