Skip to content

【AutoParallel】Split bw in vpp #64534

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/paddle/distributed/auto_parallel/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def set_field_default_config(category, field, default_value):
set_field_default_config(PIPELINE, "enable_send_recv_overlap", False)
set_field_default_config(PIPELINE, "job_schedule_profiler_start", -1)
set_field_default_config(PIPELINE, "job_schedule_profiler_stop", -1)
set_field_default_config(PIPELINE, "split_backward", False)

#########################################
# quantization configuration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -555,4 +555,5 @@ def _apply_post_optimization(
"pp_stage": get_pp_stage(self._dist_context, rank),
"vpp_degree": self._strategy.pipeline.vpp_degree,
"dist_context": self._dist_context,
"split_backward": self._strategy.pipeline.split_backward,
}
132 changes: 132 additions & 0 deletions python/paddle/distributed/passes/pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,138 @@ def _split_ops(block):
return list(type_to_program.keys()), list(type_to_program.values())


def _program_for_vpp_split_bwk(
program,
num_model_chunks,
dist_context,
enable_send_recv_overlap=False,
):
if enable_send_recv_overlap:
_overlap_send_recv(program)
else:
_insert_sync_for_fthenb_1f1b(program, dist_context)

oprole_type = {
0: "forward",
1: "backward",
2: "backward_b",
3: 'backward_w',
4: "optimizer",
}

def _split_ops(block):
type_to_ops = OrderedDict()
for type in oprole_type.values():
chunk_ids = list(range(num_model_chunks))
if type == "optimizer":
type_to_ops[type] = []
else:
chunk_ids = (
chunk_ids if "backward" not in type else reversed(chunk_ids)
)
for chunk_id in chunk_ids:
type_to_ops[type + str(chunk_id)] = []
type_to_ops["fetch"] = []

dealed_op_idx = 0
for ip, op in enumerate(block.ops):
if ip < dealed_op_idx:
continue
if is_forward_op(op):
type = oprole_type[0]
elif is_backward_op(op):
types = _get_backward_op_type(block, op, ip)
dealed_op_idx = dealed_op_idx + len(types) - 1
elif is_optimize_op(op):
type = oprole_type[4]
else:
raise ValueError(
"The op role: "
+ str(op.attr('op_role'))
+ " isn't one of Forward, Backward or Optimizer."
)

dist_op = dist_context.get_dist_op_for_program(op)
if _is_fetch_op(op):
type_to_ops["fetch"].append(op)
elif is_optimize_op(op):
type_to_ops[type].append(op)
elif op.type == "feed":
type_to_ops[type + str(0)].append(op)
elif op.type == "share_buffer":
dist_pre_op = dist_context.get_dist_op_for_program(
block.ops[ip - 1]
)
type_to_ops[type + str(dist_pre_op.dist_attr.chunk_id)].append(
op
)
elif (
dist_op
and type + str(dist_op.dist_attr.chunk_id) in type_to_ops
and not is_backward_op(op)
):
type_to_ops[type + str(dist_op.dist_attr.chunk_id)].append(op)
elif (
dist_op
and type + str(dist_op.dist_attr.chunk_id) in type_to_ops
and is_backward_op(op)
):
for i, type in enumerate(types):
type_to_ops[
"backward" + str(dist_op.dist_attr.chunk_id)
].append(block.ops[ip + i])
type_to_ops[type + str(dist_op.dist_attr.chunk_id)].append(
block.ops[ip + i]
)
else:
raise ValueError(f"There is not dist_attr for op[{op.type}].")
dealed_op_idx = dealed_op_idx + 1

return type_to_ops

type_to_program = OrderedDict()

for ib, src_block in enumerate(program.blocks):
type_to_ops = _split_ops(src_block)
fetch_ops = type_to_ops.pop("fetch", [])
dst_blocks = []

if ib == 0:
for type, ops in type_to_ops.items():
type_to_program[type] = Program()
dst_block = type_to_program[type].block(0)
_add_ops_into_block(src_block, dst_block, ops)
dst_blocks.append(dst_block)
else:
for type, ops in type_to_ops.items():
if len(ops) > 0:
dst_block = type_to_program[type]._create_block(
parent_idx=src_block.parent_idx
)
dst_block._set_forward_block_idx(
src_block.forward_block_idx
)
_add_ops_into_block(src_block, dst_block, ops)
dst_blocks.append(dst_block)

for fetch_op in fetch_ops:
in_name = fetch_op.input('X')[0]
fetch_block = None
for dst_block in dst_blocks:
if dst_block._find_var_recursive(in_name):
fetch_block = dst_block
break

if fetch_block:
_create_program(src_block, fetch_block, fetch_op)

for prog in type_to_program.values():
prog._sync_with_cpp()
prog._roll_to_global_block()

return list(type_to_program.keys()), list(type_to_program.values())


def _get_backward_op_type(block, cur_op, idx):
# deal the ops pattern: [reshape2, reshape2, matmul_v2, reshape2, elementwise_add]
def is_reshape_matmul_pattern(cur_op, idx, ops, ops_len):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@

from ...utils.log_utils import get_logger
from ..pass_base import register_pass
from ..pass_utils import _program_for_vpp
from ..pass_utils import (
_program_for_vpp,
_program_for_vpp_split_bwk,
split_matmul_grad_to_matmul,
)
from .pipeline_pass_base import PipelinePassBase

FORWARD = "forward"
Expand Down Expand Up @@ -51,6 +55,7 @@ def _create_job_list(self):
stage_id = self.get_attr("pp_stage")
num_stages = self.get_attr("pp_degree")
num_model_chunks = self.get_attr("vpp_degree")
split_backward = self.get_attr("split_backward", False)
for i in range(num_model_chunks):
self._forward_micro_step_counter[i] = 0
self._backward_micro_step_counter[i] = 0
Expand All @@ -73,6 +78,9 @@ def _get_virtual_pp_rank(micro_step, forward):
warmup_steps = min(warmup_steps, total_num_steps)

steady_steps = total_num_steps - warmup_steps
real_split_backward = (
accumulate_steps == num_stages
) and split_backward

job_list = []
for micro_step in range(warmup_steps):
Expand Down Expand Up @@ -101,26 +109,75 @@ def _get_virtual_pp_rank(micro_step, forward):
bwd_micro_batch_id = self._record_bwd_micro_step(
bwd_virtual_pp_rank
)
bwd_job = core.Job(BACKWARD + str(bwd_virtual_pp_rank))
if real_split_backward:
bwd_job = core.Job(BACKWARD + "_b" + str(bwd_virtual_pp_rank))
else:
bwd_job = core.Job(BACKWARD + str(bwd_virtual_pp_rank))
bwd_job.set_micro_batch_id(bwd_micro_batch_id)
job_list.append(bwd_job)

for micro_step in range(steady_steps, total_num_steps):
virtual_pp_rank = _get_virtual_pp_rank(micro_step, forward=False)
micro_batch_id = self._record_bwd_micro_step(virtual_pp_rank)
bwd_job = core.Job(BACKWARD + str(virtual_pp_rank))
if real_split_backward:
bwd_job = core.Job(BACKWARD + "_b" + str(virtual_pp_rank))
else:
bwd_job = core.Job(BACKWARD + str(virtual_pp_rank))
bwd_job.set_micro_batch_id(micro_batch_id)
job_list.append(bwd_job)
# TODO(lizhiyu): Inserting 'backward_b' and 'backward_w' interleavedly can decrease the memory,
# but it reduces the speed. We should find the better way to use the code here.
# next_virtual_pp_rank = _get_virtual_pp_rank(micro_step + 1, forward=False)
# if next_virtual_pp_rank != virtual_pp_rank:
# for micro_batch_id in range(0, accumulate_steps):
# w_job = core.Job(BACKWARD + "_w" + str(virtual_pp_rank))
# w_job.set_micro_batch_id(micro_batch_id)
# job_list.append(w_job)

if real_split_backward:
for chunk_id in range(num_model_chunks - 1, -1, -1):
for micro_batch_id in range(0, accumulate_steps):
w_job = core.Job(BACKWARD + "_w" + str(chunk_id))
w_job.set_micro_batch_id(micro_batch_id)
job_list.append(w_job)

opt_job = core.Job(OPT)
job_list.append(opt_job)
return job_list

def _split_matmul_grad_ops_to_matmul(self, program, dist_context):
for block in program.blocks:
matmul_grad_op_idx = []
ops = block.ops
for i, op_i in enumerate(ops):
if (
op_i.type == "matmul_v2_grad"
and not op_i.attr("trans_x")
and not op_i.attr("trans_y")
):
matmul_grad_op_idx.append(i)

for matmul_grad_id in reversed(matmul_grad_op_idx):
split_matmul_grad_to_matmul(
block, matmul_grad_id, dist_context=dist_context
)

def _partial_programs(self, program):
dist_context = self.get_attr("dist_context")
num_model_chunks = self.get_attr("vpp_degree")
enable_send_recv_overlap = self.get_attr("enable_send_recv_overlap")
accumulate_steps = self.get_attr("num_micro_batches")
num_stages = self.get_attr("pp_degree")
split_backward = self.get_attr("split_backward", False)
types, sub_program_list = _program_for_vpp(
program, num_model_chunks, dist_context, enable_send_recv_overlap
)
if split_backward and accumulate_steps == num_stages:
self._split_matmul_grad_ops_to_matmul(program, dist_context)
types, sub_program_list = _program_for_vpp_split_bwk(
program,
num_model_chunks,
dist_context,
enable_send_recv_overlap,
)
return types, sub_program_list