Skip to content

Commit

Permalink
Version 1.1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
Labbeti committed Apr 19, 2024
1 parent 944bd8a commit 137849f
Show file tree
Hide file tree
Showing 20 changed files with 252 additions and 34 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

All notable changes to this project will be documented in this file.

## [1.1.0] 2024-04-19
### Added
- Download baseline weights during prepare script, and resume option in train and test scripts can now load baseline weights.
- Hub for testing baseline model on a single file.

## [1.0.0] 2024-03-23
### Added
- Train an AAC model on the Clotho dataset
Expand Down
4 changes: 2 additions & 2 deletions CITATION.cff
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ keywords:
- dcase2024

license: MIT
version: 1.0.0
date-released: '2024-03-23'
version: 1.1.0
date-released: '2024-04-19'

preferred-citation:
authors:
Expand Down
26 changes: 24 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,31 @@ or specify each path separtely:
```bash
dcase24t6-test resume=null model.checkpoint_path=./logs/SAVE_NAME/checkpoints/MODEL.ckpt tokenizer.path=./logs/SAVE_NAME/tokenizer.json
```

You need to replace `SAVE_NAME` by the save directory name and `MODEL` by the checkpoint filename.

If you want to load and test the baseline pretrained weights, you can specify the baseline checkpoint weights:

```bash
dcase24t6-test resume=~/.cache/torch/hub/checkpoints/dcase2024-task6-baseline
```

### Inference on a file
If you want to test the baseline model on a single file, you can use the `baseline_pipeline` function:

```python
from dcase24t6.nn.hub import baseline_pipeline

sr = 44100
audio = torch.rand(1, sr * 15)

model = baseline_pipeline()
item = {"audio": audio, "sr": sr}
outputs = model(item)
candidate = outputs["candidates"][0]

print(candidate)
```

## Code overview
The source code extensively use [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/) for training and [Hydra](https://hydra.cc/) for configuration.
It is highly recommanded to learn about them if you want to understand this code.
Expand All @@ -99,7 +121,7 @@ Training follows the standard way to create a model with lightning:
The model outperforms previous baselines with a SPIDEr-FL score of **29.6%** on the Clotho evaluation subset.
The captioning model architecture is described in [this paper](https://arxiv.org/pdf/2309.00454.pdf) and called **CNext-trans**. The encoder part (ConvNeXt) is described in more detail in [this paper](https://arxiv.org/pdf/2306.00830.pdf).

The pretrained weights of the AAC model are available on Zenodo. ([ConvNeXt encoder (BL_AC)](https://zenodo.org/records/8020843), [Transformer decoder](https://zenodo.org/records/10849427))
The pretrained weights of the AAC model are available on Zenodo: [ConvNeXt encoder (BL_AC)](https://zenodo.org/records/8020843), [Transformer decoder](https://zenodo.org/records/10849427). Both weights are automatically downloaded during `dcase24t6-prepare`.

### Main hyperparameters

Expand Down
9 changes: 9 additions & 0 deletions hubconf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

dependencies = ["git+https://github.com/Labbeti/dcase2024-task6-baseline"]


from dcase24t6.nn.hub import baseline_pipeline

__all__ = ["baseline_pipeline"]
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ tensorboard==2.16.2
tokenizers==0.15.2
torch==2.2.1
torchlibrosa==0.1.0
torchoutil[extras]>=0.2.2,<0.3.0
torchoutil[extras]~=0.3.0
6 changes: 6 additions & 0 deletions src/conf/model/baseline.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# @package model

_target_: "dcase24t6.models.trans_decoder.TransDecoderModel.load_from_checkpoint"

checkpoint_path: ~/.cache/torch/hub/checkpoints/dcase2024-task6-baseline/checkpoints/best.ckpt
verbose: ${verbose}
2 changes: 1 addition & 1 deletion src/conf/pre_process/cnext.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

_target_: "dcase24t6.pre_processes.cnext.ResampleMeanCNext"

model_name_or_path: "cnext_bl"
model_name_or_path: "cnext_bl_70"
model_sr: 32_000
offline: false
device: "cuda_if_available"
5 changes: 5 additions & 0 deletions src/conf/tokenizer/baseline.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# @package tokenizer

_target_: "dcase24t6.tokenization.aac_tokenizer.AACTokenizer.from_file"

path: ~/.cache/torch/hub/checkpoints/dcase2024-task6-baseline/tokenizer.json
2 changes: 1 addition & 1 deletion src/dcase24t6/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@
__license__ = "MIT"
__maintainer__ = "Étienne Labbé (Labbeti)"
__status__ = "Released"
__version__ = "1.0.0"
__version__ = "1.1.0"
4 changes: 3 additions & 1 deletion src/dcase24t6/models/trans_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,9 @@ def mix_audio(
return mixed_audio, mixed_audio_shape, lbd

def encode_audio(
self, frame_embs: Tensor, frame_embs_shape: Tensor
self,
frame_embs: Tensor,
frame_embs_shape: Tensor,
) -> AudioEncoding:
# frame_embs: (bsize, 1, in_features, max_seq_size)
# frame_embs_shape: (bsize, 3)
Expand Down
52 changes: 42 additions & 10 deletions src/dcase24t6/nn/ckpt.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,56 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from torchoutil.utils.ckpt import ModelCheckpointRegister
from torchoutil.hub.registry import RegistryHub

# Zenodo link : https://zenodo.org/record/8020843
# Zenodo link : https://zenodo.org/records/8020843
# Hash type : md5
CNEXT_REGISTER = ModelCheckpointRegister(
CNEXT_REGISTRY = RegistryHub(
infos={
"cnext_bl": {
"cnext_nobl": {
"architecture": "ConvNeXt",
"url": "https://zenodo.org/record/8020843/files/convnext_tiny_471mAP.pth?download=1",
"hash_value": "e069ecd1c7b880268331119521c549f2",
"hash_type": "md5",
"fname": "convnext_tiny_471mAP.pth",
"state_dict_key": "model",
},
"cnext_bl_70": {
"architecture": "ConvNeXt",
"url": "https://zenodo.org/record/8020843/files/convnext_tiny_465mAP_BL_AC_70kit.pth?download=1",
"hash": "0688ae503f5893be0b6b71cb92f8b428",
"hash_value": "0688ae503f5893be0b6b71cb92f8b428",
"hash_type": "md5",
"fname": "convnext_tiny_465mAP_BL_AC_70kit.pth",
"state_dict_key": "model",
},
"cnext_nobl": {
"cnext_bl_75": {
"architecture": "ConvNeXt",
"url": "https://zenodo.org/record/8020843/files/convnext_tiny_471mAP.pth?download=1",
"hash": "e069ecd1c7b880268331119521c549f2",
"fname": "convnext_tiny_471mAP.pth",
"url": "https://zenodo.org/records/10987498/files/convnext_tiny_465mAP_BL_AC_75kit.pth?download=1",
"hash_value": "f6f57c87b7eb664a23ae8cad26eccaa0",
"hash_type": "md5",
"fname": "convnext_tiny_465mAP_BL_AC_75kit.pth",
},
},
)

# Zenodo link : https://zenodo.org/records/10849427
# Hash type : md5
BASELINE_REGISTRY = RegistryHub(
infos={
"baseline_weights": {
"architecture": "TransDecoderModel",
"url": "https://zenodo.org/records/10849427/files/epoch_192-step_001544-mode_min-val_loss_3.3758.ckpt?download=1",
"hash_value": "9514a8e6fa547bd01fb1badde81c6d10",
"hash_type": "md5",
"fname": "dcase2024-task6-baseline/checkpoints/best.ckpt",
"state_dict_key": "state_dict",
},
"baseline_tokenizer": {
"architecture": "AACTokenizer",
"url": "https://zenodo.org/records/10849427/files/tokenizer.json?download=1",
"hash_value": "ee3fef19f7d0891d820d84035483a900",
"hash_type": "md5",
"fname": "dcase2024-task6-baseline/tokenizer.json",
},
},
state_dict_key="model",
)
60 changes: 60 additions & 0 deletions src/dcase24t6/nn/hub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os.path as osp
from pathlib import Path

import nltk
import torch
from torch import nn
from torchoutil.nn.functional import get_device

from dcase24t6.models.trans_decoder import TransDecoderModel
from dcase24t6.nn.ckpt import BASELINE_REGISTRY
from dcase24t6.pre_processes.cnext import ResampleMeanCNext
from dcase24t6.tokenization.aac_tokenizer import AACTokenizer


def baseline_pipeline(
model_name_or_path: str | Path = "baseline_weights",
tokenizer_name_or_path: str | Path = "baseline_tokenizer",
pre_process_name_or_path: str | Path = "cnext_bl_70",
*,
offline: bool = False,
device: str | torch.device | None = "cuda_if_available",
verbose: int = 0,
) -> nn.Sequential:
device = get_device(device)

if not offline:
nltk.download("stopwords")

pre_process = ResampleMeanCNext(
pre_process_name_or_path,
offline=offline,
device=device,
keep_batch=True,
)

if osp.isfile(model_name_or_path):
model_path = Path(model_name_or_path)
else:
model_name = str(model_name_or_path)
model_path = BASELINE_REGISTRY.get_path(model_name)
if not offline and not model_path.exists():
BASELINE_REGISTRY.download_file(model_name, verbose=verbose)

if osp.isfile(tokenizer_name_or_path):
tokenizer_path = Path(tokenizer_name_or_path)
else:
tokenizer_name = str(tokenizer_name_or_path)
tokenizer_path = BASELINE_REGISTRY.get_path(tokenizer_name)
if not offline and not tokenizer_path.exists():
BASELINE_REGISTRY.download_file(tokenizer_name, verbose=verbose)

tokenizer = AACTokenizer.from_file(tokenizer_path)
model = TransDecoderModel.load_from_checkpoint(model_path, tokenizer=tokenizer)
pipeline = nn.Sequential(pre_process, model)
pipeline = pipeline.to(device=device)

return pipeline
27 changes: 19 additions & 8 deletions src/dcase24t6/pre_processes/cnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,15 @@
from torch import nn
from torchoutil.nn.functional.get import get_device

from dcase24t6.nn.ckpt import CNEXT_REGISTER
from dcase24t6.nn.ckpt import CNEXT_REGISTRY
from dcase24t6.nn.encoders.convnext import convnext_tiny
from dcase24t6.nn.functional import remove_index_nd
from dcase24t6.pre_processes.common import batchify, is_audio_batch, unbatchify
from dcase24t6.pre_processes.common import (
batchify,
is_audio_batch,
sanitize_batch,
unbatchify,
)
from dcase24t6.pre_processes.resample import Resample


Expand All @@ -24,11 +29,12 @@ class ResampleMeanCNext(nn.Module):

def __init__(
self,
model_name_or_path: str | Path = "cnext_bl",
model_name_or_path: str | Path = "cnext_bl_70",
model_sr: int = 32_000,
offline: bool = False,
device: str | torch.device | None = "cuda_if_available",
input_time_dim: int = -1,
keep_batch: bool = False,
) -> None:
device = get_device(device)

Expand All @@ -42,7 +48,7 @@ def __init__(
return_frame_outputs=True,
return_clip_outputs=True,
)
state_dict = CNEXT_REGISTER.load_state_dict(
state_dict = CNEXT_REGISTRY.load_state_dict(
model_name_or_path,
device="cpu",
offline=offline,
Expand All @@ -59,18 +65,23 @@ def __init__(
self.convnext = convnext
self.resample = resample
self.input_time_dim = input_time_dim
self.keep_batch = keep_batch

@property
def device(self) -> torch.device:
return self.convnext.device

def forward(self, item_or_batch: dict[str, Any]) -> dict[str, Any]:
if is_audio_batch(item_or_batch):
return self.forward_batch(item_or_batch)
batch = sanitize_batch(item_or_batch)
return self.forward_batch(batch)

item = item_or_batch
batch = batchify(item)
batch = self.forward_batch(batch)
if self.keep_batch:
return batch
else:
item = item_or_batch
batch = batchify(item)
batch = self.forward_batch(batch)
item = unbatchify(batch)
return item

Expand Down
4 changes: 4 additions & 0 deletions src/dcase24t6/pre_processes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ def is_audio_batch(item_or_batch: dict[str, Any]) -> bool:
)


def sanitize_batch(batch: dict[str, Any]) -> dict[str, Any]:
return batch


def batchify(item: dict[str, Any]) -> dict[str, list | Tensor]:
"""Transform a item dict to a batch dict."""
item = add_audio_shape_to_item(item)
Expand Down
21 changes: 16 additions & 5 deletions src/dcase24t6/pre_processes/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,37 @@
from torch import Tensor, nn
from torchaudio.functional import resample

from dcase24t6.pre_processes.common import batchify, is_audio_batch, unbatchify
from dcase24t6.pre_processes.common import (
batchify,
is_audio_batch,
sanitize_batch,
unbatchify,
)


class Resample(nn.Module):
def __init__(
self,
target_sr: int = 32_000,
input_time_dim: int = -1,
keep_batch: bool = False,
) -> None:
super().__init__()
self.target_sr = target_sr
self.input_time_dim = input_time_dim
self.keep_batch = keep_batch

def forward(self, item_or_batch: dict[str, Any]) -> dict[str, Any]:
if is_audio_batch(item_or_batch):
return self.forward_batch(item_or_batch)
batch = sanitize_batch(item_or_batch)
return self.forward_batch(batch)

item = item_or_batch
batch = batchify(item)
batch = self.forward_batch(batch)
if self.keep_batch:
return batch
else:
item = item_or_batch
batch = batchify(item)
batch = self.forward_batch(batch)
item = unbatchify(batch)
return item

Expand Down
5 changes: 5 additions & 0 deletions src/dcase24t6/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from dcase24t6.callbacks.complexity import ComplexityProfiler
from dcase24t6.callbacks.emissions import CustomEmissionTracker
from dcase24t6.nn.ckpt import BASELINE_REGISTRY, CNEXT_REGISTRY
from dcase24t6.utils.job import get_git_hash
from dcase24t6.utils.saving import save_to_yaml

Expand Down Expand Up @@ -117,6 +118,10 @@ def prepare_data_metrics_models(
nltk.download("stopwords")
download_metrics(verbose=verbose)

for registry in (BASELINE_REGISTRY, CNEXT_REGISTRY):
for name in registry.names:
registry.download_file(name, force, verbose)

os.makedirs(dataroot, exist_ok=True)

if download_clotho:
Expand Down
2 changes: 1 addition & 1 deletion src/dcase24t6/tokenization/aac_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def save(self, path: str | Path, pretty: bool = True) -> None:
@classmethod
def from_file(cls, path: str | Path) -> "AACTokenizer":
"""Load tokenizer from JSON file."""
path = Path(path).resolve()
path = Path(path).resolve().expanduser()
content = path.read_text()
aac_tokenizer = cls.from_str(content)
return aac_tokenizer
Expand Down
Loading

0 comments on commit 137849f

Please sign in to comment.