Skip to content

Commit

Permalink
Merge pull request #244 from huilin16/visdrone
Browse files Browse the repository at this point in the history
add finetune_visdrone
  • Loading branch information
CaitinZhao committed Dec 14, 2023
2 parents a287863 + 17bbb0d commit b376123
Show file tree
Hide file tree
Showing 11 changed files with 513 additions and 0 deletions.
162 changes: 162 additions & 0 deletions examples/finetune_visdrone/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# 基于MindYOLO的无人机航拍图像检测案例输出

## 数据集介绍

[VisDrone-Dataset](https://github.com/VisDrone/VisDrone-Dataset)是2019年由天津大学发布的一项无人机识别挑战中的数据集,其共有5个子任务,包括图像目标检测、视频目标检测、单目标跟踪、多目标跟踪、人群计数等。

其中的图像目标检测任务数据集,包含10个目标类别,10209张图像,共54200个目标,图像大小为2000\*1500\*3,由无人机采集得到,采用horizontal bounding boxes的标注格式。

每个样本中,每幅图像对应一个txt文件,txt的每行代表1个目标,由bbox_left、bbox_top、bbox_width、bbox_height、score、object_category、truncation、occlusion8个值表示,其中bbox_left、bbox_top、bbox_width、bbox_height分别表示样本框的左上角横坐标、左上角纵坐标、框长度、框宽度.

原始数据集格式为:

```
ROOT_DIR
├── VisDrone2019-DET-train
│ ├── images
│ │ ├── 0000000_00000_d_0000001.jpg
│ │ ├── 0000000_00000_d_0000002.jpg
│ │ ├── ...
│ │ └── ...
│ └── annotations
│ ├── 0000000_00000_d_0000001.txt
│ ├── 0000000_00000_d_0000002.txt
│ ├── ...
│ └── ...
└── VisDrone2019-DET-val
├── images
│ ├── 0000001_00000_d_0000001.jpg
│ ├── 0000001_00000_d_0000002.jpg
│ ├── ...
│ └── ...
└── annotations
├── 0000001_00000_d_0000001.txt
├── 0000001_00000_d_0000002.txt
├── ...
└── ...
```

## 数据集格式转换

mindyolo中的train过程使用的数据是yolo格式,而eval过程使用coco数据集中的json文件,因此需要对数据集进行转换。

yolo格式标注中,每幅图像对应一个txt文件,txt的每行代表1个目标,由id、center_x、center_y、w、h等5个值表示,其中center_x、center_y、w、h分别表示样本框的归一化中心点横坐标、归一化中心点纵坐标、框长度、框宽度;并通过train.txt、val.txt记录图片路径。

yolo格式的bbox格式与visdrone数据集的对象框格式不同,因此需要通过以下步骤进行格式转换:

- 创建新的文件夹结构,读取原有文件列表;
- 将图片复制到新的文件夹下;
- 读取原有对象框信息,将左上角坐标信息转换成中心点坐标信息;
- 对于train,保存txt文件,对于val,将对象框信息与图像信息、类别信息一起保存为json文件。

详细实现可参考[convert_visdrone2yolo.py](./convert_visdrone2yolo.py),运行方式如下:

```
python convert_visdrone2yolo.py \
--img_dir /path_to_visdrone/VisDrone2019-DET-train/images \
--gt_dir /path_to_visdrone/VisDrone2019-DET-train/annotations \
--json_path /path_to_visdrone/visdrone/annotations/instances_train2017.json \
--img_dir_dst /path_to_visdrone/visdrone/train/images \
--gt_dir_dst /path_to_visdrone/visdrone/train/labels \
--txt_path /path_to_visdrone/visdrone/train.txt
python convert_visdrone2yolo.py \
--img_dir /path_to_visdrone/VisDrone2019-DET-val/images \
--gt_dir /path_to_visdrone/VisDrone2019-DET-val/annotations \
--json_path /path_to_visdrone/visdrone/annotations/instances_val2017.json \
--img_dir_dst /path_to_visdrone/visdrone/val/images \
--gt_dir_dst /path_to_visdrone/visdrone/val/labels \
--txt_path /path_to_visdrone/visdrone/val.txt
```

运行以上命令将在不改变原数据集的前提下,在同级目录生成yolo格式的visdrone数据集。

经过转换后的visdrone数据集包括以下内容:

```
visdrone
├── train.txt
├── val.txt
├── train
│ ├── images
│ │ ├── 000001.jpg
│ │ ├── 000002.jpg
│ │ ├── ...
│ │ └── ...
│ └── labels
│ ├── 000001.txt
│ ├── 000002.txt
│ ├── ...
│ └── ...
├── annotations
│ ├── instances_train2017.json
│ └── instances_val2017.json
└── val
├── images
│ ├── 000001.jpg
│ ├── 000002.jpg
│ ├── ...
│ └── ...
└── labels
├── 000001.txt
├── 000001.txt
├── ...
└── ...
```

## 模型选择

在选择具体的模型尺寸时,提取val部分数据作为miniVisDrone数据集,快速训练测试网络性能。拟使用较新的模型yolov7、yolov8进行训练。

首先使用轻量级模型yolov7t、yolov8m在mini visdrone上训练50epoch,精度分别为0.215、0.203。在轻量级模型上,yolov8效果较yolov7差,因此在后续训练中优先训练yolov7t,在后续训练中根据模型训练效果换用yolov7l、yolov8l。

## 编写yaml配置文件

MindYOLO支持yaml文件继承机制,因此新编写的配置文件只需要继承MindYOLO提供的原生yaml文件现有配置文件,最终的数据集配置文件见[visdrone.yaml](./visdrone.yaml)、yolov8l的配置文件见[yolov8-l-visdrone.yaml](./yolov8-l-visdrone.yaml)

## 优化策略

- 更换大参数模型:yolov7tiny训练精度上限较低,改为yolov7large
- 改用更大的image shape:使用[sta_anno.py](./sta_anno.py)[sta_img.py](./sta_img.py)可以得到对样本的图像与对象框size的分布情况。根据统计结果可以看到,样本的shape分布情况(均在1400 * 800以上)与目标的shape分布情况(均在200 * 200以下),因此该任务属于小目标检测,原始配置中的640 * 640样本尺寸过小,需要调大输入图像的尺寸。



<img src="./pic/img_shape.png" alt="img_shape" style="zoom:50%;" />


样本shape分布图

<img src="./pic/obj_shape.png" alt="obj_shape" style="zoom:50%;" />

目标shape分布图

- 改用更小的anchor:鉴于目标的shape均在200\*200以下,因此将yolov7中的anchor从[12,16,19,36,40,28] - [36,75,76,55,72,146] - [142,119,192,243,459,481] ,修改为-[12,16,19,36,40,28] - [36,75,76,55,50,50] - [72,146 142,119,192,243],减少大anchor的尺寸。
- 使用切图策略:由于该任务为小目标检测,因此尝试将训练集切割为小图,在验证时使用大尺寸原图进行验证(无提升)
- 调整学习率:调大、调小学习率(无提升)
- 使用数据增强:对yolov7l的训练过程,增加copy_paste=0.3,提升0.007;增加mixup=0.3,flipud=0.3,(降低0.03)
- 换用yolov8:yolov7的最高精度在0.355,换用yolov8后,精度可以达到0.365

## 最终精度:


<img src="./pic/results.png" alt="results" style="zoom:75%;" />


## 推理结果:

使用/demo/predict.py测试训练模型参数的结果并进行可视化推理,运行方式如下:

```
python demo/predict.py --config ./yolov8-l-visdrone.yaml --weight=/path_to_ckpt/WEIGHT.ckpt --image_path /path_to_image/IMAGE.jpg
```

推理效果如下:


<img src="./pic/000001.jpg" alt="000001" style="zoom:75%;" />



<img src="./pic/000002.jpg" alt="000002" style="zoom:75%;" />
107 changes: 107 additions & 0 deletions examples/finetune_visdrone/convert_visdrone2yolo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import os
import cv2
from tqdm import tqdm
import json
import argparse
from skimage import io
import pandas as pd

categories_info = [
{"id": 0, "name": "ignored regions"},
{"id": 1, "name": "pedestrian"},
{"id": 2, "name": "people"},
{"id": 3, "name": "bicycle"},
{"id": 4, "name": "car"},
{"id": 5, "name": "van"},
{"id": 6, "name": "truck"},
{"id": 7, "name": "tricycle"},
{"id": 8, "name": "awning-tricycle"},
{"id": 9, "name": "bus"},
{"id": 10, "name": "motor"},
{"id": 11, "name": "others"}
]


def data_convert(img_dir_src, anno_dir_src, json_path, img_dir_dst, anno_dir_dst, txt_path):
if not os.path.exists(img_dir_dst):
os.makedirs(img_dir_dst)
if not os.path.exists(anno_dir_dst):
os.makedirs(anno_dir_dst)
if not os.path.exists(os.path.dirname(json_path)):
os.makedirs(os.path.dirname(json_path))
file_list = os.listdir(img_dir_src)
box_num = 0
anno_infos = []
img_infos = []
df_img = pd.DataFrame(None, columns=['img_name'])
for idx, img_name_src in enumerate(tqdm(file_list)):
anno_name_src = img_name_src.replace('.jpg', '.txt')
img_name_dst = '%06d.jpg' % idx
anno_name_dst = img_name_dst.replace('.jpg', '.txt')
img_path_src = os.path.join(img_dir_src, img_name_src)
anno_path_src = os.path.join(anno_dir_src, anno_name_src)
img_path_dst = os.path.join(img_dir_dst, img_name_dst)
anno_path_dst = os.path.join(anno_dir_dst, anno_name_dst)
df_img.loc[len(df_img)] = img_path_dst
img = io.imread(img_path_src)
io.imsave(img_path_dst, img)
img_height, img_width = img.shape[:2]

img_info = {}
img_info["file_name"] = img_name_dst
img_info["height"] = img_height
img_info["width"] = img_width
img_info["id"] = idx
img_infos.append(img_info)

df_anno = pd.read_csv(anno_path_src, index_col=None, header=None,
names=['bbox_left', 'bbox_top', 'bbox_width', 'bbox_height', 'score', 'category',
'truncation', 'occlusion'])
df_anno['bbox_center_x'] = (df_anno['bbox_left'] + 0.5 * df_anno['bbox_width']) / img_width
df_anno['bbox_center_y'] = (df_anno['bbox_top'] + 0.5 * df_anno['bbox_height']) / img_height
df_anno['bbox_w'] = df_anno['bbox_width'] / img_width
df_anno['bbox_h'] = df_anno['bbox_height'] / img_height

df_anno_dst = df_anno[['category', 'bbox_center_x', 'bbox_center_y', 'bbox_w', 'bbox_h']]
df_anno_dst = df_anno_dst.round({"bbox_center_x": 6, "bbox_center_y": 6, "bbox_w": 6, "bbox_h": 6})
df_anno_dst.to_csv(anno_path_dst, header=None, index=None, sep=' ')

for i, row in df_anno.iterrows():
anno_info = {}
bbox_xywh = [int(row['bbox_left']), int(row['bbox_top']), int(row['bbox_width']), int(row['bbox_height'])]
anno_info["image_id"] = idx
anno_info["score"] = float(row['score'])
anno_info["bbox"] = bbox_xywh
anno_info["category_id"] = int(row['category'])
anno_info["id"] = box_num
anno_info["iscrowd"] = 0
anno_info["segmentation"] = []
anno_info["area"] = bbox_xywh[2] * bbox_xywh[3]
anno_info["truncation"] = float(row['truncation'])
anno_info["occlusion"] = float(row['occlusion'])
box_num += 1

anno_infos.append(anno_info)

df_img.to_csv(txt_path, header=None, index=None)
data_info = {}
data_info["images"] = img_infos
data_info["annotations"] = anno_infos
data_info["categories"] = categories_info
# print(data_info)
json_str = json.dumps(data_info)

with open(json_path, 'w') as json_file:
json_file.write(json_str)


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--img_dir', default='', type=str, help='img_dir')
parser.add_argument('--gt_dir', default='', type=str, help='gt_dir')
parser.add_argument('--json_path', default='', type=str, help='gt_dir')
parser.add_argument('--img_dir_dst', default='', type=str, help='img_dir_dst')
parser.add_argument('--gt_dir_dst', default='', type=str, help='gt_dir_dst')
parser.add_argument('--txt_path', default='', type=str, help='txt_path')
opt = parser.parse_args()
data_convert(opt.img_dir, opt.gt_dir, opt.json_path, opt.img_dir_dst, opt.gt_dir_dst, opt.txt_path)
Binary file added examples/finetune_visdrone/pic/000001.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/finetune_visdrone/pic/000002.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/finetune_visdrone/pic/img_shape.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/finetune_visdrone/pic/obj_shape.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/finetune_visdrone/pic/results.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit b376123

Please sign in to comment.