Skip to content

[TRTLLM-6019] feat: Remove cutlass min latency code from AutoTuner. #5394

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 12 additions & 34 deletions tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,9 @@ def bmm_out(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor) -> None:
class MoERunner(TunableRunner):
# avoid overhead of creating a new runner in forward pass
runner_dict = dict()
# TODO: only profile for min_latency_mode = False due to the error in the moe_kernels
tuning_config = TuningConfig(dynamic_tensor_specs=(
DynamicTensorSpec(0, 0, get_last_power_of_2_num_tokens_buckets(8192),
lambda x: min(last_positive_power_of_2(x), 8192)),
DynamicTensorSpec(3, 0, (0, ), lambda x: x),
))
lambda x: min(last_positive_power_of_2(x), 8192)), ))

def __init__(
self,
Expand All @@ -44,6 +41,7 @@ def __init__(
enable_alltoall: bool,
use_deepseek_fp8_block_scale: bool,
use_w4a8_group_scaling: bool,
min_latency_mode: bool,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need to pass the argument here? Should self.min_latency_mode just be false?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think removing the flag from tuning_config has achieved our goal for discorrelating it with autotuner. We can currently bypass the min_latency_mode to keep the consistency, and remove it completely in a future commit when the final API is clear.

):
self.x_dtype = x_dtype
self.weight_dtype = weight_dtype
Expand All @@ -58,7 +56,7 @@ def __init__(
self.enable_alltoall = enable_alltoall
self.use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale
self.use_w4a8_group_scaling = use_w4a8_group_scaling

self.min_latency_mode = min_latency_mode
instance_key = (x_dtype, weight_dtype, output_dtype,
use_deepseek_fp8_block_scale, use_w4a8_group_scaling)

Expand All @@ -74,22 +72,7 @@ def get_valid_tactics(
inputs: List[torch.Tensor],
profile: OptimizationProfile,
) -> List[int]:
x, _, _, min_latency_mode_tensor = inputs
min_latency_mode = min_latency_mode_tensor.size(0) == 1
m = x.shape[0]

# Only profile m <= 128 for min latency mode = True
# Profile all valid buckets for min latency mode = False
# TODO: min_latency_mode = True will cause the following error:
# Cannot profile configuration 4: Cutlass GEMM Tactic
# [TensorRT-LLM][ERROR] Assertion failed: Failed to initialize cutlass TMA WS grouped gemm.
# Should be fixed in the moe_kernels in the future.
invalid = (m > 128 and
min_latency_mode) or (m <= 128 and min_latency_mode and
(not self.weight_dtype == torch.int64))

return [] if invalid else list(
range(self.fused_moe_runner.get_tactic_num()))
return range(self.fused_moe_runner.get_tactic_num())

def forward(
self,
Expand All @@ -98,8 +81,7 @@ def forward(
tactic: int = -1,
do_preparation: bool = False,
):
x, fc1_expert_weights, fc2_expert_weights, min_latency_mode_tensor = inputs
min_latency_mode = min_latency_mode_tensor.size(0) == 1
x, fc1_expert_weights, fc2_expert_weights = inputs
# determine if we should use min latency mode according to the profiled seq len
self.fused_moe_runner.run_gemm_profile(
x,
Expand All @@ -113,7 +95,7 @@ def forward(
self.cluster_size,
self.cluster_rank,
self.enable_alltoall,
min_latency_mode,
self.min_latency_mode,
gemm_idx,
tactic,
do_preparation,
Expand All @@ -122,13 +104,11 @@ def forward(
@classmethod
@lru_cache(maxsize=None)
def refine_tuning_config(cls, tune_max_num_tokens: int):
cls.tuning_config = TuningConfig(dynamic_tensor_specs=(
DynamicTensorSpec(
cls.tuning_config = TuningConfig(
dynamic_tensor_specs=(DynamicTensorSpec(
0, 0, get_last_power_of_2_num_tokens_buckets(
tune_max_num_tokens), lambda x: min(
last_positive_power_of_2(x), tune_max_num_tokens)),
DynamicTensorSpec(3, 0, (0, ), lambda x: x),
))
last_positive_power_of_2(x), tune_max_num_tokens)), ))


@torch.library.custom_op("trtllm::fused_moe", mutates_args=())
Expand Down Expand Up @@ -157,9 +137,6 @@ def fused_moe(
tuner = AutoTuner.get()
MoERunner.refine_tuning_config(tune_max_num_tokens)

# TODO: set min_latency_mode always to False due to the error in the moe_kernels
min_latency_tensor = torch.empty(0)

# allocate workspace for profiling
moe_runner = MoERunner(
x_dtype=input.dtype,
Expand All @@ -175,21 +152,22 @@ def fused_moe(
enable_alltoall=enable_alltoall,
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
use_w4a8_group_scaling=use_w4a8_group_scaling,
min_latency_mode=min_latency_mode,
)

_, gemm_tactic_1 = tuner.choose_one(
"trtllm::fused_moe::gemm1",
[moe_runner],
MoERunner.tuning_config,
[input, fc1_expert_weights, fc2_expert_weights, min_latency_tensor],
[input, fc1_expert_weights, fc2_expert_weights],
gemm_idx=1,
)

_, gemm_tactic_2 = tuner.choose_one(
"trtllm::fused_moe::gemm2",
[moe_runner],
MoERunner.tuning_config,
[input, fc1_expert_weights, fc2_expert_weights, min_latency_tensor],
[input, fc1_expert_weights, fc2_expert_weights],
gemm_idx=2,
)

Expand Down