@@ -2504,6 +2504,88 @@ void FusedTransposeSplitQuantInferMeta(const MetaTensor& x,
2504
2504
" x.shape[1] (%d) must be <= 65535 * 128" , N));
2505
2505
}
2506
2506
2507
+ void FusedTransposeWLCHSplitQuantInferMeta (const MetaTensor& x,
2508
+ const IntArray& tokens_per_expert,
2509
+ bool pow_2_scales,
2510
+ std::vector<MetaTensor*> outs,
2511
+ std::vector<MetaTensor*> scales) {
2512
+ PADDLE_ENFORCE_EQ (
2513
+ x.dtype (),
2514
+ DataType::BFLOAT16,
2515
+ common::errors::InvalidArgument (
2516
+ " The dtype of Input(x) must be BFLOAT16, but received %s" ,
2517
+ x.dtype ()));
2518
+
2519
+ auto x_dims = x.dims ();
2520
+
2521
+ PADDLE_ENFORCE_EQ (
2522
+ x_dims.size (),
2523
+ 4 ,
2524
+ common::errors::InvalidArgument (
2525
+ " Input(x) must have dimension of 4, but got %d." , x_dims.size ()));
2526
+
2527
+ const int64_t M = x_dims[0 ] * x_dims[1 ] * x_dims[2 ];
2528
+ const int64_t H = x_dims[3 ];
2529
+
2530
+ auto tokens_list = tokens_per_expert.GetData ();
2531
+ const size_t num_experts = tokens_list.size ();
2532
+
2533
+ PADDLE_ENFORCE_EQ (
2534
+ outs.size (),
2535
+ num_experts,
2536
+ common::errors::InvalidArgument (
2537
+ " Size of outs (%d) must equal size of tokens_per_expert (%d)" ,
2538
+ outs.size (),
2539
+ num_experts));
2540
+
2541
+ PADDLE_ENFORCE_EQ (
2542
+ scales.size (),
2543
+ num_experts,
2544
+ common::errors::InvalidArgument (
2545
+ " Size of scales (%d) must equal size of tokens_per_expert (%d)" ,
2546
+ scales.size (),
2547
+ num_experts));
2548
+
2549
+ int64_t sum_tokens = 0 ;
2550
+ for (size_t i = 0 ; i < num_experts; ++i) {
2551
+ const int64_t tokens = tokens_list[i];
2552
+
2553
+ PADDLE_ENFORCE_EQ (
2554
+ tokens % 128 ,
2555
+ 0 ,
2556
+ common::errors::InvalidArgument (
2557
+ " tokens_per_expert[%d] (%d) must be divisible by 128" , i, tokens));
2558
+
2559
+ sum_tokens += tokens;
2560
+
2561
+ if (outs[i] != nullptr ) {
2562
+ outs[i]->set_dims (common::make_ddim ({H, tokens}));
2563
+ outs[i]->set_dtype (DataType::FLOAT8_E4M3FN);
2564
+ outs[i]->set_layout (x.layout ());
2565
+ }
2566
+
2567
+ if (scales[i] != nullptr ) {
2568
+ scales[i]->set_dims (common::make_ddim ({tokens / 128 , H}));
2569
+ scales[i]->set_dtype (DataType::FLOAT32);
2570
+ scales[i]->set_layout (x.layout ());
2571
+ }
2572
+ }
2573
+
2574
+ PADDLE_ENFORCE_EQ (
2575
+ sum_tokens,
2576
+ M,
2577
+ common::errors::InvalidArgument (" Sum of tokens_per_expert (%d) must "
2578
+ " equal the upper dims of Input(x) (%d)" ,
2579
+ sum_tokens,
2580
+ M));
2581
+ PADDLE_ENFORCE_LE (
2582
+ H,
2583
+ 65535 * 128 ,
2584
+ common::errors::InvalidArgument (" Currently only supports the hidden size "
2585
+ " of Input(x) <= 65535 * 128, but got %d." ,
2586
+ H));
2587
+ }
2588
+
2507
2589
void YoloBoxXPUInferMeta (const MetaTensor& x,
2508
2590
const MetaTensor& x_max,
2509
2591
const MetaTensor& grid,
0 commit comments