Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Format code with yapf, autopep8 and isort #7

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 52 additions & 26 deletions dair_v2x/pp_vehicle.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class Loader:
"""加载数据集标签信息

将 car, truck, van, bus 类别都设置为vehicle;
忽略 tricyclist, barrowlist, pedestrianignore,
忽略 tricyclist, barrowlist, pedestrianignore,
carignore, otherignore, unknown_movable, unknown_unmovable 类别
"""

Expand All @@ -33,7 +33,9 @@ def __init__(self, data_dir: str, train_split: float = 0.85) -> None:
self.data_dir = data_dir
self.__categories, self.__vhicles = self.__get_categories()
self.__data_info_path = os.path.join(data_dir, "data_info.json") # 数据信息
self.__train_data_info, self.__val_data_info = self.__get_train_val_info(train_split)
self.__train_data_info, self.__val_data_info = self.__get_train_val_info(
train_split
)

@property
def train_info(self) -> dict:
Expand All @@ -59,11 +61,16 @@ def vhicles(self) -> list:
def __get_categories() -> dict:
"""获取类别与id对应关系"""
__categories = [
"vehicle", "pedestrian", "cyclist", "motorcyclist", "barrowlist"
"vehicle",
"pedestrian",
"cyclist",
"motorcyclist",
"barrowlist",
]

return {_category: _id
for _id, _category in enumerate(__categories, start=1)}, ["car", "truck", "van", "bus"]
return {
_category: _id for _id, _category in enumerate(__categories, start=1)
}, ["car", "truck", "van", "bus"]

def __get_train_val_info(self, train_split) -> Tuple[dict, dict]:
"""分割训练集和验证集"""
Expand Down Expand Up @@ -113,30 +120,48 @@ def format2coco(self, data_info: dict, json_path: str) -> None:
for _info in tqdm(data_info):
file_name = _info["image_path"]
img_id, _ = os.path.splitext(os.path.basename(file_name))
coco_json["images"].append({
"id": int(img_id),
"file_name": file_name,
"width": 1920,
"height": 1080
})

annos = self.__get_annotations(os.path.join(self.data_dir, _info["label_camera_std_path"])) # 获取标注信息
coco_json["images"].append(
{
"id": int(img_id),
"file_name": file_name,
"width": 1920,
"height": 1080,
}
)

annos = self.__get_annotations(
os.path.join(self.data_dir, _info["label_camera_std_path"])
) # 获取标注信息
for _anno in annos:
if category := _anno["type"].lower() in self.vhicles: # 将所有的车辆类别都设为vehicle
if (
category := _anno["type"].lower() in self.vhicles
): # 将所有的车辆类别都设为vehicle
category = "vehicle"
if category_id := self.categories.get(category) is not None: # 获取类别id
xywh = self.__bbox2xywh(_anno["2d_box"]) # coco的bbox[xmin, ymin, width, height]
coco_json["annotations"].append({
"id": len(coco_json["annotations"]),
"image_id": int(img_id),
"category_id": category_id,
"bbox": xywh,
"area": xywh[-2] * xywh[-1],
"iscrowd": 0
})
xywh = self.__bbox2xywh(
_anno["2d_box"]
) # coco的bbox[xmin, ymin, width, height]
coco_json["annotations"].append(
{
"id": len(coco_json["annotations"]),
"image_id": int(img_id),
"category_id": category_id,
"bbox": xywh,
"area": xywh[-2] * xywh[-1],
"iscrowd": 0,
}
)
item_id += 1
coco_json["categories"] = [{"id": _id, "name": _name} for _name, _id in self.categories.items()] # 类别信息
json.dump(coco_json, open(json_path, "w+", encoding='utf-8'), indent=4, sort_keys=False, ensure_ascii=False) # 保存json
coco_json["categories"] = [
{"id": _id, "name": _name} for _name, _id in self.categories.items()
] # 类别信息
json.dump(
coco_json,
open(json_path, "w+", encoding="utf-8"),
indent=4,
sort_keys=False,
ensure_ascii=False,
) # 保存json

def processing(self) -> None:
"""处理进程"""
Expand All @@ -163,7 +188,8 @@ def processing(self) -> None:

def parse_args():
parser = argparse.ArgumentParser(
description="DAIR-V2X dataset to PP-vehicle format.")
description="DAIR-V2X dataset to PP-vehicle format."
)
parser.add_argument("--data_dir", type=str, help="数据位置")
return parser.parse_args()

Expand Down