Skip to content

Commit 7944d24

Browse files
bottlerfacebook-github-bot
authored andcommitted
gather_scatter on CPU
Summary: CPU implementation of the graph convolution op. Reviewed By: nikhilaravi, gkioxari Differential Revision: D21384361 fbshipit-source-id: bc96730e9727bb9aa1b0a232dcb82f0c0d12fe6b
1 parent 4872a2c commit 7944d24

File tree

4 files changed

+68
-19
lines changed

4 files changed

+68
-19
lines changed

pytorch3d/csrc/gather_scatter/gather_scatter.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ __global__ void GatherScatterCudaKernel(
4444
}
4545

4646
at::Tensor GatherScatterCuda(
47-
const at::Tensor input,
48-
const at::Tensor edges,
47+
const at::Tensor& input,
48+
const at::Tensor& edges,
4949
bool directed,
5050
bool backward) {
5151
// Check inputs are on the same device

pytorch3d/csrc/gather_scatter/gather_scatter.h

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,22 @@
2020
// Returns:
2121
// output: float32 Tensor of same shape as input.
2222

23-
// Cuda implementation.
2423
at::Tensor GatherScatterCuda(
25-
const at::Tensor input,
26-
const at::Tensor edges,
24+
const at::Tensor& input,
25+
const at::Tensor& edges,
26+
bool directed,
27+
bool backward);
28+
29+
at::Tensor GatherScatterCpu(
30+
const at::Tensor& input,
31+
const at::Tensor& edges,
2732
bool directed,
2833
bool backward);
2934

3035
// Exposed implementation.
3136
at::Tensor GatherScatter(
32-
const at::Tensor input,
33-
const at::Tensor edges,
37+
const at::Tensor& input,
38+
const at::Tensor& edges,
3439
bool directed,
3540
bool backward) {
3641
if (input.is_cuda() && edges.is_cuda()) {
@@ -42,5 +47,5 @@ at::Tensor GatherScatter(
4247
AT_ERROR("Not compiled with GPU support.");
4348
#endif
4449
}
45-
AT_ERROR("Not implemented on the CPU");
50+
return GatherScatterCpu(input, edges, directed, backward);
4651
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
3+
#include <ATen/ATen.h>
4+
5+
at::Tensor GatherScatterCpu(
6+
const at::Tensor& input,
7+
const at::Tensor& edges,
8+
bool directed,
9+
bool backward) {
10+
const auto num_vertices = input.size(0);
11+
const auto input_feature_dim = input.size(1);
12+
const auto num_edges = edges.size(0);
13+
14+
auto output = at::zeros({num_vertices, input_feature_dim}, input.options());
15+
16+
auto input_a = input.accessor<float, 2>();
17+
auto edges_a = edges.accessor<int64_t, 2>();
18+
auto output_a = output.accessor<float, 2>();
19+
const int v0_idx = backward ? 1 : 0;
20+
const int v1_idx = backward ? 0 : 1;
21+
22+
for (int e = 0; e < num_edges; ++e) {
23+
// Get indices of vertices which form the edge.
24+
const int64_t v0 = edges_a[e][v0_idx];
25+
const int64_t v1 = edges_a[e][v1_idx];
26+
27+
for (int d = 0; d < input_feature_dim; ++d) {
28+
output_a[v0][d] += input_a[v1][d];
29+
if (!directed) {
30+
output_a[v1][d] += input_a[v0][d];
31+
}
32+
}
33+
}
34+
return output;
35+
}

tests/test_graph_conv.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -101,17 +101,24 @@ def test_backward(self):
101101
mesh = ico_sphere()
102102
verts = mesh.verts_packed()
103103
edges = mesh.edges_packed()
104+
verts_cpu = verts.clone()
105+
edges_cpu = edges.clone()
104106
verts_cuda = verts.clone().to(device)
105107
edges_cuda = edges.clone().to(device)
106108
verts.requires_grad = True
109+
verts_cpu.requires_grad = True
107110
verts_cuda.requires_grad = True
108111

109112
neighbor_sums_cuda = gather_scatter(verts_cuda, edges_cuda, False)
113+
neighbor_sums_cpu = gather_scatter(verts_cpu, edges_cpu, False)
110114
neighbor_sums = gather_scatter_python(verts, edges, False)
111-
neighbor_sums_cuda.sum().backward()
112-
neighbor_sums.sum().backward()
115+
randoms = torch.rand_like(neighbor_sums)
116+
(neighbor_sums_cuda * randoms.cuda()).sum().backward()
117+
(neighbor_sums_cpu * randoms).sum().backward()
118+
(neighbor_sums * randoms).sum().backward()
113119

114-
self.assertClose(verts.grad.cpu(), verts_cuda.grad.cpu())
120+
self.assertClose(verts.grad, verts_cuda.grad.cpu())
121+
self.assertClose(verts.grad, verts_cpu.grad)
115122

116123
def test_repr(self):
117124
conv = GraphConv(32, 64, directed=True)
@@ -141,22 +148,24 @@ def test_gather_scatter(self):
141148
w0 = nn.Linear(3, 1)
142149
input = w0(verts)
143150

144-
# output
145-
output_cpu = gather_scatter_python(input, edges, False)
151+
# undirected
152+
output_python = gather_scatter_python(input, edges, False)
146153
output_cuda = _C.gather_scatter(
147154
input.to(device=device), edges.to(device=device), False, False
148155
)
149-
self.assertClose(output_cuda.cpu(), output_cpu)
150-
with self.assertRaises(Exception) as err:
151-
_C.gather_scatter(input.cpu(), edges.cpu(), False, False)
152-
self.assertTrue("Not implemented on the CPU" in str(err.exception))
156+
self.assertClose(output_cuda.cpu(), output_python)
157+
158+
output_cpu = _C.gather_scatter(input.cpu(), edges.cpu(), False, False)
159+
self.assertClose(output_cpu, output_python)
153160

154161
# directed
155-
output_cpu = gather_scatter_python(input, edges, True)
162+
output_python = gather_scatter_python(input, edges, True)
156163
output_cuda = _C.gather_scatter(
157164
input.to(device=device), edges.to(device=device), True, False
158165
)
159-
self.assertClose(output_cuda.cpu(), output_cpu)
166+
self.assertClose(output_cuda.cpu(), output_python)
167+
output_cpu = _C.gather_scatter(input.cpu(), edges.cpu(), True, False)
168+
self.assertClose(output_cpu, output_python)
160169

161170
@staticmethod
162171
def graph_conv_forward_backward(

0 commit comments

Comments
 (0)