Skip to content

Commit

Permalink
chore(format): run black on dev (#602)
Browse files Browse the repository at this point in the history
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
github-actions[bot] and github-actions[bot] committed Jul 19, 2024
1 parent 6f645b6 commit 3e4ed88
Showing 1 changed file with 18 additions and 7 deletions.
25 changes: 18 additions & 7 deletions ChatTTS/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,12 +406,15 @@ def generate(
attention_mask_cache.narrow(1, 0, attention_mask.shape[1]).copy_(
attention_mask
)

progress = inputs_ids.size(1)
# pre-allocate inputs_ids
inputs_ids_buf = torch.zeros(
inputs_ids.size(0), progress+max_new_token, inputs_ids.size(2),
dtype=inputs_ids.dtype, device=inputs_ids.device,
inputs_ids.size(0),
progress + max_new_token,
inputs_ids.size(2),
dtype=inputs_ids.dtype,
device=inputs_ids.device,
)
inputs_ids_buf.narrow(1, 0, progress).copy_(inputs_ids)
del inputs_ids
Expand Down Expand Up @@ -502,17 +505,25 @@ def generate(
logits = logits.reshape(-1, logits.size(2))
# logits_token = rearrange(inputs_ids[:, start_idx:], "b c n -> (b n) c")
inputs_ids_sliced = inputs_ids.narrow(
1, start_idx, inputs_ids.size(1)-start_idx,
1,
start_idx,
inputs_ids.size(1) - start_idx,
).permute(0, 2, 1)
logits_token = inputs_ids_sliced.reshape(
inputs_ids_sliced.size(0) * inputs_ids_sliced.size(1),
-1,
).to(self.device)
del inputs_ids_sliced
else:
logits_token = inputs_ids.narrow(
1, start_idx, inputs_ids.size(1)-start_idx,
).narrow(2, 0, 1).to(self.device)
logits_token = (
inputs_ids.narrow(
1,
start_idx,
inputs_ids.size(1) - start_idx,
)
.narrow(2, 0, 1)
.to(self.device)
)

logits /= temperature

Expand Down

0 comments on commit 3e4ed88

Please sign in to comment.