Skip to content

Commit

Permalink
fix clip_slowfast_pann qvh pretrain
Browse files Browse the repository at this point in the history
  • Loading branch information
awkrail committed Sep 10, 2024
1 parent a5a333c commit a25816c
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 19 deletions.
11 changes: 4 additions & 7 deletions training/cg_detr_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,19 +449,16 @@ def _get_video_feat_by_vid(self, vid):
def _get_audio_feat_by_vid(self, vid):
a_feat_list = []
for _feat_dir in self.a_feat_dirs:
if self.dset_name == 'qvhighlight':
if self.a_feat_types == "clap":
_feat_path = join(_feat_dir, f"{vid}.npz")
_feat = np.load(_feat_path)["features"][:self.max_a_l].astype(np.float32)
elif self.a_feat_types == "pann":
if self.dset_name == 'qvhighlight' or self.dset_name == 'qvhighlight_pretrain':
if self.a_feat_types == "pann":
_feat_path = join(_feat_dir, f"{vid}.npy")
_feat = np.load(_feat_path)[:self.max_a_l].astype(np.float32)
else:
raise NotImplementedError()
raise NotImplementedError
_feat = l2_normalize_np_array(_feat) # normalize?
a_feat_list.append(_feat)
else:
raise NotImplementedError()
raise NotImplementedError

# some features are slightly longer than the others
min_len = min([len(e) for e in a_feat_list])
Expand Down
13 changes: 5 additions & 8 deletions training/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def __getitem__(self, index):
model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"], model_inputs["saliency_all_labels"] = \
self.get_saliency_labels_sub_as_query(meta["relevant_windows"][0], ctx_l)
else:
raise NotImplementedError()
raise NotImplementedError

return dict(meta=meta, model_inputs=model_inputs)

Expand Down Expand Up @@ -471,19 +471,16 @@ def _get_video_feat_by_vid(self, vid):
def _get_audio_feat_by_vid(self, vid):
a_feat_list = []
for _feat_dir in self.a_feat_dirs:
if self.dset_name == 'qvhighlight':
if self.a_feat_types == "clap":
_feat_path = join(_feat_dir, f"{vid}.npz")
_feat = np.load(_feat_path)["features"][:self.max_a_l].astype(np.float32)
elif self.a_feat_types == "pann":
if self.dset_name == 'qvhighlight' or self.dset_name == 'qvhighlight_pretrain':
if self.a_feat_types == "pann":
_feat_path = join(_feat_dir, f"{vid}.npy")
_feat = np.load(_feat_path)[:self.max_a_l].astype(np.float32)
else:
raise NotImplementedError()
raise NotImplementedError
_feat = l2_normalize_np_array(_feat) # normalize?
a_feat_list.append(_feat)
else:
raise NotImplementedError()
raise NotImplementedError

# some features are slightly longer than the others
min_len = min([len(e) for e in a_feat_list])
Expand Down
4 changes: 0 additions & 4 deletions training/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,7 @@
SOFTWARE.
"""

import pprint
import numpy as np
import torch
from lighthouse.common.utils.basic_utils import load_jsonl
from training.standalone_eval.eval import eval_submission
from tqdm import tqdm


Expand Down

0 comments on commit a25816c

Please sign in to comment.