@@ -1627,7 +1627,7 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank,
1627
1627
1628
1628
#ifdef PADDLE_WITH_NVSHMEM
1629
1629
std::tuple<deep_ep::detail::Tensor,
1630
- deep_ep::detail::Tensor,
1630
+ std::optional< deep_ep::detail::Tensor> ,
1631
1631
deep_ep::detail::Tensor,
1632
1632
deep_ep::detail::Tensor,
1633
1633
deep_ep::detail::Tensor,
@@ -1637,6 +1637,7 @@ Buffer::low_latency_dispatch(const deep_ep::detail::Tensor& x,
1637
1637
const deep_ep::detail::Tensor& topk_idx,
1638
1638
int num_max_dispatch_tokens_per_rank,
1639
1639
int num_experts,
1640
+ bool use_fp8,
1640
1641
bool async,
1641
1642
bool return_recv_hook) {
1642
1643
EP_HOST_ASSERT (low_latency_mode);
@@ -1675,12 +1676,13 @@ Buffer::low_latency_dispatch(const deep_ep::detail::Tensor& x,
1675
1676
if (!return_recv_hook) stream_wait (launch_stream, compute_stream);
1676
1677
1677
1678
// Allocate packed tensors
1678
- auto packed_recv_x = ConvertPaddleTensorToDetailTensor (
1679
- paddle::experimental::empty ({num_local_experts,
1680
- num_ranks * num_max_dispatch_tokens_per_rank,
1681
- hidden},
1682
- phi::DataType::FLOAT8_E4M3FN,
1683
- x.place ()));
1679
+ auto packed_recv_x =
1680
+ ConvertPaddleTensorToDetailTensor (paddle::experimental::empty (
1681
+ {num_local_experts,
1682
+ num_ranks * num_max_dispatch_tokens_per_rank,
1683
+ hidden},
1684
+ use_fp8 ? phi::DataType::FLOAT8_E4M3FN : phi::DataType::BFLOAT16,
1685
+ x.place ()));
1684
1686
auto packed_recv_src_info =
1685
1687
ConvertPaddleTensorToDetailTensor (paddle::experimental::empty (
1686
1688
{num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank},
@@ -1695,25 +1697,32 @@ Buffer::low_latency_dispatch(const deep_ep::detail::Tensor& x,
1695
1697
{num_local_experts}, phi::DataType::INT32, phi::GPUPlace (device_id)));
1696
1698
1697
1699
// Allocate column-majored scales
1698
- EP_HOST_ASSERT ((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 &&
1699
- " TMA requires the number of tokens to be multiple of 4" );
1700
- auto packed_recv_x_scales =
1701
- ConvertPaddleTensorToDetailTensor (paddle::experimental::empty (
1702
- {num_local_experts,
1703
- num_scales,
1704
- num_ranks * num_max_dispatch_tokens_per_rank},
1705
- phi::DataType::FLOAT32,
1706
- phi::GPUPlace (device_id)));
1707
- packed_recv_x_scales =
1708
- ConvertPaddleTensorToDetailTensor (paddle::experimental::transpose (
1709
- ConvertDetailTensorToPaddleTensor (packed_recv_x_scales),
1710
- std::vector<int >{1 , 2 }));
1700
+ auto packed_recv_x_scales = std::optional<deep_ep::detail::Tensor>();
1701
+
1702
+ float * packed_recv_x_scales_ptr = nullptr ;
1703
+
1704
+ if (use_fp8) {
1705
+ EP_HOST_ASSERT ((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 &&
1706
+ " TMA requires the number of tokens to be multiple of 4" );
1707
+ packed_recv_x_scales =
1708
+ ConvertPaddleTensorToDetailTensor (paddle::experimental::empty (
1709
+ {num_local_experts,
1710
+ num_scales,
1711
+ num_ranks * num_max_dispatch_tokens_per_rank},
1712
+ phi::DataType::FLOAT32,
1713
+ phi::GPUPlace (device_id)));
1714
+ packed_recv_x_scales =
1715
+ ConvertPaddleTensorToDetailTensor (paddle::experimental::transpose (
1716
+ ConvertDetailTensorToPaddleTensor (packed_recv_x_scales.value ()),
1717
+ std::vector<int >{0 , 2 , 1 }));
1718
+ packed_recv_x_scales_ptr = packed_recv_x_scales.value ().data_ptr <float >();
1719
+ }
1711
1720
1712
1721
// Kernel launch
1713
1722
auto next_clean_meta = next_buffer.clean_meta ();
1714
1723
auto launcher = [=](int phases) {
1715
1724
internode_ll::dispatch (packed_recv_x.data_ptr (),
1716
- packed_recv_x_scales. data_ptr < float >() ,
1725
+ packed_recv_x_scales_ptr ,
1717
1726
packed_recv_src_info.data_ptr <int >(),
1718
1727
packed_recv_layout_range.data_ptr <int64_t >(),
1719
1728
packed_recv_count.data_ptr <int >(),
@@ -1731,6 +1740,7 @@ Buffer::low_latency_dispatch(const deep_ep::detail::Tensor& x,
1731
1740
num_experts,
1732
1741
rank,
1733
1742
num_ranks,
1743
+ use_fp8,
1734
1744
workspace,
1735
1745
launch_stream,
1736
1746
phases);
@@ -2092,7 +2102,7 @@ Buffer::internode_combine_api(
2092
2102
}
2093
2103
2094
2104
std::tuple<paddle::Tensor,
2095
- paddle::Tensor,
2105
+ std::optional< paddle::Tensor> ,
2096
2106
paddle::Tensor,
2097
2107
paddle::Tensor,
2098
2108
paddle::Tensor,
@@ -2102,6 +2112,7 @@ Buffer::low_latency_dispatch_api(const paddle::Tensor& x,
2102
2112
const paddle::Tensor& topk_idx,
2103
2113
int num_max_dispatch_tokens_per_rank,
2104
2114
int num_experts,
2115
+ bool use_fp8,
2105
2116
bool async,
2106
2117
bool return_recv_hook) {
2107
2118
#ifdef PADDLE_WITH_NVSHMEM
@@ -2112,12 +2123,18 @@ Buffer::low_latency_dispatch_api(const paddle::Tensor& x,
2112
2123
topk_idx_,
2113
2124
num_max_dispatch_tokens_per_rank,
2114
2125
num_experts,
2126
+ use_fp8,
2115
2127
async,
2116
2128
return_recv_hook);
2117
2129
2118
2130
auto packed_recv_x_ = ConvertDetailTensorToPaddleTensor (std::get<0 >(res));
2119
- auto packed_recv_x_scales_ =
2120
- ConvertDetailTensorToPaddleTensor (std::get<1 >(res));
2131
+
2132
+ std::optional<paddle::Tensor> packed_recv_x_scales_;
2133
+ if (std::get<1 >(res).has_value ()) {
2134
+ packed_recv_x_scales_ =
2135
+ ConvertDetailTensorToPaddleTensor (std::get<1 >(res).value ());
2136
+ }
2137
+
2121
2138
auto packed_recv_count_ = ConvertDetailTensorToPaddleTensor (std::get<2 >(res));
2122
2139
auto packed_recv_src_info_ =
2123
2140
ConvertDetailTensorToPaddleTensor (std::get<3 >(res));
0 commit comments