[MMDet3D] How to Infer — 3D Box Coordinates

jun94
jun-devpBlog
Published in
9 min readApr 8, 2024

--

In this article, we will go over how to use MMDet3D in order to retrieve predictions, particularly actual coordinate values of predicted 3D bounding-box, from a trained model.

I would like to kindly forward those readers who haven’t read the previous article to go over it as this one is latter part of it.

Same as the last article, it assumes the followings:

  • MMDet3D, version 1.4.0, is already installed.
  • Task is Monocular 3D Object detection where an image is given as input to model.
  • FCOS3D is chosen as a 3D object detector.
  • nuScenes Dataset is placed under PROJECT_ROOT/data.

What format are 3D Boxes are represented in?

3D bounding boxes in 3D space can be represented in various formats. Popular ones could be either 8-corner(x, y, z) or (centroid, dimension, orientation).

8-corner(x, y, z) Format

  • 8-corner is intuitive format to represent a 3D box as it simply express 3D coordinates (x, y, z) for each 8 corner points of the box.
  • While it is simple and is not in need of any further decoding, one has to keep tracking the order of 8 points. Once the order gets mixed up, one might struggle to figure out what points define what face of 3D box.
  • Another disadvantage of this format is the memory usage. Since we are talking about 8 corner points of 3D box, it requires array of shape (8, 3) for individual box. When compared to (centroid, dimension, orientation) format, this consumes notably more space to store the box information.

(centroid, dimension, orientation) Format

  • This representation requires 7 values to express a 3D box. First three for center point (center x, center y, center z), next three for box size (size in x-axis, size in y-axis, size in z-axis) and lastly one for orientation (yaw).
  • As this way of representation only requires an array of shape (7), while 8-corner requires (8, 3), it can be more commonly observed in practice.
  • Note that this representation also can take 9 values, instead of 7, when orientation is represented in three dimensions, namely roll, pitch and yaw.
3D Box represented with 9 values, from OpenLabel concept paper

Data class for 3D Box in MMDet3D

As this articles address a 3D monocular detector, FCOS3D, the predicted 3D bounding boxes are preresented with respect to a camera, i.e., a special structure called CameraInstance3DBoxes is used.

Illustration of visualized 3D boxes and respective raw data from class, CameraInstance3DBoxes.

Among all member variables of CameraInstance3DBoxes, the ones I believe worth to mention are illustrated and described below. Those I personally found highly-useful are highlighted in bold.

Visualization of member variables of the class, CameraInstance3DBoxes.
  • center(N, 3): The center of bottom face of the 3D Box. Unlike our intuition, MMDet3D defines center of 3D Box as the center of bottom face following KITTI convention. See remark section below for further details. As CameraInstance3DBoxes handles multiple boxes simultaneously, N in shape stands for the number of boxes.
  • dims(N, 3): The size of 3D Box along each axis(x, y, z).
  • yaw(N, 1): The orientation of 3D Box with respect to the camera coordinate system.
  • gravity_center(N, 3): The ACTUAL center of 3D Box, computed as the ave.
  • bottom_center(N, 3): same as center above.
  • corners(N, 8): All 8 corner points of 3D Box.
  • bev(N, 5): Oriented 2D Box to represent top face of each 3D Box. As the 2D Box is oriented, 5 attributes, center point(x, y), dim(size_x, size_y) and rotation angle(yaw), are taken.

While CameraInstance3DBoxes provides numerous methods, I would like to highlight one, overlaps.

  • overlaps(boxes1, boxes2, mode) -> overlap_scores: Computes the degree of overlap between two given boxes(or set, box-set 1 and box-set 2). Normally, the metric to measure how much two given boxes overlap is IoU(intersectino over union).
  • points_in_boxes_part(points, boxes) -> box_indices: returns at which box the given points belong.

Remark

Please Note that the center of each 3D box is the center of bottom face, following the convention from KITTI as is discussed in here.

  • Thus, you might need additional conversion to get the actual center(gravity center) of the 3D box.
  • Luckily the calculation is rather minor, I attached a well established figure to get intuition regarding relative center below.
further details were discussed in here

It is also worth to mention that MMDet3D provides class to represent 3D Box in other coordinates system, e.g., LiDAR.

  • Base Box3D class
  • Box3D class for Camera coordinate system
  • Box3D class for LiDAR coordinate system

How to Obtain Raw Data of Predicted 3D Boxes?

Prerequisite: Info.pkl Generation

Before getting into the programming part to retrieve the inference result(3D box) from MMDet3D, let us make sure our dataset is properly prepared, i.e., pre-processed, following the link below.

Once the pre-processing step is done, one can find their directory where nuScenes dataset is stored has changed as below.

After running mmdetection3d/tools/create_data.py

While this step does not change the essence of nuScenes, it conducts re-formatting of the dataset for the sake of easier accessibility and usage during data-loading stage.

The detailed description for contents stored in each pkl files can be found here.

Run Inference with Monocular 3D Detector, FCOS3D

The following image gives an overview of this section, given an image what and how we obtain the inference results.

In general, running inference is as simple as a single command line:

$python mmdetection3d/demo/mono_det_demo.py {IMG} {GT_INFO} {MODEL} {CKPT}

where

  • IMG: input image to be fed into the MODEL.
  • GT_INFO: special format in pkl, widely used in MMDet3D, that processed the ground-truth labels of original dataset. See section:Prerequisite.
  • MODEL: config.py file which states all information regarding its architecture.
  • CKPT: trained weights to be loaded for MODEL.

An example for each argument is as follows,

While the above command is sufficient in general to run inference on various models, I happened to find some minor modification is required to run FCOS3D with nuScenes dataset at MMDet3D ver1.4.0.

To address this, I made my custom class, FCOS3DInferencer, which inherits the original MonoDet3DInferencer from MMDet3D as below.

import mmengine
import numpy as np
from mmdet3d.apis.inferencers.base_3d_inferencer import InputsType, track
from mmdet3d.apis.inferencers.mono_det3d_inferencer import MonoDet3DInferencer
from mmdet3d.registry import INFERENCERS
from mmengine.config import Config
from mmengine.config.utils import MODULE2PACKAGE
from mmengine.fileio import get_file_backend, isdir, join_path, list_dir_or_file
import os.path as osp
from typing import Optional, Tuple, Union


@INFERENCERS.register_module(name="det3d-mono", force=True)
class FCOS3DInferencer(MonoDet3DInferencer):

def __call__(
self, inputs: InputsType, batch_size: int = 1, return_datasamples: bool = False, **kwargs
) -> Optional[dict]:
"""Call the inferencer.

Args:
inputs (InputsType): Inputs for the inferencer.
batch_size (int): Batch size. Defaults to 1.
return_datasamples (bool): Whether to return results as
:obj:`BaseDataElement`. Defaults to False.
**kwargs: Key words arguments passed to :meth:`preprocess`,
:meth:`forward`, :meth:`visualize` and :meth:`postprocess`.
Each key in kwargs should be in the corresponding set of
``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs``
and ``postprocess_kwargs``.


Returns:
dict: Inference and visualization results.
"""

(
preprocess_kwargs,
forward_kwargs,
visualize_kwargs,
postprocess_kwargs,
) = self._dispatch_kwargs(**kwargs)

if "cam_type" in preprocess_kwargs:
cam_type = preprocess_kwargs.pop("cam_type")
else:
cam_type = visualize_kwargs["cam_type_dir"]
ori_inputs = self._inputs_to_list(inputs, cam_type=cam_type)
inputs = self.preprocess(ori_inputs, batch_size=batch_size, **preprocess_kwargs)
preds = []

results_dict = {"predictions": [], "visualization": []}
for data in track(inputs, description="Inference") if self.show_progress else inputs:
preds.extend(self.forward(data, **forward_kwargs))
visualization = self.visualize(ori_inputs, preds, **visualize_kwargs)
results = self.postprocess(preds, visualization, return_datasamples, **postprocess_kwargs)
# results_dict['predictions'].extend(results['predictions']) # to prevent memory error, we disable this line.
if results["visualization"] is not None:
results_dict["visualization"].extend(results["visualization"])
return results_dict

def _inputs_to_list(self, inputs: Union[dict, list], cam_type="CAM2", **kwargs) -> list:
"""Preprocess the inputs to a list.

Preprocess inputs to a list according to its type:

- list or tuple: return inputs
- dict: the value with key 'img' is
- Directory path: return all files in the directory
- other cases: return a list containing the string. The string
could be a path to file, a url or other types of string according
to the task.

Args:
inputs (Union[dict, list]): Inputs for the inferencer.

Returns:
list: List of input for the :meth:`preprocess`.
"""
if isinstance(inputs, dict):
assert "infos" in inputs
infos = inputs.pop("infos")

if isinstance(inputs["img"], str):
img = inputs["img"]
backend = get_file_backend(img)
if hasattr(backend, "isdir") and isdir(img):
# Backends like HttpsBackend do not implement `isdir`, so
# only those backends that implement `isdir` could accept
# the inputs as a directory
filename_list = list_dir_or_file(img, list_dir=False)
inputs = [{"img": join_path(img, filename)} for filename in filename_list]

if not isinstance(inputs, (list, tuple)):
inputs = [{"img": img_path} for img_path in inputs["img"]]

# get cam2img, lidar2cam and lidar2img from infos
info_list = mmengine.load(infos)["data_list"]
assert len(info_list) == len(inputs)
for index, input in enumerate(inputs):
data_info = info_list[index]
img_path = data_info["images"][cam_type]["img_path"]
if isinstance(input["img"], str) and osp.basename(img_path) != osp.basename(input["img"]):
raise ValueError(f"the info file of {img_path} is not provided.")
cam2img = np.asarray(data_info["images"][cam_type]["cam2img"], dtype=np.float32)
lidar2cam = np.asarray(data_info["images"][cam_type]["lidar2cam"], dtype=np.float32)
if "lidar2img" in data_info["images"][cam_type]:
lidar2img = np.asarray(data_info["images"][cam_type]["lidar2img"], dtype=np.float32)
else:
assert cam2img.shape == (3, 3)
cam2img_ = np.eye(4)
cam2img_[:3, :3] = cam2img
lidar2img = cam2img_ @ lidar2cam

input["cam2img"] = cam2img
input["lidar2cam"] = lidar2cam
input["lidar2img"] = lidar2img
elif isinstance(inputs, (list, tuple)):
# get cam2img, lidar2cam and lidar2img from infos
for input in inputs:
assert "infos" in input
infos = input.pop("infos")
info_list = mmengine.load(infos)["data_list"]
assert len(info_list) == 1, "Only support single sample info" "in `.pkl`, when inputs is a list."
data_info = info_list[0]
img_path = data_info["images"][cam_type]["img_path"]
if isinstance(input["img"], str) and osp.basename(img_path) != osp.basename(input["img"]):
raise ValueError(f"the info file of {img_path} is not provided.")
cam2img = np.asarray(data_info["images"][cam_type]["cam2img"], dtype=np.float32)
lidar2cam = np.asarray(data_info["images"][cam_type]["lidar2cam"], dtype=np.float32)
if "lidar2img" in data_info["images"][cam_type]:
lidar2img = np.asarray(data_info["images"][cam_type]["lidar2img"], dtype=np.float32)
else:
lidar2img = cam2img @ lidar2cam
input["cam2img"] = cam2img
input["lidar2cam"] = lidar2cam
input["lidar2img"] = lidar2img

return list(inputs)

def _load_model_from_metafile(self, model: str) -> Tuple[Config, str]:
"""Load config and weights from metafile.

Args:
model (str): model name defined in metafile.

Returns:
Tuple[Config, str]: Loaded Config and weights path defined in
metafile.
"""
model = model.lower()

assert self.scope is not None, "scope should be initialized if you want " "to load config from metafile."
assert self.scope in MODULE2PACKAGE, f"{self.scope} not in {MODULE2PACKAGE}!," "please pass a valid scope."

repo_or_mim_dir = CustomMonoDet3DInferencer._get_repo_or_mim_dir(self.scope)
for model_cfg in CustomMonoDet3DInferencer._get_models_from_metafile(repo_or_mim_dir):
model_name = model_cfg["Name"].lower()
model_aliases = model_cfg.get("Alias", [])
if isinstance(model_aliases, str):
model_aliases = [model_aliases.lower()]
else:
model_aliases = [alias.lower() for alias in model_aliases]
if model_name in model or model in model_aliases:
cfg = Config.fromfile(osp.join(repo_or_mim_dir, model_cfg["Config"]))
weights = model_cfg["Weights"]
weights = weights[0] if isinstance(weights, list) else weights
return cfg, weights
raise ValueError(f"Cannot find model: {model} in {self.scope}")

While most of the code remains the same as the original, the changes I made are the following two:

  • Removed hard-coded cam_type, CAM2 which was used only for KITTI dataset.
  • Disabled results_dict in the function, __call__.

Note that I intentionally disabled results_dict as when the large number of images was given, e.g., entire nuScenes dataset, the inference results keep stack up and lead to memory overflow. To prevent it, I simply commented out the line where result_dict is updated.

For those who might worry that this modification won’t get you the inference results, the inference results will be stored in the specified output directory, which you can specify as argument for mono_det_demo.py, regardless of this modification.

Any corrections, suggestions, and comments are welcome

Reference

[1] MMDet3D Installation

[2] MMDet3D — Visualization

[3] MMDet3D — Structure.BBOX

[4] MMDet3D — CoordinateSystem

--

--