Skip to content

Commit

Permalink
feat: add [--source ***] [--custom_path XXX] for run cmd (#669)
Browse files Browse the repository at this point in the history
* feat: add simple_run example

Signed-off-by: weedge <weege007@gmail.com>

* fix: mv simple_run cmd and add [--source ***] [--custom_path XXX] for run cmd

Signed-off-by: weedge <weege007@gmail.com>

* fix: mv unuse code from ai gen

Signed-off-by: weedge <weege007@gmail.com>

---------

Signed-off-by: weedge <weege007@gmail.com>
  • Loading branch information
weedge committed Aug 7, 2024
1 parent c88e039 commit 6259319
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 15 deletions.
Empty file added examples/__init__.py
Empty file.
86 changes: 71 additions & 15 deletions examples/cmd/run.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
import os, sys

if sys.platform == "darwin":
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

now_dir = os.getcwd()
sys.path.append(now_dir)

import argparse
from typing import Optional, List
import argparse
import os
import sys

import numpy as np
import torch

import ChatTTS

from tools.audio import pcm_arr_to_mp3_view
from tools.logger import get_logger
from tools.audio import pcm_arr_to_mp3_view
from tools.normalizer.en import normalizer_en_nemo_text
from tools.normalizer.zh import normalizer_zh_tn

if sys.platform == "darwin":
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

now_dir = os.getcwd()
sys.path.append(now_dir)

logger = get_logger("Command")

Expand All @@ -27,12 +29,49 @@ def save_mp3_file(wav, index):
logger.info(f"Audio saved to {mp3_filename}")


def main(texts: List[str], spk: Optional[str] = None, stream=False):
def load_normalizer(chat: ChatTTS.Chat):
# try to load normalizer
try:
chat.normalizer.register("en", normalizer_en_nemo_text())
except ValueError as e:
logger.error(e)
except BaseException:
logger.warning("Package nemo_text_processing not found!")
logger.warning(
"Run: conda install -c conda-forge pynini=2.1.5 && pip install nemo_text_processing",
)
try:
chat.normalizer.register("zh", normalizer_zh_tn())
except ValueError as e:
logger.error(e)
except BaseException:
logger.warning("Package WeTextProcessing not found!")
logger.warning(
"Run: conda install -c conda-forge pynini=2.1.5 && pip install WeTextProcessing",
)


def main(texts: List[str],
spk: Optional[str] = None,
stream: bool = False,
source: str = "local",
custom_path: str = "",
):
logger.info("Text input: %s", str(texts))

chat = ChatTTS.Chat(get_logger("ChatTTS"))
logger.info("Initializing ChatTTS...")
if chat.load():
load_normalizer(chat)

is_load = False
if os.path.isdir(custom_path) and source == "custom":
is_load = chat.load(compile=True,
source="custom",
custom_path=custom_path)
else:
is_load = chat.load(compile=True, source=source)

if is_load:
logger.info("Models loaded successfully.")
else:
logger.error("Models load failed.")
Expand Down Expand Up @@ -69,10 +108,14 @@ def main(texts: List[str], spk: Optional[str] = None, stream=False):


if __name__ == "__main__":
r"""
python -m examples.cmd.run \
--source custom --custom_path ../../models/2Noise/ChatTTS 你好喲 ":)"
"""
logger.info("Starting ChatTTS commandline demo...")
parser = argparse.ArgumentParser(
description="ChatTTS Command",
usage='[--spk xxx] [--stream] "Your text 1." " Your text 2."',
usage='[--spk xxx] [--stream] [--source ***] [--custom_path XXX] "Your text 1." " Your text 2."',
)
parser.add_argument(
"--spk",
Expand All @@ -85,12 +128,25 @@ def main(texts: List[str], spk: Optional[str] = None, stream=False):
help="Use stream mode",
action="store_true",
)
parser.add_argument(
"--source",
help="source form [ huggingface(hf download), local(ckpt save to asset dir), custom(define) ]",
type=str,
default="local",
)
parser.add_argument(
"--custom_path",
help="custom defined model path(include asset ckpt dir)",
type=str,
default="",
)
parser.add_argument(
"texts",
help="Original text",
default=["YOUR TEXT HERE"],
nargs=argparse.REMAINDER,
)
args = parser.parse_args()
main(args.texts, args.spk, args.stream)
logger.info(args)
main(args.texts, args.spk, args.stream, args.source, args.custom_path)
logger.info("ChatTTS process finished.")

0 comments on commit 6259319

Please sign in to comment.