Skip to content

Commit

Permalink
chore: restore some latest changes
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama committed Jul 21, 2024
1 parent fe68af9 commit 9c0a1df
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
8 changes: 5 additions & 3 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def _load(
device = select_device()
self.logger.info("use device %s", str(device))
self.device = device
self.device_gpt = device if "mps" not in str(device) else torch.device("cpu")
self.compile = compile

feature_extractor = instantiate_class(
Expand Down Expand Up @@ -299,7 +300,8 @@ def _load(
gpt = GPT(
**cfg,
use_flash_attn=use_flash_attn,
device=device,
device=self.device,
device_gpt=self.device_gpt,
logger=self.logger,
).eval()
assert gpt_ckpt_path, "gpt_ckpt_path should not be None"
Expand Down Expand Up @@ -537,7 +539,7 @@ def _infer_code(
text,
self.config.gpt.num_vq,
prompt_str=params.spk_smp,
device=self.device,
device=self.device_gpt,
)
start_idx = input_ids.shape[-2]

Expand Down Expand Up @@ -597,7 +599,7 @@ def _refine_text(
input_ids, attention_mask, text_mask = self.tokenizer.encode(
text,
self.config.gpt.num_vq,
device=self.device,
device=self.device_gpt,
)

start_idx = input_ids.shape[-2]
Expand Down
3 changes: 2 additions & 1 deletion ChatTTS/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,15 @@ def __init__(
num_vq=4,
use_flash_attn=False,
device=torch.device("cpu"),
device_gpt=torch.device("cpu"),
logger=logging.getLogger(__name__),
):
super().__init__()

self.logger = logger

self.device = device
self.device_gpt = device if "mps" not in str(device) else torch.device("cpu")
self.device_gpt = device_gpt

self.num_vq = num_vq
self.num_audio_tokens = num_audio_tokens
Expand Down

0 comments on commit 9c0a1df

Please sign in to comment.