diff --git a/CHANGELOG.md b/CHANGELOG.md index e78a3c7..e67aa7c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,11 @@ All notable changes to this project will be documented in this file. +## [1.1.0] 2024-04-19 +### Added +- Download baseline weights during prepare script, and resume option in train and test scripts can now load baseline weights. +- Hub for testing baseline model on a single file. + ## [1.0.0] 2024-03-23 ### Added - Train an AAC model on the Clotho dataset diff --git a/CITATION.cff b/CITATION.cff index 281be93..7b443d4 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -17,8 +17,8 @@ keywords: - dcase2024 license: MIT -version: 1.0.0 -date-released: '2024-03-23' +version: 1.1.0 +date-released: '2024-04-19' preferred-citation: authors: diff --git a/README.md b/README.md index 874b45d..5bd5166 100644 --- a/README.md +++ b/README.md @@ -77,9 +77,31 @@ or specify each path separtely: ```bash dcase24t6-test resume=null model.checkpoint_path=./logs/SAVE_NAME/checkpoints/MODEL.ckpt tokenizer.path=./logs/SAVE_NAME/tokenizer.json ``` - You need to replace `SAVE_NAME` by the save directory name and `MODEL` by the checkpoint filename. +If you want to load and test the baseline pretrained weights, you can specify the baseline checkpoint weights: + +```bash +dcase24t6-test resume=~/.cache/torch/hub/checkpoints/dcase2024-task6-baseline +``` + +### Inference on a file +If you want to test the baseline model on a single file, you can use the `baseline_pipeline` function: + +```python +from dcase24t6.nn.hub import baseline_pipeline + +sr = 44100 +audio = torch.rand(1, sr * 15) + +model = baseline_pipeline() +item = {"audio": audio, "sr": sr} +outputs = model(item) +candidate = outputs["candidates"][0] + +print(candidate) +``` + ## Code overview The source code extensively use [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/) for training and [Hydra](https://hydra.cc/) for configuration. It is highly recommanded to learn about them if you want to understand this code. @@ -99,7 +121,7 @@ Training follows the standard way to create a model with lightning: The model outperforms previous baselines with a SPIDEr-FL score of **29.6%** on the Clotho evaluation subset. The captioning model architecture is described in [this paper](https://arxiv.org/pdf/2309.00454.pdf) and called **CNext-trans**. The encoder part (ConvNeXt) is described in more detail in [this paper](https://arxiv.org/pdf/2306.00830.pdf). -The pretrained weights of the AAC model are available on Zenodo. ([ConvNeXt encoder (BL_AC)](https://zenodo.org/records/8020843), [Transformer decoder](https://zenodo.org/records/10849427)) +The pretrained weights of the AAC model are available on Zenodo: [ConvNeXt encoder (BL_AC)](https://zenodo.org/records/8020843), [Transformer decoder](https://zenodo.org/records/10849427). Both weights are automatically downloaded during `dcase24t6-prepare`. ### Main hyperparameters diff --git a/hubconf.py b/hubconf.py new file mode 100644 index 0000000..99d87eb --- /dev/null +++ b/hubconf.py @@ -0,0 +1,9 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +dependencies = ["git+https://github.com/Labbeti/dcase2024-task6-baseline"] + + +from dcase24t6.nn.hub import baseline_pipeline + +__all__ = ["baseline_pipeline"] diff --git a/requirements.txt b/requirements.txt index 6070b08..d3e448d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,4 +19,4 @@ tensorboard==2.16.2 tokenizers==0.15.2 torch==2.2.1 torchlibrosa==0.1.0 -torchoutil[extras]>=0.2.2,<0.3.0 +torchoutil[extras]~=0.3.0 diff --git a/src/conf/model/baseline.yaml b/src/conf/model/baseline.yaml new file mode 100644 index 0000000..f65fe58 --- /dev/null +++ b/src/conf/model/baseline.yaml @@ -0,0 +1,6 @@ +# @package model + +_target_: "dcase24t6.models.trans_decoder.TransDecoderModel.load_from_checkpoint" + +checkpoint_path: ~/.cache/torch/hub/checkpoints/dcase2024-task6-baseline/checkpoints/best.ckpt +verbose: ${verbose} diff --git a/src/conf/pre_process/cnext.yaml b/src/conf/pre_process/cnext.yaml index 8645da4..46cb632 100644 --- a/src/conf/pre_process/cnext.yaml +++ b/src/conf/pre_process/cnext.yaml @@ -2,7 +2,7 @@ _target_: "dcase24t6.pre_processes.cnext.ResampleMeanCNext" -model_name_or_path: "cnext_bl" +model_name_or_path: "cnext_bl_70" model_sr: 32_000 offline: false device: "cuda_if_available" diff --git a/src/conf/tokenizer/baseline.yaml b/src/conf/tokenizer/baseline.yaml new file mode 100644 index 0000000..c4c85bf --- /dev/null +++ b/src/conf/tokenizer/baseline.yaml @@ -0,0 +1,5 @@ +# @package tokenizer + +_target_: "dcase24t6.tokenization.aac_tokenizer.AACTokenizer.from_file" + +path: ~/.cache/torch/hub/checkpoints/dcase2024-task6-baseline/tokenizer.json diff --git a/src/dcase24t6/__init__.py b/src/dcase24t6/__init__.py index 2c6ab89..4cdb4af 100644 --- a/src/dcase24t6/__init__.py +++ b/src/dcase24t6/__init__.py @@ -8,4 +8,4 @@ __license__ = "MIT" __maintainer__ = "Étienne Labbé (Labbeti)" __status__ = "Released" -__version__ = "1.0.0" +__version__ = "1.1.0" diff --git a/src/dcase24t6/models/trans_decoder.py b/src/dcase24t6/models/trans_decoder.py index a44c4ab..ed214cf 100644 --- a/src/dcase24t6/models/trans_decoder.py +++ b/src/dcase24t6/models/trans_decoder.py @@ -285,7 +285,9 @@ def mix_audio( return mixed_audio, mixed_audio_shape, lbd def encode_audio( - self, frame_embs: Tensor, frame_embs_shape: Tensor + self, + frame_embs: Tensor, + frame_embs_shape: Tensor, ) -> AudioEncoding: # frame_embs: (bsize, 1, in_features, max_seq_size) # frame_embs_shape: (bsize, 3) diff --git a/src/dcase24t6/nn/ckpt.py b/src/dcase24t6/nn/ckpt.py index 97ad8a6..ba4da06 100644 --- a/src/dcase24t6/nn/ckpt.py +++ b/src/dcase24t6/nn/ckpt.py @@ -1,24 +1,56 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -from torchoutil.utils.ckpt import ModelCheckpointRegister +from torchoutil.hub.registry import RegistryHub -# Zenodo link : https://zenodo.org/record/8020843 +# Zenodo link : https://zenodo.org/records/8020843 # Hash type : md5 -CNEXT_REGISTER = ModelCheckpointRegister( +CNEXT_REGISTRY = RegistryHub( infos={ - "cnext_bl": { + "cnext_nobl": { + "architecture": "ConvNeXt", + "url": "https://zenodo.org/record/8020843/files/convnext_tiny_471mAP.pth?download=1", + "hash_value": "e069ecd1c7b880268331119521c549f2", + "hash_type": "md5", + "fname": "convnext_tiny_471mAP.pth", + "state_dict_key": "model", + }, + "cnext_bl_70": { "architecture": "ConvNeXt", "url": "https://zenodo.org/record/8020843/files/convnext_tiny_465mAP_BL_AC_70kit.pth?download=1", - "hash": "0688ae503f5893be0b6b71cb92f8b428", + "hash_value": "0688ae503f5893be0b6b71cb92f8b428", + "hash_type": "md5", "fname": "convnext_tiny_465mAP_BL_AC_70kit.pth", + "state_dict_key": "model", }, - "cnext_nobl": { + "cnext_bl_75": { "architecture": "ConvNeXt", - "url": "https://zenodo.org/record/8020843/files/convnext_tiny_471mAP.pth?download=1", - "hash": "e069ecd1c7b880268331119521c549f2", - "fname": "convnext_tiny_471mAP.pth", + "url": "https://zenodo.org/records/10987498/files/convnext_tiny_465mAP_BL_AC_75kit.pth?download=1", + "hash_value": "f6f57c87b7eb664a23ae8cad26eccaa0", + "hash_type": "md5", + "fname": "convnext_tiny_465mAP_BL_AC_75kit.pth", + }, + }, +) + +# Zenodo link : https://zenodo.org/records/10849427 +# Hash type : md5 +BASELINE_REGISTRY = RegistryHub( + infos={ + "baseline_weights": { + "architecture": "TransDecoderModel", + "url": "https://zenodo.org/records/10849427/files/epoch_192-step_001544-mode_min-val_loss_3.3758.ckpt?download=1", + "hash_value": "9514a8e6fa547bd01fb1badde81c6d10", + "hash_type": "md5", + "fname": "dcase2024-task6-baseline/checkpoints/best.ckpt", + "state_dict_key": "state_dict", + }, + "baseline_tokenizer": { + "architecture": "AACTokenizer", + "url": "https://zenodo.org/records/10849427/files/tokenizer.json?download=1", + "hash_value": "ee3fef19f7d0891d820d84035483a900", + "hash_type": "md5", + "fname": "dcase2024-task6-baseline/tokenizer.json", }, }, - state_dict_key="model", ) diff --git a/src/dcase24t6/nn/hub.py b/src/dcase24t6/nn/hub.py new file mode 100644 index 0000000..c59d7c8 --- /dev/null +++ b/src/dcase24t6/nn/hub.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import os.path as osp +from pathlib import Path + +import nltk +import torch +from torch import nn +from torchoutil.nn.functional import get_device + +from dcase24t6.models.trans_decoder import TransDecoderModel +from dcase24t6.nn.ckpt import BASELINE_REGISTRY +from dcase24t6.pre_processes.cnext import ResampleMeanCNext +from dcase24t6.tokenization.aac_tokenizer import AACTokenizer + + +def baseline_pipeline( + model_name_or_path: str | Path = "baseline_weights", + tokenizer_name_or_path: str | Path = "baseline_tokenizer", + pre_process_name_or_path: str | Path = "cnext_bl_70", + *, + offline: bool = False, + device: str | torch.device | None = "cuda_if_available", + verbose: int = 0, +) -> nn.Sequential: + device = get_device(device) + + if not offline: + nltk.download("stopwords") + + pre_process = ResampleMeanCNext( + pre_process_name_or_path, + offline=offline, + device=device, + keep_batch=True, + ) + + if osp.isfile(model_name_or_path): + model_path = Path(model_name_or_path) + else: + model_name = str(model_name_or_path) + model_path = BASELINE_REGISTRY.get_path(model_name) + if not offline and not model_path.exists(): + BASELINE_REGISTRY.download_file(model_name, verbose=verbose) + + if osp.isfile(tokenizer_name_or_path): + tokenizer_path = Path(tokenizer_name_or_path) + else: + tokenizer_name = str(tokenizer_name_or_path) + tokenizer_path = BASELINE_REGISTRY.get_path(tokenizer_name) + if not offline and not tokenizer_path.exists(): + BASELINE_REGISTRY.download_file(tokenizer_name, verbose=verbose) + + tokenizer = AACTokenizer.from_file(tokenizer_path) + model = TransDecoderModel.load_from_checkpoint(model_path, tokenizer=tokenizer) + pipeline = nn.Sequential(pre_process, model) + pipeline = pipeline.to(device=device) + + return pipeline diff --git a/src/dcase24t6/pre_processes/cnext.py b/src/dcase24t6/pre_processes/cnext.py index 931c343..82d6386 100644 --- a/src/dcase24t6/pre_processes/cnext.py +++ b/src/dcase24t6/pre_processes/cnext.py @@ -9,10 +9,15 @@ from torch import nn from torchoutil.nn.functional.get import get_device -from dcase24t6.nn.ckpt import CNEXT_REGISTER +from dcase24t6.nn.ckpt import CNEXT_REGISTRY from dcase24t6.nn.encoders.convnext import convnext_tiny from dcase24t6.nn.functional import remove_index_nd -from dcase24t6.pre_processes.common import batchify, is_audio_batch, unbatchify +from dcase24t6.pre_processes.common import ( + batchify, + is_audio_batch, + sanitize_batch, + unbatchify, +) from dcase24t6.pre_processes.resample import Resample @@ -24,11 +29,12 @@ class ResampleMeanCNext(nn.Module): def __init__( self, - model_name_or_path: str | Path = "cnext_bl", + model_name_or_path: str | Path = "cnext_bl_70", model_sr: int = 32_000, offline: bool = False, device: str | torch.device | None = "cuda_if_available", input_time_dim: int = -1, + keep_batch: bool = False, ) -> None: device = get_device(device) @@ -42,7 +48,7 @@ def __init__( return_frame_outputs=True, return_clip_outputs=True, ) - state_dict = CNEXT_REGISTER.load_state_dict( + state_dict = CNEXT_REGISTRY.load_state_dict( model_name_or_path, device="cpu", offline=offline, @@ -59,6 +65,7 @@ def __init__( self.convnext = convnext self.resample = resample self.input_time_dim = input_time_dim + self.keep_batch = keep_batch @property def device(self) -> torch.device: @@ -66,11 +73,15 @@ def device(self) -> torch.device: def forward(self, item_or_batch: dict[str, Any]) -> dict[str, Any]: if is_audio_batch(item_or_batch): - return self.forward_batch(item_or_batch) + batch = sanitize_batch(item_or_batch) + return self.forward_batch(batch) + + item = item_or_batch + batch = batchify(item) + batch = self.forward_batch(batch) + if self.keep_batch: + return batch else: - item = item_or_batch - batch = batchify(item) - batch = self.forward_batch(batch) item = unbatchify(batch) return item diff --git a/src/dcase24t6/pre_processes/common.py b/src/dcase24t6/pre_processes/common.py index 52bcf9e..4b3cb03 100644 --- a/src/dcase24t6/pre_processes/common.py +++ b/src/dcase24t6/pre_processes/common.py @@ -36,6 +36,10 @@ def is_audio_batch(item_or_batch: dict[str, Any]) -> bool: ) +def sanitize_batch(batch: dict[str, Any]) -> dict[str, Any]: + return batch + + def batchify(item: dict[str, Any]) -> dict[str, list | Tensor]: """Transform a item dict to a batch dict.""" item = add_audio_shape_to_item(item) diff --git a/src/dcase24t6/pre_processes/resample.py b/src/dcase24t6/pre_processes/resample.py index 22156bd..293b3b7 100644 --- a/src/dcase24t6/pre_processes/resample.py +++ b/src/dcase24t6/pre_processes/resample.py @@ -7,7 +7,12 @@ from torch import Tensor, nn from torchaudio.functional import resample -from dcase24t6.pre_processes.common import batchify, is_audio_batch, unbatchify +from dcase24t6.pre_processes.common import ( + batchify, + is_audio_batch, + sanitize_batch, + unbatchify, +) class Resample(nn.Module): @@ -15,18 +20,24 @@ def __init__( self, target_sr: int = 32_000, input_time_dim: int = -1, + keep_batch: bool = False, ) -> None: super().__init__() self.target_sr = target_sr self.input_time_dim = input_time_dim + self.keep_batch = keep_batch def forward(self, item_or_batch: dict[str, Any]) -> dict[str, Any]: if is_audio_batch(item_or_batch): - return self.forward_batch(item_or_batch) + batch = sanitize_batch(item_or_batch) + return self.forward_batch(batch) + + item = item_or_batch + batch = batchify(item) + batch = self.forward_batch(batch) + if self.keep_batch: + return batch else: - item = item_or_batch - batch = batchify(item) - batch = self.forward_batch(batch) item = unbatchify(batch) return item diff --git a/src/dcase24t6/prepare.py b/src/dcase24t6/prepare.py index 38fba4e..efe1e52 100644 --- a/src/dcase24t6/prepare.py +++ b/src/dcase24t6/prepare.py @@ -33,6 +33,7 @@ from dcase24t6.callbacks.complexity import ComplexityProfiler from dcase24t6.callbacks.emissions import CustomEmissionTracker +from dcase24t6.nn.ckpt import BASELINE_REGISTRY, CNEXT_REGISTRY from dcase24t6.utils.job import get_git_hash from dcase24t6.utils.saving import save_to_yaml @@ -117,6 +118,10 @@ def prepare_data_metrics_models( nltk.download("stopwords") download_metrics(verbose=verbose) + for registry in (BASELINE_REGISTRY, CNEXT_REGISTRY): + for name in registry.names: + registry.download_file(name, force, verbose) + os.makedirs(dataroot, exist_ok=True) if download_clotho: diff --git a/src/dcase24t6/tokenization/aac_tokenizer.py b/src/dcase24t6/tokenization/aac_tokenizer.py index 4162461..60f7a42 100644 --- a/src/dcase24t6/tokenization/aac_tokenizer.py +++ b/src/dcase24t6/tokenization/aac_tokenizer.py @@ -284,7 +284,7 @@ def save(self, path: str | Path, pretty: bool = True) -> None: @classmethod def from_file(cls, path: str | Path) -> "AACTokenizer": """Load tokenizer from JSON file.""" - path = Path(path).resolve() + path = Path(path).resolve().expanduser() content = path.read_text() aac_tokenizer = cls.from_str(content) return aac_tokenizer diff --git a/src/dcase24t6/utils/type_checks.py b/src/dcase24t6/utils/type_checks.py index 643820b..efa00e6 100644 --- a/src/dcase24t6/utils/type_checks.py +++ b/src/dcase24t6/utils/type_checks.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -from typing import Any, TypeGuard +from typing import Any, Iterable, TypeGuard def is_list_int(x: Any) -> TypeGuard[list[int]]: @@ -10,3 +10,11 @@ def is_list_int(x: Any) -> TypeGuard[list[int]]: def is_list_str(x: Any) -> TypeGuard[list[str]]: return isinstance(x, list) and all(isinstance(xi, str) for xi in x) + + +def is_iterable_str(x: Any, *, accept_str: bool) -> TypeGuard[Iterable[str]]: + return (accept_str and isinstance(x, str)) or ( + not isinstance(x, str) + and isinstance(x, Iterable) + and all(isinstance(xi, str) for xi in x) + ) diff --git a/tests/test_nn_hub.py b/tests/test_nn_hub.py new file mode 100644 index 0000000..9c513f9 --- /dev/null +++ b/tests/test_nn_hub.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import unittest +from unittest import TestCase + +import torch + +from dcase24t6.nn.hub import baseline_pipeline + + +class TestPipeline(TestCase): + def test_example_1(self) -> None: + model = baseline_pipeline(device="cpu") + sr = 44100 + audio = torch.rand(1, 1, sr * 15) + audio_shape = torch.as_tensor([audio[0].shape]) + batch = {"audio": audio, "audio_shape": audio_shape, "sr": [sr]} + outputs = model(batch) + + assert isinstance(outputs, dict) + print(outputs["candidates"]) + assert isinstance(outputs["candidates"], list) + + def test_example_2(self) -> None: + model = baseline_pipeline(device="cpu") + sr = 44100 + audio = torch.rand(1, sr * 15) + item = {"audio": audio, "sr": sr} + outputs = model(item) + + assert isinstance(outputs, dict) + print(outputs["candidates"]) + assert isinstance(outputs["candidates"], list) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_pre_process.py b/tests/test_pre_process.py index 9da77a5..ef4e615 100644 --- a/tests/test_pre_process.py +++ b/tests/test_pre_process.py @@ -95,7 +95,7 @@ def test_cnext_output_shapes(self) -> None: ) for duration in durations ] - pre_process = ResampleMeanCNext("cnext_bl", device=device) + pre_process = ResampleMeanCNext("cnext_bl_70", device=device) frame_embs_per_item = [] frame_embs_shape_per_item = []