@@ -47,10 +47,12 @@ limitations under the License. */
47
47
#include " paddle/phi/core/string_tensor.h"
48
48
49
49
using phi::distributed::DistTensor;
50
+ using phi::distributed::DistTensorMeta;
50
51
using phi::distributed::Placement;
51
52
using phi::distributed::Placements;
52
53
using phi::distributed::ProcessMesh;
53
54
using phi::distributed::TensorDistAttr;
55
+ using phi::distributed::auto_parallel::str_join;
54
56
55
57
namespace paddle {
56
58
namespace pybind {
@@ -81,35 +83,51 @@ void EmptyTensorInitializer(TensorObject* self,
81
83
paddle::framework::proto::VarType::FP32,
82
84
const std::vector<int >& dims = {0 },
83
85
framework::proto::VarType::Type var_type =
84
- paddle::framework::proto::VarType::LOD_TENSOR) {
86
+ paddle::framework::proto::VarType::LOD_TENSOR,
87
+ ProcessMesh* process_mesh = nullptr ,
88
+ Placements* placements = nullptr ) {
85
89
auto ddims = common::make_ddim (dims);
86
90
self->tensor .set_name (name);
87
91
auto autograd_meta = egr::EagerUtils::autograd_meta (&(self->tensor ));
88
92
autograd_meta->SetPersistable (persistable);
89
93
if (stop_gradient != -1 ) {
90
94
autograd_meta->SetStopGradient (static_cast <bool >(stop_gradient));
91
95
}
92
- if (var_type == paddle::framework::proto::VarType::LOD_TENSOR) {
93
- // TODO(jiabin): Maybe support LOD later
94
- std::shared_ptr<phi::DenseTensor> dense_tensor = nullptr ;
95
- if (dims.size () == 1 && dims[0 ] == 0 ) {
96
- std::shared_ptr<phi::Allocation> allocation_ptr = nullptr ;
97
- dense_tensor = std::make_shared<phi::DenseTensor>(
98
- allocation_ptr,
99
- phi::DenseTensorMeta (paddle::framework::TransToPhiDataType (dtype),
100
- ddims));
101
- } else {
102
- // TODO(dev): we need enhance check for ddims.
103
- dense_tensor = std::make_shared<phi::DenseTensor>(
104
- std::make_shared<phi::Allocation>(),
105
- phi::DenseTensorMeta (paddle::framework::TransToPhiDataType (dtype),
106
- ddims));
96
+ if (process_mesh != nullptr ) {
97
+ #ifdef PADDLE_WITH_DISTRIBUTE
98
+ VLOG (6 ) << " in EmptyTensorInitializer, create DistTensor" ;
99
+ self->tensor .set_impl (std::make_shared<DistTensor>());
100
+ #else
101
+ PADDLE_THROW (platform::errors::Unavailable (
102
+ " The tensor-based initialization of (Dist)Tensor is not supported "
103
+ " in the current PaddlePaddle, please recompile and install "
104
+ " PaddlePaddle "
105
+ " with the option of `WITH_DISTRIBUTE=ON`." ));
106
+ #endif
107
+ } else {
108
+ VLOG (6 ) << " in EmptyTensorInitializer, create DenseTensor" ;
109
+ if (var_type == paddle::framework::proto::VarType::LOD_TENSOR) {
110
+ // TODO(jiabin): Maybe support LOD later
111
+ std::shared_ptr<phi::DenseTensor> dense_tensor = nullptr ;
112
+ if (dims.size () == 1 && dims[0 ] == 0 ) {
113
+ std::shared_ptr<phi::Allocation> allocation_ptr = nullptr ;
114
+ dense_tensor = std::make_shared<phi::DenseTensor>(
115
+ allocation_ptr,
116
+ phi::DenseTensorMeta (paddle::framework::TransToPhiDataType (dtype),
117
+ ddims));
118
+ } else {
119
+ // TODO(dev): we need enhance check for ddims.
120
+ dense_tensor = std::make_shared<phi::DenseTensor>(
121
+ std::make_shared<phi::Allocation>(),
122
+ phi::DenseTensorMeta (paddle::framework::TransToPhiDataType (dtype),
123
+ ddims));
124
+ }
125
+ self->tensor .set_impl (dense_tensor);
126
+ } else if (var_type == paddle::framework::proto::VarType::SELECTED_ROWS) {
127
+ std::shared_ptr<phi::SelectedRows> tensor =
128
+ std::make_shared<phi::SelectedRows>();
129
+ self->tensor .set_impl (tensor);
107
130
}
108
- self->tensor .set_impl (dense_tensor);
109
- } else if (var_type == paddle::framework::proto::VarType::SELECTED_ROWS) {
110
- std::shared_ptr<phi::SelectedRows> tensor =
111
- std::make_shared<phi::SelectedRows>();
112
- self->tensor .set_impl (tensor);
113
131
}
114
132
115
133
if (!autograd_meta->GetMutableGradNode ()) {
@@ -768,12 +786,16 @@ Tensor is the basic data structure in PaddlePaddle. There are some ways to creat
768
786
* 1.
769
787
* def __init__ ()
770
788
* 2.
789
+ * (should have at least five parameter, five parameters create DenseTensor,
790
+ * seven parameters create DistTensor)
771
791
* def __init__ (
772
792
* ** dtype: paddle::framework::proto::VarType::Type,
773
793
* ** dims: vector<int>,
774
794
* ** name: std::string,
775
795
* ** type: paddle::framework::proto::VarType::LodTensor,
776
- * ** persistable: bool)
796
+ * ** persistable: bool,
797
+ * ** process_mesh: phi::distributed::ProcessMesh,
798
+ * ** placements: std::vector<Placement>)
777
799
* 3. (multi-place)
778
800
* (should have at least one parameter, one parameter equals to case 4, zero
779
801
* parameter equals to case 1)
@@ -797,7 +819,7 @@ Tensor is the basic data structure in PaddlePaddle. There are some ways to creat
797
819
* ** global_tensor: Tensor,
798
820
* ** place: paddle::platform::Place,
799
821
* ** name: std::string,
800
- * ** process_mesh: phi::distributed::ProcessMesh)
822
+ * ** process_mesh: phi::distributed::ProcessMesh,
801
823
* ** placements: std::vector<Placement>)
802
824
* 7. (multi-place)
803
825
* (should have at least one parameter, one parameter equals to case 5, zero
@@ -806,7 +828,7 @@ Tensor is the basic data structure in PaddlePaddle. There are some ways to creat
806
828
* ** local_tensor: Tensor,
807
829
* ** global_dims: vector<int>,
808
830
* ** name: std::string,
809
- * ** process_mesh: phi::distributed::ProcessMesh)
831
+ * ** process_mesh: phi::distributed::ProcessMesh,
810
832
* ** placements: std::vector<Placement>)
811
833
* 8. (multi-place) (should have at least one parameter, one parameter similar
812
834
* to case 5, zero parameter equals to case 1.)
@@ -995,14 +1017,28 @@ int TensorInit(PyObject* self, PyObject* args, PyObject* kwargs) {
995
1017
CastPyArg2ProtoType (kw_type, 0 );
996
1018
bool persistable = CastPyArg2AttrBoolean (kw_persistable, 0 );
997
1019
1020
+ ProcessMesh* process_mesh_ptr = nullptr ;
1021
+ if (kw_process_mesh != nullptr ) {
1022
+ ProcessMesh process_mesh = CastPyArg2ProcessMesh (kw_process_mesh, 0 );
1023
+ process_mesh_ptr = &process_mesh;
1024
+ }
1025
+
1026
+ Placements* placements_ptr = nullptr ;
1027
+ if (kw_placements != nullptr ) {
1028
+ Placements placements = CastPyArg2VectorOfPlacement (kw_placements, 0 );
1029
+ placements_ptr = &placements;
1030
+ }
1031
+
998
1032
EmptyTensorInitializer (py_tensor_ptr,
999
1033
act_name,
1000
1034
egr::Controller::Instance ().GetExpectedPlace (),
1001
1035
persistable,
1002
1036
/* stop_gradient */ -1 ,
1003
1037
dtype,
1004
1038
dims,
1005
- var_type);
1039
+ var_type,
1040
+ process_mesh_ptr,
1041
+ placements_ptr);
1006
1042
1007
1043
return 0 ;
1008
1044
} else {
@@ -1025,12 +1061,12 @@ int TensorInit(PyObject* self, PyObject* args, PyObject* kwargs) {
1025
1061
py_tensor_ptr, kws_map, args, flag_kwargs, args_num);
1026
1062
return 0 ;
1027
1063
} else if (PyObject_TypeCheck (arg0_ptr, p_tensor_type)) {
1028
- VLOG (6 ) << " Calling case5's or case6's initializer." ;
1064
+ VLOG (6 ) << " Calling case5's or case6's or case7's initializer." ;
1029
1065
AutoInitTensorByTensor (
1030
1066
py_tensor_ptr, kws_map, args, flag_kwargs, args_num);
1031
1067
return 0 ;
1032
1068
} else if (PyObject_TypeCheck (arg0_ptr, g_framework_tensor_pytype)) {
1033
- VLOG (6 ) << " Calling case7 's initializer." ;
1069
+ VLOG (6 ) << " Calling case8 's initializer." ;
1034
1070
AutoInitTensorByTensor (py_tensor_ptr,
1035
1071
kws_map,
1036
1072
args,
@@ -1137,6 +1173,56 @@ int TensorInit(PyObject* self, PyObject* args, PyObject* kwargs) {
1137
1173
" Please check your code and make sure the first position args is "
1138
1174
" PyArray." ));
1139
1175
}
1176
+ } else if (args_num == (Py_ssize_t)7 ) {
1177
+ if (!flag_kwargs) {
1178
+ PyObject* arg0_ptr = PyTuple_GET_ITEM (args, 0 );
1179
+ if (PyObject_TypeCheck (arg0_ptr, g_vartype_pytype)) {
1180
+ VLOG (6 ) << " Calling case2's initializer." ;
1181
+ paddle::framework::proto::VarType::Type dtype =
1182
+ CastPyArg2ProtoType (PyTuple_GET_ITEM (args, 0 ), 0 );
1183
+ std::vector<int > dims =
1184
+ CastPyArg2VectorOfInt (PyTuple_GET_ITEM (args, 1 ), 1 );
1185
+ std::string act_name = " " ;
1186
+ PyObject* name_obj = PyTuple_GET_ITEM (args, 2 );
1187
+ if (name_obj == Py_None) {
1188
+ act_name = egr::Controller::Instance ().GenerateUniqueName (
1189
+ " generated_tensor" );
1190
+ } else {
1191
+ act_name = CastPyArg2AttrString (PyTuple_GET_ITEM (args, 2 ), 2 );
1192
+ }
1193
+ paddle::framework::proto::VarType::Type var_type =
1194
+ CastPyArg2ProtoType (PyTuple_GET_ITEM (args, 3 ), 3 );
1195
+ bool persistable = CastPyArg2AttrBoolean (PyTuple_GET_ITEM (args, 4 ), 4 );
1196
+ ProcessMesh process_mesh =
1197
+ CastPyArg2ProcessMesh (PyTuple_GET_ITEM (args, 5 ), 5 );
1198
+ Placements placements =
1199
+ CastPyArg2VectorOfPlacement (PyTuple_GET_ITEM (args, 6 ), 6 );
1200
+ EmptyTensorInitializer (py_tensor_ptr,
1201
+ act_name,
1202
+ egr::Controller::Instance ().GetExpectedPlace (),
1203
+ persistable,
1204
+ -1 ,
1205
+ dtype,
1206
+ dims,
1207
+ var_type,
1208
+ &process_mesh,
1209
+ &placements);
1210
+ return 0 ;
1211
+ } else {
1212
+ PADDLE_THROW (platform::errors::InvalidArgument (
1213
+ " Incompatible constructor arguments, "
1214
+ " there are only 7 position args,"
1215
+ " but the first position args should be dtype. "
1216
+ " Please check your code and make sure you call the existed "
1217
+ " constructor." ));
1218
+ }
1219
+ } else {
1220
+ PADDLE_THROW (platform::errors::InvalidArgument (
1221
+ " Incompatible constructor arguments, "
1222
+ " there are 7 position args and remaining arguments are kwargs,"
1223
+ " Please check your code and make sure you call the existed "
1224
+ " constructor." ));
1225
+ }
1140
1226
} else {
1141
1227
PADDLE_THROW (platform::errors::Fatal (
1142
1228
" Can't not find expected num of args, please check your call, and "
0 commit comments