|
| 1 | +/* |
| 2 | + * Copyright (c) Facebook, Inc. and its affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | + |
| 9 | +#include <ATen/ATen.h> |
| 10 | +#include <ATen/cuda/CUDAContext.h> |
| 11 | +#include <c10/cuda/CUDAGuard.h> |
| 12 | +#include <math.h> |
| 13 | +#include <stdio.h> |
| 14 | +#include <stdlib.h> |
| 15 | +#include <thrust/device_vector.h> |
| 16 | +#include <thrust/tuple.h> |
| 17 | +#include "iou_box3d/iou_utils.cuh" |
| 18 | +#include "utils/pytorch3d_cutils.h" |
| 19 | + |
| 20 | +// Parallelize over N*M computations which can each be done |
| 21 | +// independently |
| 22 | +__global__ void IoUBox3DKernel( |
| 23 | + const at::PackedTensorAccessor64<float, 3, at::RestrictPtrTraits> boxes1, |
| 24 | + const at::PackedTensorAccessor64<float, 3, at::RestrictPtrTraits> boxes2, |
| 25 | + at::PackedTensorAccessor64<float, 2, at::RestrictPtrTraits> vols, |
| 26 | + at::PackedTensorAccessor64<float, 2, at::RestrictPtrTraits> ious) { |
| 27 | + const size_t N = boxes1.size(0); |
| 28 | + const size_t M = boxes2.size(0); |
| 29 | + |
| 30 | + const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; |
| 31 | + const size_t stride = gridDim.x * blockDim.x; |
| 32 | + |
| 33 | + for (size_t i = tid; i < N * M; i += stride) { |
| 34 | + const size_t n = i / M; // box1 index |
| 35 | + const size_t m = i % M; // box2 index |
| 36 | + |
| 37 | + // Convert to array of structs of face vertices i.e. effectively (F, 3, 3) |
| 38 | + // FaceVerts is a data type defined in iou_utils.cuh |
| 39 | + FaceVerts box1_tris[NUM_TRIS]; |
| 40 | + FaceVerts box2_tris[NUM_TRIS]; |
| 41 | + GetBoxTris(boxes1[n], box1_tris); |
| 42 | + GetBoxTris(boxes2[m], box2_tris); |
| 43 | + |
| 44 | + // Calculate the position of the center of the box which is used in |
| 45 | + // several calculations. This requires a tensor as input. |
| 46 | + const float3 box1_center = BoxCenter(boxes1[n]); |
| 47 | + const float3 box2_center = BoxCenter(boxes2[m]); |
| 48 | + |
| 49 | + // Convert to an array of face vertices |
| 50 | + FaceVerts box1_planes[NUM_PLANES]; |
| 51 | + GetBoxPlanes(boxes1[n], box1_planes); |
| 52 | + FaceVerts box2_planes[NUM_PLANES]; |
| 53 | + GetBoxPlanes(boxes2[m], box2_planes); |
| 54 | + |
| 55 | + // Get Box Volumes |
| 56 | + const float box1_vol = BoxVolume(box1_tris, box1_center, NUM_TRIS); |
| 57 | + const float box2_vol = BoxVolume(box2_tris, box2_center, NUM_TRIS); |
| 58 | + |
| 59 | + // Tris in Box1 intersection with Planes in Box2 |
| 60 | + // Initialize box1 intersecting faces. MAX_TRIS is the |
| 61 | + // max faces possible in the intersecting shape. |
| 62 | + // TODO: determine if the value of MAX_TRIS is sufficient or |
| 63 | + // if we should store the max tris for each NxM computation |
| 64 | + // and throw an error if any exceeds the max. |
| 65 | + FaceVerts box1_intersect[MAX_TRIS]; |
| 66 | + for (int j = 0; j < NUM_TRIS; ++j) { |
| 67 | + // Initialize the faces from the box |
| 68 | + box1_intersect[j] = box1_tris[j]; |
| 69 | + } |
| 70 | + // Get the count of the actual number of faces in the intersecting shape |
| 71 | + int box1_count = BoxIntersections(box2_planes, box2_center, box1_intersect); |
| 72 | + |
| 73 | + // Tris in Box2 intersection with Planes in Box1 |
| 74 | + FaceVerts box2_intersect[MAX_TRIS]; |
| 75 | + for (int j = 0; j < NUM_TRIS; ++j) { |
| 76 | + box2_intersect[j] = box2_tris[j]; |
| 77 | + } |
| 78 | + const int box2_count = |
| 79 | + BoxIntersections(box1_planes, box1_center, box2_intersect); |
| 80 | + |
| 81 | + // If there are overlapping regions in Box2, remove any coplanar faces |
| 82 | + if (box2_count > 0) { |
| 83 | + // Identify if any triangles in Box2 are coplanar with Box1 |
| 84 | + Keep tri2_keep[MAX_TRIS]; |
| 85 | + for (int j = 0; j < MAX_TRIS; ++j) { |
| 86 | + // Initialize the valid faces to be true |
| 87 | + tri2_keep[j].keep = j < box2_count ? true : false; |
| 88 | + } |
| 89 | + for (int b1 = 0; b1 < box1_count; ++b1) { |
| 90 | + for (int b2 = 0; b2 < box2_count; ++b2) { |
| 91 | + const bool is_coplanar = |
| 92 | + IsCoplanarFace(box1_intersect[b1], box2_intersect[b2]); |
| 93 | + if (is_coplanar) { |
| 94 | + tri2_keep[b2].keep = false; |
| 95 | + } |
| 96 | + } |
| 97 | + } |
| 98 | + |
| 99 | + // Keep only the non coplanar triangles in Box2 - add them to the |
| 100 | + // Box1 triangles. |
| 101 | + for (int b2 = 0; b2 < box2_count; ++b2) { |
| 102 | + if (tri2_keep[b2].keep) { |
| 103 | + box1_intersect[box1_count] = box2_intersect[b2]; |
| 104 | + // box1_count will determine the total faces in the |
| 105 | + // intersecting shape |
| 106 | + box1_count++; |
| 107 | + } |
| 108 | + } |
| 109 | + } |
| 110 | + |
| 111 | + // Initialize the vol and iou to 0.0 in case there are no triangles |
| 112 | + // in the intersecting shape. |
| 113 | + float vol = 0.0; |
| 114 | + float iou = 0.0; |
| 115 | + |
| 116 | + // If there are triangles in the intersecting shape |
| 117 | + if (box1_count > 0) { |
| 118 | + // The intersecting shape is a polyhedron made up of the |
| 119 | + // triangular faces that are all now in box1_intersect. |
| 120 | + // Calculate the polyhedron center |
| 121 | + const float3 poly_center = PolyhedronCenter(box1_intersect, box1_count); |
| 122 | + // Compute intersecting polyhedron volume |
| 123 | + vol = BoxVolume(box1_intersect, poly_center, box1_count); |
| 124 | + // Compute IoU |
| 125 | + iou = vol / (box1_vol + box2_vol - vol); |
| 126 | + } |
| 127 | + |
| 128 | + // Write the volume and IoU to global memory |
| 129 | + vols[n][m] = vol; |
| 130 | + ious[n][m] = iou; |
| 131 | + } |
| 132 | +} |
| 133 | + |
| 134 | +std::tuple<at::Tensor, at::Tensor> IoUBox3DCuda( |
| 135 | + const at::Tensor& boxes1, // (N, 8, 3) |
| 136 | + const at::Tensor& boxes2) { // (M, 8, 3) |
| 137 | + // Check inputs are on the same device |
| 138 | + at::TensorArg boxes1_t{boxes1, "boxes1", 1}, boxes2_t{boxes2, "boxes2", 2}; |
| 139 | + at::CheckedFrom c = "IoUBox3DCuda"; |
| 140 | + at::checkAllSameGPU(c, {boxes1_t, boxes2_t}); |
| 141 | + at::checkAllSameType(c, {boxes1_t, boxes2_t}); |
| 142 | + |
| 143 | + // Set the device for the kernel launch based on the device of boxes1 |
| 144 | + at::cuda::CUDAGuard device_guard(boxes1.device()); |
| 145 | + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
| 146 | + |
| 147 | + TORCH_CHECK(boxes2.size(2) == boxes1.size(2), "Boxes must have shape (8, 3)"); |
| 148 | + |
| 149 | + TORCH_CHECK( |
| 150 | + (boxes2.size(1) == 8) && (boxes1.size(1) == 8), |
| 151 | + "Boxes must have shape (8, 3)"); |
| 152 | + |
| 153 | + const int64_t N = boxes1.size(0); |
| 154 | + const int64_t M = boxes2.size(0); |
| 155 | + |
| 156 | + auto vols = at::zeros({N, M}, boxes1.options()); |
| 157 | + auto ious = at::zeros({N, M}, boxes1.options()); |
| 158 | + |
| 159 | + if (vols.numel() == 0) { |
| 160 | + AT_CUDA_CHECK(cudaGetLastError()); |
| 161 | + return std::make_tuple(vols, ious); |
| 162 | + } |
| 163 | + |
| 164 | + const size_t blocks = 512; |
| 165 | + const size_t threads = 256; |
| 166 | + |
| 167 | + IoUBox3DKernel<<<blocks, threads, 0, stream>>>( |
| 168 | + boxes1.packed_accessor64<float, 3, at::RestrictPtrTraits>(), |
| 169 | + boxes2.packed_accessor64<float, 3, at::RestrictPtrTraits>(), |
| 170 | + vols.packed_accessor64<float, 2, at::RestrictPtrTraits>(), |
| 171 | + ious.packed_accessor64<float, 2, at::RestrictPtrTraits>()); |
| 172 | + |
| 173 | + AT_CUDA_CHECK(cudaGetLastError()); |
| 174 | + |
| 175 | + return std::make_tuple(vols, ious); |
| 176 | +} |
0 commit comments