Skip to content

Commit

Permalink
optimize: separate speaker from tokenizer
Browse files Browse the repository at this point in the history
- move `spk_stat.pt` into config
- move `sample_audio_speaker` to DVAE
- update rvcmd to `v0.2.7`
  • Loading branch information
fumiama committed Aug 6, 2024
1 parent 2647238 commit f3dcd97
Show file tree
Hide file tree
Showing 12 changed files with 171 additions and 170 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/checksum.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:

- name: Run RVC-Models-Downloader
run: |
wget https://github.com/fumiama/RVC-Models-Downloader/releases/download/v0.2.6/rvcmd_linux_amd64.deb
wget https://github.com/fumiama/RVC-Models-Downloader/releases/download/v0.2.7/rvcmd_linux_amd64.deb
sudo apt -y install ./rvcmd_linux_amd64.deb
rm -f ./rvcmd_linux_amd64.deb
rvcmd -notrs -w 1 -notui assets/chtts
Expand Down
1 change: 1 addition & 0 deletions ChatTTS/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,4 @@ class Config:
dvae: DVAE = DVAE()
gpt: GPT = GPT()
vocos: Vocos = Vocos()
spk_stat: str = "愐穤巩噅廷戇笉屈癐媄垹垧帶爲漈塀殐慄亅倴庲舴猂瑈圐狴夥圓帍戛挠腉耐劤坽喳幾战謇聀崒栄呥倸庭燡欈杁襐褄乭埗幺爃弔摁斐捔兕佖廐舏竾豃磐姓趡佄幒爚欄豄讐皳訵仩帆投謌荃蝐叄圝伆幦抂茁呄掑斃讹傮庞爣蜀橁偐祄亥兡常爂欍扉丐浔佱僈強払伅扂蛐徴憍傞巀戺欀艂琐嗴啥値彷刂權穈扒卤俔贲庛初笂卄贐枴仭亁庛剎猢扃缐趤刁偵幪舏伌煁婐潤晍位弾舙茥穁葏蠣訑企庤刊笍橁溑僔云偁庯戚伍潉膐脴僵噔廃艅匊祂唐憴壝嗙席爥欁虁谐牴帽势弿牳蜁兀蛐傄喩丿帔刔圆衁廐罤庁促帙劢伈汄樐檄勵伴弝舑欍罅虐昴劭勅帜刼朊蕁虐蓴樑伫幨扑謪剀堐稴丵伱弐舮諸赁習俔容厱幫牶謃孄糐答嗝僊帜燲笄終瀒判久僤帘爴茇千孑冄凕佳引扐蜁歁缏裄剽儺恘爋朏眿廐呄塍嘇幻爱茠詁訐剴唭俐幾戊欀硁菐贄楕偒巡爀弎屄莐睳賙凶彎刅漄區唐溴剑劋庽舽猄煃跐夔惥伾庮舎伈罁垑坄怅业怯刁朇獁嶏覔坩俳巶爜朐潁崐萄俹凛常爺笌穀聐此夡倛帡刀匉終窏舣販侽怿扉伥贿憐忓謩姆幌犊漂慆癒却甝兎帼戏欅詂浐朔仹壭帰臷弎恇菐獤帡偖帘爞伅腂皐纤囅充幓戠伥灂丐訤戱倱弋爮嬌癁恐孄侥劬忶刓國詀桒古偩嘄庬戚茝赂监燤嘑勌幦舽持呂諐棤姑再底舡笍艃瀐孴倉傔弋爔猠乁濑塄偽嘧恂舛缇襃厐窴仡刱忕別漇穁岏缴廽价庌爊謈硄讑惤倁儂庭爋伇蝂嶐莔摝傠库刞茄歃戏薤伍伯廮创笠塄熐兴勽俄帅剉最腀砐敤卝侍弆戺朒虃旐蚄梕亖幔牻朣扅贐玔堝噅帡剌圅摀崐彤流僳庙爖嬇啁渐悤堁丛幆刧挜彃悐幤刹嚟恕芁看聀摐焔向乁帖爭欁癃糒圄弙佱廜戤謍婀咐昴焍亩廦艏拼謿芐癤怹兽幸舳朇畁喐稔毝丼弈懲挀譂勑哴啁伎常舭笯晁堑俄叩剔廟爍欦絁夒伤休傑廳戌蜅潆癐彴摑勯床刽欅艁砐忄搉从廡舊猥潂唐委仱僜廼爤朄呃弐礔滵垓幩爄挂筁乐籤刕凟幵爠弉癅乑吴勥伖帪舩茆婁碐幤叭乢巜艳猁桀桐啄唩俊幍舮猀艅焐螔琽亀帋爜缅噃咐斤喩予幩爛笆摀浐猴依侹幃刕園慄蛐栤澹仑座爼謉桃慐浔斕偻幛懰嬓衁愐氄悅仿应芔漄衃敐謤傁匩幹抃圉癄廐裄屵噉幍利謍聂搐蛔嚙坍怗舁圐畃膐栄刵东巆戤諾呃偑媤嗨跞忶爝眄祂朒嶔僭劉忾刐匋癄袐翴珅僷廲芄茈恈皐擄崑伄廉牍匃剃犏澤唑丄庺戃伃煀某杄偙亽帴切缌罄挐尴噙倰带舞漄橄塐糴俩僯帀般漀坂栐更両俇廱舌猁慂拐偤嶱卶应刪眉獁茐伔嘅偺帟舊漂恀栐暄喡乞庙舆匂敀潑恔劑侖延戦盽怶唯慳蝘蟃孫娎益袰玍屃痶翮笪儚裀倹椌玻翀詵筽舘惯堿某侰晈藏缮詗廦夸妎瑻瀒裔媀憞唃冶璭狻渠荑奬熹茅愺氰菣滠翦岓褌泣崲嚭欓湒聙宺爄蛅愸庍匃帆誔穮懌蓪玷澌氋抌訙屌臞廛玸听屺希疭孝凂紋新煎彃膲跱尪懁眆窴珏卓揨菸紭概囥显壌榄垫嘮嬭覤媸侵佮烒耸觌婀秋狃帹葯訤桜糨笾腢伀肶悍炂艤禖岅臺惘梷瞍友盁佨岧憳瓧嘴汬藊愌蘤嶠硴绤蜲襏括勾谂縨妥蓪澭竭萢藜纞糲煮愆瀯孯琓罂諺塿燗狟弙衯揻縷丱糅臄梱瀮杰巳猙亊符胠匃泀廏圃膂蒃籏礩岈簹缌劺燲褡孓膜拔蠿觮呋煣厌尷熜論弲牭紫寊誃紀橴賬傸箍弚窃侫簲慯烣渽祌壓媥噜夽夛諛玹疮禄冪謇媽衤盰缺繑薫兾萧嵱打滽箺嚯凣狢蠜崼覽烸簶盯籓摀苶峸懗泲涻凮愳緗剋笔懆廡瞿椏礤惐藥崍腈烄伹亯昣翬褍絋桫僨吨莌丛矄蜞娈憊苆塁蓏嚢嫼绻崱婋囱蠸篯晣芀繼索兓僖誹岯圪褰蠇唓妷胅巁渮砛傈蝷嵚冃購赁峍裋荂舾符熻岳墩寮粃凲袑彚太绲头摯繳狁俥籌冝諝註坎幫擤詒宒凕賐唶梎噔弼課屿覍囨焬櫱撪蝮蝬簸懰櫫涺嵍睻屪翔峞慘滟熲昱军烊舿尦舄糖奁溏凂彆蝲糴禍困皻灏牋睒诙嶱臀开蓈眎腼丢纻廏憤嫖暭袭崲肸螛妒榗紉谨窮袃瑠聍绊腆亿冲葐喋縔詖岑兾给堸赏旻桀蛨媆訂峦紷敯囬偐筨岸焸拭笵殒哜墒萍屓娓諙械臮望摰芑寭准僞谹氍旋憢菮屃划欣瘫谎蘻哐繁籥禦僿誵皯墓燀縿笞熦绗稹榎矻綞蓓帡戓沺区才畃洊詪糐裶盰窶耎偌劂誐庩惝滜沺哮呃煐譠崄槀猄肼蔐擋湌蠺篃恥諌瞦宍堫挪裕崑慩狲悠煋仛愞砈粵八棁害楐妋萔貨尵奂苰怫誎傫岆蕯屇脉夈仆茎刓繸芺壸碗曛汁戭炻獻凉媁兎狜爴怰賃纎袏娷禃蓥膹薪渻罸窿粫凾褄舺窮墫干苊繁冏僮訸夯绛蓪虛羽慲烏憷趎睊蠰莍塞成廎盁欏喓蜮譤崆楁囘矇薭伣艘虝帴奮苢渶虎暣翐蝃尾稈糶瀴罐嵚氮葯笫慐棌悶炯竻爅们媡姢嫺窷刮歫劈裩屬椕賑蜹薊刲義哯尗褦瓀稾礋揣窼舫尋姁椄侸嗫珺修纘媃腽蛛稹梭呛瀈蘟縀礉論夵售主梮蠉娅娭裀誼嶭観枳倊簈褃擞綿催瞃溶苊笛襹櫲盅六囫獩佃粨慯瓢眸旱荃婨蔞岋祗墼焻网牻琖詆峋秉胳媴袭澓賢経稟壩胫碯偏囫嶎纆窈槊賐撹璬莃缘誾宭愊眗喷监劋萘訯總槿棭戾墮犄恌縈簍樥蛔杁袭嫛憫倆篏墵賈羯茎觳蒜致娢慄勒覸蘍曲栂葭宆妋皽缽免盳猼蔂糥觧烳檸佯憓煶蔐筼种繷琲膌塄剰讎対腕棥渽忲俛浪譬秛惛壒嘸淫冻曄睻砃奫貯庴爅粓脮脡娎妖峵蘲討惋泊蠀㴆"
41 changes: 10 additions & 31 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from huggingface_hub import snapshot_download

from .config import Config
from .model import DVAE, GPT, gen_logits, Tokenizer
from .model import DVAE, GPT, gen_logits, Tokenizer, Speaker
from .utils import (
check_all_assets,
download_all_assets,
Expand Down Expand Up @@ -152,25 +152,12 @@ def unload(self):
if hasattr(self, module):
delattr(self, module)
self.__init__(logger)

def sample_random_speaker(self) -> str:
return self.tokenizer._encode_spk_emb(self._sample_random_speaker())
return self.speaker.sample_random()

@torch.inference_mode()
def sample_audio_speaker(self, wav: Union[np.ndarray, torch.Tensor]) -> str:
if isinstance(wav, np.ndarray):
wav = torch.from_numpy(wav).to(self.device)
return self.tokenizer._encode_prompt(self.dvae(wav, "encode").squeeze_(0))

@torch.no_grad()
def _sample_random_speaker(self) -> torch.Tensor:
dim: int = self.config.gpt.hidden_size
spk = (
torch.randn(dim, device=self.std.device, dtype=self.std.dtype)
.mul_(self.std)
.add_(self.mean)
)
return spk
return self.speaker.encode_prompt(self.dvae.sample_audio(wav))

@dataclass(repr=False, eq=False)
class RefineTextParams:
Expand Down Expand Up @@ -303,15 +290,7 @@ def _load(
gpt.prepare(compile=compile and "cuda" in str(device))
self.gpt = gpt

spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), "spk_stat.pt")
assert os.path.exists(spk_stat_path), f"Missing spk_stat.pt: {spk_stat_path}"
spk_stat: torch.Tensor = torch.load(
spk_stat_path,
weights_only=True,
mmap=True,
map_location=device,
)
self.std, self.mean = spk_stat.requires_grad_(False).chunk(2)
self.speaker = Speaker(self.config.gpt.hidden_size, self.config.spk_stat, device)
self.logger.log(logging.INFO, "gpt loaded.")

decoder = (
Expand Down Expand Up @@ -479,14 +458,14 @@ def _infer_code(
temperature = params.temperature

input_ids, attention_mask, text_mask = self.tokenizer.encode(
self.tokenizer.decorate_code_prompts(
self.speaker.decorate_code_prompts(
text,
params.prompt,
params.txt_smp,
params.spk_emb,
),
self.config.gpt.num_vq,
prompt_str=params.spk_smp,
prompt=self.speaker.decode_prompt(params.spk_smp) if params.spk_smp is not None else None,
device=self.device_gpt,
)
start_idx = input_ids.shape[-2]
Expand Down Expand Up @@ -544,8 +523,8 @@ def _infer_code(
del text_mask

if params.spk_emb is not None:
self.tokenizer.apply_spk_emb(
emb, params.spk_emb, input_ids, self.gpt.device_gpt
self.speaker.apply(
emb, params.spk_emb, input_ids, self.tokenizer.spk_emb_ids, self.gpt.device_gpt,
)

result = gpt.generate(
Expand Down Expand Up @@ -585,7 +564,7 @@ def _refine_text(
text = [text]

input_ids, attention_mask, text_mask = self.tokenizer.encode(
self.tokenizer.decorate_text_prompts(text, params.prompt),
self.speaker.decorate_text_prompts(text, params.prompt),
self.config.gpt.num_vq,
device=self.device_gpt,
)
Expand Down
1 change: 1 addition & 0 deletions ChatTTS/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .dvae import DVAE
from .gpt import GPT
from .processors import gen_logits
from .speaker import Speaker
from .tokenizer import Tokenizer
10 changes: 8 additions & 2 deletions ChatTTS/model/dvae.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import List, Optional, Literal, Tuple
from typing import List, Optional, Literal, Union

import numpy as np
import pybase16384 as b14
Expand Down Expand Up @@ -216,7 +216,7 @@ def __init__(
coef = torch.rand(100)
else:
coef = torch.from_numpy(
np.copy(np.frombuffer(b14.decode_from_string(coef), dtype=np.float32))
np.frombuffer(b14.decode_from_string(coef), dtype=np.float32).copy()
)
self.register_buffer("coef", coef.unsqueeze(0).unsqueeze_(2))

Expand Down Expand Up @@ -284,3 +284,9 @@ def forward(
del vq_feats

return torch.mul(dec_out, self.coef, out=dec_out)

@torch.inference_mode()
def sample_audio(self, wav: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
if isinstance(wav, np.ndarray):
wav = torch.from_numpy(wav)
return self(wav, "encode").squeeze_(0)
146 changes: 146 additions & 0 deletions ChatTTS/model/speaker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import lzma
from typing import List, Optional, Union

import pybase16384 as b14
import numpy as np
import torch
import torch.nn.functional as F

class Speaker:
def __init__(self, dim: int, spk_cfg: str, device=torch.device("cpu")) -> None:
spk_stat = torch.from_numpy(np.frombuffer(b14.decode_from_string(spk_cfg), dtype=np.float16).copy()).to(device=device)
self.std, self.mean = spk_stat.requires_grad_(False).chunk(2)
self.dim = dim

def sample_random(self) -> str:
return self._encode(self._sample_random())

@torch.no_grad()
def apply(
self,
emb: torch.Tensor,
spk_emb: str,
input_ids: torch.Tensor,
spk_emb_ids: int,
device: torch.device,
):
n = (
F.normalize(
torch.from_numpy(
self._decode(spk_emb),
),
p=2.0,
dim=0,
eps=1e-12,
)
.to(device)
.unsqueeze_(0)
.expand(emb.size(0), -1)
.unsqueeze_(1)
.expand(emb.shape)
)
cond = input_ids.narrow(-1, 0, 1).eq(spk_emb_ids).expand(emb.shape)
torch.where(cond, n, emb, out=emb)
del cond, n

@staticmethod
@torch.no_grad()
def decorate_code_prompts(
text: List[str],
prompt: str,
txt_smp: Optional[str],
spk_emb: Optional[str],
) -> List[str]:
for i, t in enumerate(text):
text[i] = (
t.replace("[Stts]", "")
.replace("[spk_emb]", "")
.replace("[empty_spk]", "")
.strip()
)
"""
see https://github.com/2noise/ChatTTS/issues/459
"""

if prompt:
text = [prompt + i for i in text]

txt_smp = "" if txt_smp is None else txt_smp
if spk_emb is not None:
text = [f"[Stts][spk_emb]{txt_smp}{i}[Ptts]" for i in text]
else:
text = [f"[Stts][empty_spk]{txt_smp}{i}[Ptts]" for i in text]

return text

@staticmethod
@torch.no_grad()
def decorate_text_prompts(text: List[str], prompt: str) -> List[str]:
return [f"[Sbreak]{i}[Pbreak]{prompt}" for i in text]

@staticmethod
@torch.no_grad()
def encode_prompt(prompt: torch.Tensor) -> str:
arr: np.ndarray = prompt.cpu().numpy().astype(np.uint16)
shp = arr.shape
assert len(shp) == 2, "prompt must be a 2D tensor"
s = b14.encode_to_string(
np.array(shp, dtype="<u2").tobytes()
+ lzma.compress(
arr.astype("<u2").tobytes(),
format=lzma.FORMAT_RAW,
filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}],
),
)
del arr
return s

@staticmethod
@torch.no_grad()
def decode_prompt(prompt: str) -> torch.Tensor:
dec = b14.decode_from_string(prompt)
shp = np.frombuffer(dec[:4], dtype="<u2")
p = np.frombuffer(
lzma.decompress(
dec[4:],
format=lzma.FORMAT_RAW,
filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}],
),
dtype="<u2",
).copy()
del dec
return torch.from_numpy(p.astype(np.int32)).view(*shp)

@torch.no_grad()
def _sample_random(self) -> torch.Tensor:
spk = (
torch.randn(self.dim, device=self.std.device, dtype=self.std.dtype)
.mul_(self.std)
.add_(self.mean)
)
return spk

@staticmethod
@torch.no_grad()
def _encode(spk_emb: torch.Tensor) -> str:
arr: np.ndarray = spk_emb.to(dtype=torch.float16, device="cpu").numpy()
s = b14.encode_to_string(
lzma.compress(
arr.tobytes(),
format=lzma.FORMAT_RAW,
filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}],
),
)
del arr
return s

@staticmethod
def _decode(spk_emb: str) -> np.ndarray:
return np.frombuffer(
lzma.decompress(
b14.decode_from_string(spk_emb),
format=lzma.FORMAT_RAW,
filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}],
),
dtype=np.float16,
).copy()
Loading

0 comments on commit f3dcd97

Please sign in to comment.