-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
20 changed files
with
252 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
|
||
dependencies = ["git+https://github.com/Labbeti/dcase2024-task6-baseline"] | ||
|
||
|
||
from dcase24t6.nn.hub import baseline_pipeline | ||
|
||
__all__ = ["baseline_pipeline"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# @package model | ||
|
||
_target_: "dcase24t6.models.trans_decoder.TransDecoderModel.load_from_checkpoint" | ||
|
||
checkpoint_path: ~/.cache/torch/hub/checkpoints/dcase2024-task6-baseline/checkpoints/best.ckpt | ||
verbose: ${verbose} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# @package tokenizer | ||
|
||
_target_: "dcase24t6.tokenization.aac_tokenizer.AACTokenizer.from_file" | ||
|
||
path: ~/.cache/torch/hub/checkpoints/dcase2024-task6-baseline/tokenizer.json |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,56 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
|
||
from torchoutil.utils.ckpt import ModelCheckpointRegister | ||
from torchoutil.hub.registry import RegistryHub | ||
|
||
# Zenodo link : https://zenodo.org/record/8020843 | ||
# Zenodo link : https://zenodo.org/records/8020843 | ||
# Hash type : md5 | ||
CNEXT_REGISTER = ModelCheckpointRegister( | ||
CNEXT_REGISTRY = RegistryHub( | ||
infos={ | ||
"cnext_bl": { | ||
"cnext_nobl": { | ||
"architecture": "ConvNeXt", | ||
"url": "https://zenodo.org/record/8020843/files/convnext_tiny_471mAP.pth?download=1", | ||
"hash_value": "e069ecd1c7b880268331119521c549f2", | ||
"hash_type": "md5", | ||
"fname": "convnext_tiny_471mAP.pth", | ||
"state_dict_key": "model", | ||
}, | ||
"cnext_bl_70": { | ||
"architecture": "ConvNeXt", | ||
"url": "https://zenodo.org/record/8020843/files/convnext_tiny_465mAP_BL_AC_70kit.pth?download=1", | ||
"hash": "0688ae503f5893be0b6b71cb92f8b428", | ||
"hash_value": "0688ae503f5893be0b6b71cb92f8b428", | ||
"hash_type": "md5", | ||
"fname": "convnext_tiny_465mAP_BL_AC_70kit.pth", | ||
"state_dict_key": "model", | ||
}, | ||
"cnext_nobl": { | ||
"cnext_bl_75": { | ||
"architecture": "ConvNeXt", | ||
"url": "https://zenodo.org/record/8020843/files/convnext_tiny_471mAP.pth?download=1", | ||
"hash": "e069ecd1c7b880268331119521c549f2", | ||
"fname": "convnext_tiny_471mAP.pth", | ||
"url": "https://zenodo.org/records/10987498/files/convnext_tiny_465mAP_BL_AC_75kit.pth?download=1", | ||
"hash_value": "f6f57c87b7eb664a23ae8cad26eccaa0", | ||
"hash_type": "md5", | ||
"fname": "convnext_tiny_465mAP_BL_AC_75kit.pth", | ||
}, | ||
}, | ||
) | ||
|
||
# Zenodo link : https://zenodo.org/records/10849427 | ||
# Hash type : md5 | ||
BASELINE_REGISTRY = RegistryHub( | ||
infos={ | ||
"baseline_weights": { | ||
"architecture": "TransDecoderModel", | ||
"url": "https://zenodo.org/records/10849427/files/epoch_192-step_001544-mode_min-val_loss_3.3758.ckpt?download=1", | ||
"hash_value": "9514a8e6fa547bd01fb1badde81c6d10", | ||
"hash_type": "md5", | ||
"fname": "dcase2024-task6-baseline/checkpoints/best.ckpt", | ||
"state_dict_key": "state_dict", | ||
}, | ||
"baseline_tokenizer": { | ||
"architecture": "AACTokenizer", | ||
"url": "https://zenodo.org/records/10849427/files/tokenizer.json?download=1", | ||
"hash_value": "ee3fef19f7d0891d820d84035483a900", | ||
"hash_type": "md5", | ||
"fname": "dcase2024-task6-baseline/tokenizer.json", | ||
}, | ||
}, | ||
state_dict_key="model", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
|
||
import os.path as osp | ||
from pathlib import Path | ||
|
||
import nltk | ||
import torch | ||
from torch import nn | ||
from torchoutil.nn.functional import get_device | ||
|
||
from dcase24t6.models.trans_decoder import TransDecoderModel | ||
from dcase24t6.nn.ckpt import BASELINE_REGISTRY | ||
from dcase24t6.pre_processes.cnext import ResampleMeanCNext | ||
from dcase24t6.tokenization.aac_tokenizer import AACTokenizer | ||
|
||
|
||
def baseline_pipeline( | ||
model_name_or_path: str | Path = "baseline_weights", | ||
tokenizer_name_or_path: str | Path = "baseline_tokenizer", | ||
pre_process_name_or_path: str | Path = "cnext_bl_70", | ||
*, | ||
offline: bool = False, | ||
device: str | torch.device | None = "cuda_if_available", | ||
verbose: int = 0, | ||
) -> nn.Sequential: | ||
device = get_device(device) | ||
|
||
if not offline: | ||
nltk.download("stopwords") | ||
|
||
pre_process = ResampleMeanCNext( | ||
pre_process_name_or_path, | ||
offline=offline, | ||
device=device, | ||
keep_batch=True, | ||
) | ||
|
||
if osp.isfile(model_name_or_path): | ||
model_path = Path(model_name_or_path) | ||
else: | ||
model_name = str(model_name_or_path) | ||
model_path = BASELINE_REGISTRY.get_path(model_name) | ||
if not offline and not model_path.exists(): | ||
BASELINE_REGISTRY.download_file(model_name, verbose=verbose) | ||
|
||
if osp.isfile(tokenizer_name_or_path): | ||
tokenizer_path = Path(tokenizer_name_or_path) | ||
else: | ||
tokenizer_name = str(tokenizer_name_or_path) | ||
tokenizer_path = BASELINE_REGISTRY.get_path(tokenizer_name) | ||
if not offline and not tokenizer_path.exists(): | ||
BASELINE_REGISTRY.download_file(tokenizer_name, verbose=verbose) | ||
|
||
tokenizer = AACTokenizer.from_file(tokenizer_path) | ||
model = TransDecoderModel.load_from_checkpoint(model_path, tokenizer=tokenizer) | ||
pipeline = nn.Sequential(pre_process, model) | ||
pipeline = pipeline.to(device=device) | ||
|
||
return pipeline |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.