Skip to content

Commit

Permalink
Merge pull request #98 from kurusugawa-computer/modify-upload-task
Browse files Browse the repository at this point in the history
タスク作成処理を`initiateTasksGenerationAPI`から`putTask`APIを使うように変更しました。
  • Loading branch information
seraphr committed Jul 28, 2022
2 parents 8dde682 + 420939a commit 5f4b88e
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 66 deletions.
30 changes: 8 additions & 22 deletions anno3d/annofab/task.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import json
from pathlib import Path
from typing import List, Optional
from typing import Collection, List, Optional

from annofabapi import AnnofabApi
from annofabapi import models as afm
from annofabapi.dataclass.task import Task
from annofabapi.models import TaskStatus

from anno3d.annofab.model import CuboidAnnotationDetail, TaskGenerateResponse
from anno3d.annofab.model import CuboidAnnotationDetail
from anno3d.annofab.project import ProjectApi
from anno3d.annofab.uploader import AnnofabStorageUploader


class TaskApi:
Expand All @@ -30,24 +28,6 @@ def project_id(self) -> str:
def _decode_task(task: afm.Task) -> Task:
return Task.from_dict(task)

def create_tasks_by_csv(self, csv_path: Path) -> TaskGenerateResponse:
client = self._client
project_id = self._project_id
uploader = AnnofabStorageUploader(client, project_id)
project = self._project.get_project(project_id)
if project is None:
raise RuntimeError("指定されたプロジェクト(={})が見つかりませんでした。".format(project_id))

uploaded_path = uploader.upload_tempdata(csv_path)

body_params = {
"task_generate_rule": {"csv_data_path": uploaded_path, "_type": "ByInputDataCsv"},
"project_last_updated_datetime": project.updated_datetime,
}
task_generate_result, _ = client.initiate_tasks_generation(project_id, request_body=body_params)

return TaskGenerateResponse.from_dict(task_generate_result)

def get_task(self, task_id: str) -> Optional[Task]:
client = self._client
project_id = self._project_id
Expand All @@ -57,6 +37,12 @@ def get_task(self, task_id: str) -> Optional[Task]:

return self._decode_task(result)

def put_task(self, task_id: str, input_data_ids: Collection[str]) -> Task:
client = self._client
project_id = self._project_id
result, _ = client.put_task(project_id, task_id, request_body={"input_data_id_list": input_data_ids})
return self._decode_task(result)

def put_cuboid_annotations(
self, task_id: str, input_data_id: str, annotations: List[CuboidAnnotationDetail]
) -> None:
Expand Down
77 changes: 33 additions & 44 deletions anno3d/kitti/scene_uploader.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import asyncio
import json
import logging
import tempfile
import time
import uuid
from dataclasses import dataclass
from enum import Enum
Expand All @@ -13,7 +10,6 @@
import numpy as np
from annofabapi import AnnofabApi
from annofabapi.dataclass.annotation_specs import LabelV2
from annofabapi.models import JobStatus
from scipy.spatial.transform import Rotation

from anno3d.annofab.model import (
Expand Down Expand Up @@ -112,12 +108,20 @@ def id_to_paths(frame_id: str) -> FilePaths:
return [id_to_paths(frame_id) for frame_id in scene.id_list]

@staticmethod
def _create_task_def_csv(
csv_path: Path,
id_prefix: str,
data_and_pathss: List[Tuple[DataId, FilePaths]],
chunk_size: Optional[int] = None,
def _get_task_to_data_dict(
id_prefix: str, data_and_pathss: List[Tuple[DataId, FilePaths]], chunk_size: Optional[int] = None,
) -> Dict[TaskId, List[Tuple[DataId, FilePaths]]]:
"""
タスクと入力データの関係を示すdictを取得します。
Args:
id_prefix: タスクIDのプレフィックス
data_and_pathss: 入力データとファイルパス情報をペアにしたlist
chunk_size: タスクに含める入力データの個数。Noneの場合は、タスクにすべての入力データを含めます。
Returns:
タスクと入力データの関係を示すdict
"""

if chunk_size is None:
chunked_by_tasks = iter([data_and_pathss])
Expand All @@ -128,38 +132,24 @@ def _create_task_def_csv(
task_id_template = "{id_prefix}_{task_count}"

result_dict: Dict[TaskId, List[Tuple[DataId, FilePaths]]] = {}
with csv_path.open("w", encoding="UTF-8") as writer:
for task_count, data_list in enumerate(chunked_by_tasks):
task_id = task_id_template.format(id_prefix=id_prefix, task_count=task_count)
result_dict[TaskId(task_id)] = data_list
for data_id, paths in data_list:
# XXX ここで `paths.pcd.name` は input_data_nameの指定なんだけど、ファイル名がinput_data_nameである
# というのは、ただの実装詳細なので、本来はやりたくない…
line = f"{task_id},{paths.pcd.name},{data_id}"
writer.write(f"{line}\r\n")

with csv_path.open("r") as reader:
logger.info("task def csv: \n%s", reader.read())
for task_count, data_list in enumerate(chunked_by_tasks):
task_id = task_id_template.format(id_prefix=id_prefix, task_count=task_count)
result_dict[TaskId(task_id)] = data_list

return result_dict

def _create_task(self, project_id: str, csv_path: Path) -> None:
def _create_tasks(self, project_id: str, task_to_data_dict: Dict[TaskId, List[Tuple[DataId, FilePaths]]]):
"""タスクを作成します。
Notes:
タスクは数件しか作成しないことを想定しているので、同期的に`putTask`APIを実行しています。
"""
project = self._project
task = TaskApi(self._client, project, project_id)
response = task.create_tasks_by_csv(csv_path)
job = response.job

while job.job_status == JobStatus.PROGRESS:
logger.info("タスクの作成完了を待っています。")
time.sleep(5)
new_info = project.get_job(project_id, job)
if new_info is None:
raise RuntimeError(f"ジョブ(={job.job_id})が取得できませんでした。")
job = new_info

if job.job_status == JobStatus.FAILED:
detail = json.dumps(job.job_detail, ensure_ascii=False)
errors = json.dumps(job.errors, ensure_ascii=False)
raise RuntimeError(f"タスクの作成に失敗しました: {errors}: {detail}")

for task_id, data_id_and_pathss in task_to_data_dict.items():
input_data_id_list = [input_data_id for input_data_id, _ in data_id_and_pathss]
task.put_task(task_id, input_data_id_list)

def _label_to_cuboids(
self, id_to_label: Dict[str, LabelV2], labels: List[KittiLabel]
Expand Down Expand Up @@ -282,14 +272,13 @@ async def upload_scene_async(self, scene: Scene, uploader_input: SceneUploaderIn
if uploader_input.kind == UploadKind.DATA_ONLY:
return

with tempfile.TemporaryDirectory() as tempdir_str:
csv_path = Path(tempdir_str) / "task_create.csv"
task_to_data_dict = self._create_task_def_csv(
csv_path, uploader_input.task_id_prefix, data_and_pathss, uploader_input.frame_per_task
)
self._create_task(uploader_input.project_id, csv_path)
task_to_data_dict = self._get_task_to_data_dict(
uploader_input.task_id_prefix, data_and_pathss, uploader_input.frame_per_task
)

logger.info("タスクの作成が完了しました")
logger.info("タスクを%d件作成します。", len(task_to_data_dict))
self._create_tasks(uploader_input.project_id, task_to_data_dict)
logger.info("%d件のタスクを作成しました", len(task_to_data_dict))
if uploader_input.kind == UploadKind.CREATE_TASK:
return

Expand Down

0 comments on commit 5f4b88e

Please sign in to comment.