Skip to content

Commit 0868cd2

Browse files
jeff41404co63oc
authored andcommitted
fix model run error when use auto parallel and recompute(use_reentrant=false) (PaddlePaddle#65188)
* fix model run error when auto parallel and recompute and use_reentrant=false * solve the defect of TensorWrapper not considering DistTensor * add unittest * fix recompute have not support cpu when use_reentrant is false
1 parent 791045f commit 0868cd2

File tree

5 files changed

+205
-37
lines changed

5 files changed

+205
-37
lines changed

paddle/fluid/eager/tensor_wrapper.h

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,16 +152,36 @@ class TensorWrapper {
152152
#ifndef PADDLE_NO_PYTHON
153153
if (packed_value_ && unpack_hook_) {
154154
auto tensor_unpacked = (*unpack_hook_)(packed_value_);
155-
auto src_dense_tensor =
156-
static_cast<phi::DenseTensor*>(tensor_unpacked.impl().get());
155+
phi::DenseTensor* src_dense_tensor = nullptr;
156+
if (tensor_unpacked.is_dense_tensor()) {
157+
VLOG(6) << "tensor_unpacked is DenseTensor";
158+
src_dense_tensor =
159+
static_cast<phi::DenseTensor*>(tensor_unpacked.impl().get());
160+
} else if (tensor_unpacked.is_dist_tensor()) {
161+
VLOG(6) << "tensor_unpacked is DistTensor";
162+
src_dense_tensor = static_cast<phi::distributed::DistTensor*>(
163+
tensor_unpacked.impl().get())
164+
->unsafe_mutable_value();
165+
} else {
166+
PADDLE_THROW(
167+
paddle::platform::errors::Fatal("Unrecognized tensor_unpacked type "
168+
"for egr::TensorWrapper::recover"));
169+
}
170+
157171
if (intermidiate_tensor_.is_dense_tensor()) {
172+
VLOG(6) << "intermidiate_tensor_ is DenseTensor";
158173
static_cast<phi::DenseTensor*>(intermidiate_tensor_.impl().get())
159174
->ResetHolder(src_dense_tensor->MoveMemoryHolder());
160175
} else if (intermidiate_tensor_.is_dist_tensor()) {
176+
VLOG(6) << "intermidiate_tensor_ is DistTensor";
161177
static_cast<phi::distributed::DistTensor*>(
162178
intermidiate_tensor_.impl().get())
163179
->unsafe_mutable_value()
164180
->ResetHolder(src_dense_tensor->MoveMemoryHolder());
181+
} else {
182+
PADDLE_THROW(paddle::platform::errors::Fatal(
183+
"Unrecognized intermidiate_tensor_ type for "
184+
"egr::TensorWrapper::recover"));
165185
}
166186
} else {
167187
#endif

paddle/fluid/pybind/eager.cc

Lines changed: 113 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,12 @@ limitations under the License. */
4747
#include "paddle/phi/core/string_tensor.h"
4848

4949
using phi::distributed::DistTensor;
50+
using phi::distributed::DistTensorMeta;
5051
using phi::distributed::Placement;
5152
using phi::distributed::Placements;
5253
using phi::distributed::ProcessMesh;
5354
using phi::distributed::TensorDistAttr;
55+
using phi::distributed::auto_parallel::str_join;
5456

5557
namespace paddle {
5658
namespace pybind {
@@ -81,35 +83,51 @@ void EmptyTensorInitializer(TensorObject* self,
8183
paddle::framework::proto::VarType::FP32,
8284
const std::vector<int>& dims = {0},
8385
framework::proto::VarType::Type var_type =
84-
paddle::framework::proto::VarType::LOD_TENSOR) {
86+
paddle::framework::proto::VarType::LOD_TENSOR,
87+
ProcessMesh* process_mesh = nullptr,
88+
Placements* placements = nullptr) {
8589
auto ddims = common::make_ddim(dims);
8690
self->tensor.set_name(name);
8791
auto autograd_meta = egr::EagerUtils::autograd_meta(&(self->tensor));
8892
autograd_meta->SetPersistable(persistable);
8993
if (stop_gradient != -1) {
9094
autograd_meta->SetStopGradient(static_cast<bool>(stop_gradient));
9195
}
92-
if (var_type == paddle::framework::proto::VarType::LOD_TENSOR) {
93-
// TODO(jiabin): Maybe support LOD later
94-
std::shared_ptr<phi::DenseTensor> dense_tensor = nullptr;
95-
if (dims.size() == 1 && dims[0] == 0) {
96-
std::shared_ptr<phi::Allocation> allocation_ptr = nullptr;
97-
dense_tensor = std::make_shared<phi::DenseTensor>(
98-
allocation_ptr,
99-
phi::DenseTensorMeta(paddle::framework::TransToPhiDataType(dtype),
100-
ddims));
101-
} else {
102-
// TODO(dev): we need enhance check for ddims.
103-
dense_tensor = std::make_shared<phi::DenseTensor>(
104-
std::make_shared<phi::Allocation>(),
105-
phi::DenseTensorMeta(paddle::framework::TransToPhiDataType(dtype),
106-
ddims));
96+
if (process_mesh != nullptr) {
97+
#ifdef PADDLE_WITH_DISTRIBUTE
98+
VLOG(6) << "in EmptyTensorInitializer, create DistTensor";
99+
self->tensor.set_impl(std::make_shared<DistTensor>());
100+
#else
101+
PADDLE_THROW(platform::errors::Unavailable(
102+
"The tensor-based initialization of (Dist)Tensor is not supported "
103+
"in the current PaddlePaddle, please recompile and install "
104+
"PaddlePaddle "
105+
"with the option of `WITH_DISTRIBUTE=ON`."));
106+
#endif
107+
} else {
108+
VLOG(6) << "in EmptyTensorInitializer, create DenseTensor";
109+
if (var_type == paddle::framework::proto::VarType::LOD_TENSOR) {
110+
// TODO(jiabin): Maybe support LOD later
111+
std::shared_ptr<phi::DenseTensor> dense_tensor = nullptr;
112+
if (dims.size() == 1 && dims[0] == 0) {
113+
std::shared_ptr<phi::Allocation> allocation_ptr = nullptr;
114+
dense_tensor = std::make_shared<phi::DenseTensor>(
115+
allocation_ptr,
116+
phi::DenseTensorMeta(paddle::framework::TransToPhiDataType(dtype),
117+
ddims));
118+
} else {
119+
// TODO(dev): we need enhance check for ddims.
120+
dense_tensor = std::make_shared<phi::DenseTensor>(
121+
std::make_shared<phi::Allocation>(),
122+
phi::DenseTensorMeta(paddle::framework::TransToPhiDataType(dtype),
123+
ddims));
124+
}
125+
self->tensor.set_impl(dense_tensor);
126+
} else if (var_type == paddle::framework::proto::VarType::SELECTED_ROWS) {
127+
std::shared_ptr<phi::SelectedRows> tensor =
128+
std::make_shared<phi::SelectedRows>();
129+
self->tensor.set_impl(tensor);
107130
}
108-
self->tensor.set_impl(dense_tensor);
109-
} else if (var_type == paddle::framework::proto::VarType::SELECTED_ROWS) {
110-
std::shared_ptr<phi::SelectedRows> tensor =
111-
std::make_shared<phi::SelectedRows>();
112-
self->tensor.set_impl(tensor);
113131
}
114132

115133
if (!autograd_meta->GetMutableGradNode()) {
@@ -768,12 +786,16 @@ Tensor is the basic data structure in PaddlePaddle. There are some ways to creat
768786
* 1.
769787
* def __init__ ()
770788
* 2.
789+
* (should have at least five parameter, five parameters create DenseTensor,
790+
* seven parameters create DistTensor)
771791
* def __init__ (
772792
* ** dtype: paddle::framework::proto::VarType::Type,
773793
* ** dims: vector<int>,
774794
* ** name: std::string,
775795
* ** type: paddle::framework::proto::VarType::LodTensor,
776-
* ** persistable: bool)
796+
* ** persistable: bool,
797+
* ** process_mesh: phi::distributed::ProcessMesh,
798+
* ** placements: std::vector<Placement>)
777799
* 3. (multi-place)
778800
* (should have at least one parameter, one parameter equals to case 4, zero
779801
* parameter equals to case 1)
@@ -797,7 +819,7 @@ Tensor is the basic data structure in PaddlePaddle. There are some ways to creat
797819
* ** global_tensor: Tensor,
798820
* ** place: paddle::platform::Place,
799821
* ** name: std::string,
800-
* ** process_mesh: phi::distributed::ProcessMesh)
822+
* ** process_mesh: phi::distributed::ProcessMesh,
801823
* ** placements: std::vector<Placement>)
802824
* 7. (multi-place)
803825
* (should have at least one parameter, one parameter equals to case 5, zero
@@ -806,7 +828,7 @@ Tensor is the basic data structure in PaddlePaddle. There are some ways to creat
806828
* ** local_tensor: Tensor,
807829
* ** global_dims: vector<int>,
808830
* ** name: std::string,
809-
* ** process_mesh: phi::distributed::ProcessMesh)
831+
* ** process_mesh: phi::distributed::ProcessMesh,
810832
* ** placements: std::vector<Placement>)
811833
* 8. (multi-place) (should have at least one parameter, one parameter similar
812834
* to case 5, zero parameter equals to case 1.)
@@ -995,14 +1017,28 @@ int TensorInit(PyObject* self, PyObject* args, PyObject* kwargs) {
9951017
CastPyArg2ProtoType(kw_type, 0);
9961018
bool persistable = CastPyArg2AttrBoolean(kw_persistable, 0);
9971019

1020+
ProcessMesh* process_mesh_ptr = nullptr;
1021+
if (kw_process_mesh != nullptr) {
1022+
ProcessMesh process_mesh = CastPyArg2ProcessMesh(kw_process_mesh, 0);
1023+
process_mesh_ptr = &process_mesh;
1024+
}
1025+
1026+
Placements* placements_ptr = nullptr;
1027+
if (kw_placements != nullptr) {
1028+
Placements placements = CastPyArg2VectorOfPlacement(kw_placements, 0);
1029+
placements_ptr = &placements;
1030+
}
1031+
9981032
EmptyTensorInitializer(py_tensor_ptr,
9991033
act_name,
10001034
egr::Controller::Instance().GetExpectedPlace(),
10011035
persistable,
10021036
/* stop_gradient */ -1,
10031037
dtype,
10041038
dims,
1005-
var_type);
1039+
var_type,
1040+
process_mesh_ptr,
1041+
placements_ptr);
10061042

10071043
return 0;
10081044
} else {
@@ -1025,12 +1061,12 @@ int TensorInit(PyObject* self, PyObject* args, PyObject* kwargs) {
10251061
py_tensor_ptr, kws_map, args, flag_kwargs, args_num);
10261062
return 0;
10271063
} else if (PyObject_TypeCheck(arg0_ptr, p_tensor_type)) {
1028-
VLOG(6) << "Calling case5's or case6's initializer.";
1064+
VLOG(6) << "Calling case5's or case6's or case7's initializer.";
10291065
AutoInitTensorByTensor(
10301066
py_tensor_ptr, kws_map, args, flag_kwargs, args_num);
10311067
return 0;
10321068
} else if (PyObject_TypeCheck(arg0_ptr, g_framework_tensor_pytype)) {
1033-
VLOG(6) << "Calling case7's initializer.";
1069+
VLOG(6) << "Calling case8's initializer.";
10341070
AutoInitTensorByTensor(py_tensor_ptr,
10351071
kws_map,
10361072
args,
@@ -1137,6 +1173,56 @@ int TensorInit(PyObject* self, PyObject* args, PyObject* kwargs) {
11371173
"Please check your code and make sure the first position args is "
11381174
"PyArray."));
11391175
}
1176+
} else if (args_num == (Py_ssize_t)7) {
1177+
if (!flag_kwargs) {
1178+
PyObject* arg0_ptr = PyTuple_GET_ITEM(args, 0);
1179+
if (PyObject_TypeCheck(arg0_ptr, g_vartype_pytype)) {
1180+
VLOG(6) << "Calling case2's initializer.";
1181+
paddle::framework::proto::VarType::Type dtype =
1182+
CastPyArg2ProtoType(PyTuple_GET_ITEM(args, 0), 0);
1183+
std::vector<int> dims =
1184+
CastPyArg2VectorOfInt(PyTuple_GET_ITEM(args, 1), 1);
1185+
std::string act_name = "";
1186+
PyObject* name_obj = PyTuple_GET_ITEM(args, 2);
1187+
if (name_obj == Py_None) {
1188+
act_name = egr::Controller::Instance().GenerateUniqueName(
1189+
"generated_tensor");
1190+
} else {
1191+
act_name = CastPyArg2AttrString(PyTuple_GET_ITEM(args, 2), 2);
1192+
}
1193+
paddle::framework::proto::VarType::Type var_type =
1194+
CastPyArg2ProtoType(PyTuple_GET_ITEM(args, 3), 3);
1195+
bool persistable = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 4), 4);
1196+
ProcessMesh process_mesh =
1197+
CastPyArg2ProcessMesh(PyTuple_GET_ITEM(args, 5), 5);
1198+
Placements placements =
1199+
CastPyArg2VectorOfPlacement(PyTuple_GET_ITEM(args, 6), 6);
1200+
EmptyTensorInitializer(py_tensor_ptr,
1201+
act_name,
1202+
egr::Controller::Instance().GetExpectedPlace(),
1203+
persistable,
1204+
-1,
1205+
dtype,
1206+
dims,
1207+
var_type,
1208+
&process_mesh,
1209+
&placements);
1210+
return 0;
1211+
} else {
1212+
PADDLE_THROW(platform::errors::InvalidArgument(
1213+
"Incompatible constructor arguments, "
1214+
"there are only 7 position args,"
1215+
"but the first position args should be dtype. "
1216+
"Please check your code and make sure you call the existed "
1217+
"constructor."));
1218+
}
1219+
} else {
1220+
PADDLE_THROW(platform::errors::InvalidArgument(
1221+
"Incompatible constructor arguments, "
1222+
"there are 7 position args and remaining arguments are kwargs,"
1223+
"Please check your code and make sure you call the existed "
1224+
"constructor."));
1225+
}
11401226
} else {
11411227
PADDLE_THROW(platform::errors::Fatal(
11421228
"Can't not find expected num of args, please check your call, and "

python/paddle/distributed/fleet/recompute/recompute.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,8 @@ def _recompute_without_reentrant(
295295
cur_device = paddle.get_device()
296296
if 'gpu:' in cur_device:
297297
fw_cuda_rng_state = paddle.get_cuda_rng_state()
298+
elif 'cpu' in cur_device:
299+
fw_cuda_rng_state = paddle.get_rng_state()
298300
elif 'xpu:' in cur_device:
299301
fw_cuda_rng_state = paddle.get_rng_state()
300302
elif (
@@ -346,13 +348,26 @@ def inner_pack(inner_x):
346348
return
347349

348350
if inner_x.is_contiguous():
349-
tmp_tensor = core.eager.Tensor(
350-
inner_x.dtype,
351-
inner_x.shape,
352-
inner_x.name + "cpy",
353-
core.VarDesc.VarType.LOD_TENSOR,
354-
inner_x.persistable,
355-
)
351+
if inner_x.is_dist():
352+
# TODO(jeff41404): it seems better to use `tmp_tensor = core.eager.Tensor(inner_x)`,
353+
# but other errors will be triggered during the current period, and can be modified after resolution
354+
tmp_tensor = core.eager.Tensor(
355+
inner_x.dtype,
356+
inner_x.shape,
357+
inner_x.name + "cpy",
358+
core.VarDesc.VarType.LOD_TENSOR,
359+
inner_x.persistable,
360+
inner_x.process_mesh,
361+
inner_x.placements,
362+
)
363+
else:
364+
tmp_tensor = core.eager.Tensor(
365+
inner_x.dtype,
366+
inner_x.shape,
367+
inner_x.name + "cpy",
368+
core.VarDesc.VarType.LOD_TENSOR,
369+
inner_x.persistable,
370+
)
356371
inner_x._share_buffer_to(tmp_tensor)
357372
storage[holder_list[unpack_counter - 1]()] = tmp_tensor
358373
else:

test/auto_parallel/semi_auto_parallel_simple_net.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __init__(
4040
self,
4141
param_prefix="",
4242
is_recompute=False,
43+
recompute_use_reentrant=True,
4344
is_pp=False,
4445
pp_reshard_dist_attr=None,
4546
):
@@ -50,6 +51,7 @@ def __init__(
5051

5152
self.is_pp = is_pp
5253
self.is_recompute = is_recompute
54+
self.recompute_use_reentrant = recompute_use_reentrant
5355
self.pp_reshard_dist_attr = pp_reshard_dist_attr
5456
self.linear_0 = nn.Linear(
5557
IMAGE_SIZE, IMAGE_SIZE, weight_attr_0, bias_attr=False
@@ -70,7 +72,10 @@ def _inner_forward_fn(self, x):
7072

7173
def forward(self, x):
7274
if self.is_recompute:
73-
return recompute(self._inner_forward_fn, x)
75+
if self.recompute_use_reentrant:
76+
return recompute(self._inner_forward_fn, x)
77+
else:
78+
return recompute(self._inner_forward_fn, x, use_reentrant=False)
7479
else:
7580
return self._inner_forward_fn(x)
7681

test/auto_parallel/semi_auto_parallel_simple_net_recompute.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,25 @@ def test_dp_demo_net(self):
8181
self.check_tensor_eq(param, param_base, rtol=1e-4)
8282
self.check_tensor_eq(param.grad, param_base.grad)
8383

84+
def test_dp_demo_net_use_reentrant_false(self):
85+
self.set_random_seed(self._seed)
86+
(
87+
self.dp_loss,
88+
self.dp_parameters,
89+
) = self.run_dynamic_recompute(
90+
DemoNet(
91+
"recompute_use_reentrant_false_dp_demo",
92+
is_recompute=True,
93+
recompute_use_reentrant=False,
94+
),
95+
shard_input=True,
96+
)
97+
self.check_tensor_eq(self.dp_loss, self.base_loss)
98+
self.check_tensor_eq(self.dp_loss, self.base_loss)
99+
for param, param_base in zip(self.dp_parameters, self.base_parameters):
100+
self.check_tensor_eq(param, param_base, rtol=1e-4)
101+
self.check_tensor_eq(param.grad, param_base.grad)
102+
84103
def test_mp_demo_net(self):
85104
self.set_random_seed(self._seed)
86105
mp_layer = dist.shard_layer(
@@ -98,9 +117,32 @@ def test_mp_demo_net(self):
98117
self.check_tensor_eq(param, param_base)
99118
self.check_tensor_eq(param.grad, param_base.grad)
100119

120+
def test_mp_demo_net_use_reentrant_false(self):
121+
self.set_random_seed(self._seed)
122+
mp_layer = dist.shard_layer(
123+
DemoNet(
124+
"recompute_use_reentrant_false_mp_demo",
125+
is_recompute=True,
126+
recompute_use_reentrant=False,
127+
),
128+
self._mesh,
129+
self.shard_fn,
130+
)
131+
(
132+
self.mp_loss,
133+
self.mp_parameters,
134+
) = self.run_dynamic_recompute(mp_layer)
135+
136+
self.check_tensor_eq(self.mp_loss, self.base_loss)
137+
for param, param_base in zip(self.mp_parameters, self.base_parameters):
138+
self.check_tensor_eq(param, param_base)
139+
self.check_tensor_eq(param.grad, param_base.grad)
140+
101141
def run_test_case(self):
102142
self.test_dp_demo_net()
143+
self.test_dp_demo_net_use_reentrant_false()
103144
self.test_mp_demo_net()
145+
self.test_mp_demo_net_use_reentrant_false()
104146

105147

106148
if __name__ == '__main__':

0 commit comments

Comments
 (0)