Skip to content

Commit a61a3a2

Browse files
heavyrain-lzychen2016013
authored andcommitted
【AutoParallel】Split bw in vpp (PaddlePaddle#64534)
* split bw in vpp * polish
1 parent f9b8a33 commit a61a3a2

File tree

4 files changed

+194
-3
lines changed

4 files changed

+194
-3
lines changed

python/paddle/distributed/auto_parallel/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def set_field_default_config(category, field, default_value):
120120
set_field_default_config(PIPELINE, "enable_send_recv_overlap", False)
121121
set_field_default_config(PIPELINE, "job_schedule_profiler_start", -1)
122122
set_field_default_config(PIPELINE, "job_schedule_profiler_stop", -1)
123+
set_field_default_config(PIPELINE, "split_backward", False)
123124

124125
#########################################
125126
# quantization configuration

python/paddle/distributed/auto_parallel/static/parallelizer_v2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,4 +556,5 @@ def _apply_post_optimization(
556556
"pp_stage": get_pp_stage(self._dist_context, rank),
557557
"vpp_degree": self._strategy.pipeline.vpp_degree,
558558
"dist_context": self._dist_context,
559+
"split_backward": self._strategy.pipeline.split_backward,
559560
}

python/paddle/distributed/passes/pass_utils.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -778,6 +778,138 @@ def _split_ops(block):
778778
return list(type_to_program.keys()), list(type_to_program.values())
779779

780780

781+
def _program_for_vpp_split_bwk(
782+
program,
783+
num_model_chunks,
784+
dist_context,
785+
enable_send_recv_overlap=False,
786+
):
787+
if enable_send_recv_overlap:
788+
_overlap_send_recv(program)
789+
else:
790+
_insert_sync_for_fthenb_1f1b(program, dist_context)
791+
792+
oprole_type = {
793+
0: "forward",
794+
1: "backward",
795+
2: "backward_b",
796+
3: 'backward_w',
797+
4: "optimizer",
798+
}
799+
800+
def _split_ops(block):
801+
type_to_ops = OrderedDict()
802+
for type in oprole_type.values():
803+
chunk_ids = list(range(num_model_chunks))
804+
if type == "optimizer":
805+
type_to_ops[type] = []
806+
else:
807+
chunk_ids = (
808+
chunk_ids if "backward" not in type else reversed(chunk_ids)
809+
)
810+
for chunk_id in chunk_ids:
811+
type_to_ops[type + str(chunk_id)] = []
812+
type_to_ops["fetch"] = []
813+
814+
dealed_op_idx = 0
815+
for ip, op in enumerate(block.ops):
816+
if ip < dealed_op_idx:
817+
continue
818+
if is_forward_op(op):
819+
type = oprole_type[0]
820+
elif is_backward_op(op):
821+
types = _get_backward_op_type(block, op, ip)
822+
dealed_op_idx = dealed_op_idx + len(types) - 1
823+
elif is_optimize_op(op):
824+
type = oprole_type[4]
825+
else:
826+
raise ValueError(
827+
"The op role: "
828+
+ str(op.attr('op_role'))
829+
+ " isn't one of Forward, Backward or Optimizer."
830+
)
831+
832+
dist_op = dist_context.get_dist_op_for_program(op)
833+
if _is_fetch_op(op):
834+
type_to_ops["fetch"].append(op)
835+
elif is_optimize_op(op):
836+
type_to_ops[type].append(op)
837+
elif op.type == "feed":
838+
type_to_ops[type + str(0)].append(op)
839+
elif op.type == "share_buffer":
840+
dist_pre_op = dist_context.get_dist_op_for_program(
841+
block.ops[ip - 1]
842+
)
843+
type_to_ops[type + str(dist_pre_op.dist_attr.chunk_id)].append(
844+
op
845+
)
846+
elif (
847+
dist_op
848+
and type + str(dist_op.dist_attr.chunk_id) in type_to_ops
849+
and not is_backward_op(op)
850+
):
851+
type_to_ops[type + str(dist_op.dist_attr.chunk_id)].append(op)
852+
elif (
853+
dist_op
854+
and type + str(dist_op.dist_attr.chunk_id) in type_to_ops
855+
and is_backward_op(op)
856+
):
857+
for i, type in enumerate(types):
858+
type_to_ops[
859+
"backward" + str(dist_op.dist_attr.chunk_id)
860+
].append(block.ops[ip + i])
861+
type_to_ops[type + str(dist_op.dist_attr.chunk_id)].append(
862+
block.ops[ip + i]
863+
)
864+
else:
865+
raise ValueError(f"There is not dist_attr for op[{op.type}].")
866+
dealed_op_idx = dealed_op_idx + 1
867+
868+
return type_to_ops
869+
870+
type_to_program = OrderedDict()
871+
872+
for ib, src_block in enumerate(program.blocks):
873+
type_to_ops = _split_ops(src_block)
874+
fetch_ops = type_to_ops.pop("fetch", [])
875+
dst_blocks = []
876+
877+
if ib == 0:
878+
for type, ops in type_to_ops.items():
879+
type_to_program[type] = Program()
880+
dst_block = type_to_program[type].block(0)
881+
_add_ops_into_block(src_block, dst_block, ops)
882+
dst_blocks.append(dst_block)
883+
else:
884+
for type, ops in type_to_ops.items():
885+
if len(ops) > 0:
886+
dst_block = type_to_program[type]._create_block(
887+
parent_idx=src_block.parent_idx
888+
)
889+
dst_block._set_forward_block_idx(
890+
src_block.forward_block_idx
891+
)
892+
_add_ops_into_block(src_block, dst_block, ops)
893+
dst_blocks.append(dst_block)
894+
895+
for fetch_op in fetch_ops:
896+
in_name = fetch_op.input('X')[0]
897+
fetch_block = None
898+
for dst_block in dst_blocks:
899+
if dst_block._find_var_recursive(in_name):
900+
fetch_block = dst_block
901+
break
902+
903+
if fetch_block:
904+
_create_program(src_block, fetch_block, fetch_op)
905+
906+
for prog in type_to_program.values():
907+
prog._sync_with_cpp()
908+
prog._roll_to_global_block()
909+
910+
return list(type_to_program.keys()), list(type_to_program.values())
911+
912+
781913
def _get_backward_op_type(block, cur_op, idx):
782914
# deal the ops pattern: [reshape2, reshape2, matmul_v2, reshape2, elementwise_add]
783915
def is_reshape_matmul_pattern(cur_op, idx, ops, ops_len):

python/paddle/distributed/passes/pipeline_scheduler_pass/pipeline_vpp.py

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@
1818

1919
from ...utils.log_utils import get_logger
2020
from ..pass_base import register_pass
21-
from ..pass_utils import _program_for_vpp
21+
from ..pass_utils import (
22+
_program_for_vpp,
23+
_program_for_vpp_split_bwk,
24+
split_matmul_grad_to_matmul,
25+
)
2226
from .pipeline_pass_base import PipelinePassBase
2327

2428
FORWARD = "forward"
@@ -51,6 +55,7 @@ def _create_job_list(self):
5155
stage_id = self.get_attr("pp_stage")
5256
num_stages = self.get_attr("pp_degree")
5357
num_model_chunks = self.get_attr("vpp_degree")
58+
split_backward = self.get_attr("split_backward", False)
5459
for i in range(num_model_chunks):
5560
self._forward_micro_step_counter[i] = 0
5661
self._backward_micro_step_counter[i] = 0
@@ -73,6 +78,9 @@ def _get_virtual_pp_rank(micro_step, forward):
7378
warmup_steps = min(warmup_steps, total_num_steps)
7479

7580
steady_steps = total_num_steps - warmup_steps
81+
real_split_backward = (
82+
accumulate_steps == num_stages
83+
) and split_backward
7684

7785
job_list = []
7886
for micro_step in range(warmup_steps):
@@ -101,26 +109,75 @@ def _get_virtual_pp_rank(micro_step, forward):
101109
bwd_micro_batch_id = self._record_bwd_micro_step(
102110
bwd_virtual_pp_rank
103111
)
104-
bwd_job = core.Job(BACKWARD + str(bwd_virtual_pp_rank))
112+
if real_split_backward:
113+
bwd_job = core.Job(BACKWARD + "_b" + str(bwd_virtual_pp_rank))
114+
else:
115+
bwd_job = core.Job(BACKWARD + str(bwd_virtual_pp_rank))
105116
bwd_job.set_micro_batch_id(bwd_micro_batch_id)
106117
job_list.append(bwd_job)
107118

108119
for micro_step in range(steady_steps, total_num_steps):
109120
virtual_pp_rank = _get_virtual_pp_rank(micro_step, forward=False)
110121
micro_batch_id = self._record_bwd_micro_step(virtual_pp_rank)
111-
bwd_job = core.Job(BACKWARD + str(virtual_pp_rank))
122+
if real_split_backward:
123+
bwd_job = core.Job(BACKWARD + "_b" + str(virtual_pp_rank))
124+
else:
125+
bwd_job = core.Job(BACKWARD + str(virtual_pp_rank))
112126
bwd_job.set_micro_batch_id(micro_batch_id)
113127
job_list.append(bwd_job)
128+
# TODO(lizhiyu): Inserting 'backward_b' and 'backward_w' interleavedly can decrease the memory,
129+
# but it reduces the speed. We should find the better way to use the code here.
130+
# next_virtual_pp_rank = _get_virtual_pp_rank(micro_step + 1, forward=False)
131+
# if next_virtual_pp_rank != virtual_pp_rank:
132+
# for micro_batch_id in range(0, accumulate_steps):
133+
# w_job = core.Job(BACKWARD + "_w" + str(virtual_pp_rank))
134+
# w_job.set_micro_batch_id(micro_batch_id)
135+
# job_list.append(w_job)
136+
137+
if real_split_backward:
138+
for chunk_id in range(num_model_chunks - 1, -1, -1):
139+
for micro_batch_id in range(0, accumulate_steps):
140+
w_job = core.Job(BACKWARD + "_w" + str(chunk_id))
141+
w_job.set_micro_batch_id(micro_batch_id)
142+
job_list.append(w_job)
114143

115144
opt_job = core.Job(OPT)
116145
job_list.append(opt_job)
117146
return job_list
118147

148+
def _split_matmul_grad_ops_to_matmul(self, program, dist_context):
149+
for block in program.blocks:
150+
matmul_grad_op_idx = []
151+
ops = block.ops
152+
for i, op_i in enumerate(ops):
153+
if (
154+
op_i.type == "matmul_v2_grad"
155+
and not op_i.attr("trans_x")
156+
and not op_i.attr("trans_y")
157+
):
158+
matmul_grad_op_idx.append(i)
159+
160+
for matmul_grad_id in reversed(matmul_grad_op_idx):
161+
split_matmul_grad_to_matmul(
162+
block, matmul_grad_id, dist_context=dist_context
163+
)
164+
119165
def _partial_programs(self, program):
120166
dist_context = self.get_attr("dist_context")
121167
num_model_chunks = self.get_attr("vpp_degree")
122168
enable_send_recv_overlap = self.get_attr("enable_send_recv_overlap")
169+
accumulate_steps = self.get_attr("num_micro_batches")
170+
num_stages = self.get_attr("pp_degree")
171+
split_backward = self.get_attr("split_backward", False)
123172
types, sub_program_list = _program_for_vpp(
124173
program, num_model_chunks, dist_context, enable_send_recv_overlap
125174
)
175+
if split_backward and accumulate_steps == num_stages:
176+
self._split_matmul_grad_ops_to_matmul(program, dist_context)
177+
types, sub_program_list = _program_for_vpp_split_bwk(
178+
program,
179+
num_model_chunks,
180+
dist_context,
181+
enable_send_recv_overlap,
182+
)
126183
return types, sub_program_list

0 commit comments

Comments
 (0)