nms_3d
Apply 3‑D Non‑Maximum Suppression (NMS) to a set of scored axis‑aligned boxes. The function keeps the highest‑scoring box, removes all boxes whose IoU with it exceeds a chosen threshold, then repeats until no boxes remain.
nms_3d(
prediction_boxes: torch.Tensor,
iou_threshold: float = 0.5,
debug: bool = False
) -> torch.Tensor
Parameters
Name | Type | Description |
---|---|---|
prediction_boxes | torch.Tensor | Shape (N, 7). Columns: SCORE, X_MIN, Y_MIN, Z_MIN, X_MAX, Y_MAX, Z_MAX . |
iou_threshold | float | IoU cutoff for suppression (0‒1 ). Default 0.5. |
debug | bool | If True, prints each suppression step. |
Returns
torch.Tensor
– The retained boxes after NMS, shape (M, 7) where M ≤ N.
Example
import torch
from nms_3d import nms_3d
prediction_boxes = torch.tensor([
[0.95, 10, 10, 10, 20, 20, 20], # kept
[0.90, 12, 12, 12, 22, 22, 22], # suppressed (overlaps first)
[0.85, 50, 50, 50, 60, 60, 60], # kept
[0.80, 55, 55, 55, 65, 65, 65], # suppressed (overlaps third)
[0.75,100,100,100,110,110,110] # kept
])
filtered = nms_3d(
prediction_boxes=prediction_boxes,
iou_threshold=0.5,
debug=True
)
print(filtered)