Skip to content

Commit 7735abf

Browse files
committed
Support fp8
1 parent e81a049 commit 7735abf

26 files changed

+688
-5
lines changed

paddle/fluid/framework/data_type_transform.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,10 @@ void TransDataType(const phi::DenseTensor& in,
194194
case proto::VarType::FP32:
195195
framework::VisitDataType(dst_type, CastDataType<float>(in, out, ctx));
196196
break;
197+
case proto::VarType::FP8:
198+
framework::VisitDataType(
199+
dst_type, CastDataType<platform::float8_e4m3>(in, out, ctx));
200+
break;
197201
case proto::VarType::FP64:
198202
framework::VisitDataType(dst_type, CastDataType<double>(in, out, ctx));
199203
break;

paddle/fluid/framework/framework.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ message VarType {
156156
BF16 = 22;
157157
COMPLEX64 = 23;
158158
COMPLEX128 = 24;
159+
FP8= 32;
159160

160161
// Other types that may need additional descriptions
161162
LOD_TENSOR = 7;

paddle/fluid/platform/float8_e4m3.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/phi/common/float8_e4m3.h"
18+
19+
namespace paddle {
20+
namespace platform {
21+
using float8_e4m3 = phi::dtype::float8_e4m3;
22+
using namespace phi::dtype; // NOLINT
23+
} // namespace platform
24+
} // namespace paddle

paddle/fluid/pybind/protobuf.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ void BindVarDesc(pybind11::module *m) {
290290
.value("INT64", pd::proto::VarType::INT64)
291291
.value("FP16", pd::proto::VarType::FP16)
292292
.value("FP32", pd::proto::VarType::FP32)
293+
.value("FP8", pd::proto::VarType::FP8)
293294
.value("FP64", pd::proto::VarType::FP64)
294295
.value("BF16", pd::proto::VarType::BF16)
295296
.value("COMPLEX64", pd::proto::VarType::COMPLEX64)

paddle/fluid/pybind/pybind.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ limitations under the License. */
7777
#include "paddle/fluid/memory/allocation/allocator_strategy.h"
7878
#include "paddle/fluid/platform/bfloat16.h"
7979
#include "paddle/fluid/platform/float16.h"
80+
#include "paddle/fluid/platform/float8_e4m3.h"
8081
#include "paddle/fluid/prim/utils/utils.h"
8182
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
8283
#include "paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.h"
@@ -2979,6 +2980,7 @@ All parameter, weight, gradient are variables in Paddle.
29792980
.value("COMPLEX128", phi::DataType::COMPLEX128)
29802981
.value("FLOAT16", phi::DataType::FLOAT16)
29812982
.value("BFLOAT16", phi::DataType::BFLOAT16)
2983+
.value("FLOAT8", phi::DataType::FLOAT8)
29822984
.export_values();
29832985

29842986
#if defined(PADDLE_WITH_PSLIB) && !defined(PADDLE_WITH_HETERPS)

paddle/fluid/pybind/tensor.cc

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ limitations under the License. */
6868
#include "paddle/fluid/imperative/amp_auto_cast.h"
6969
#include "paddle/fluid/imperative/layer.h"
7070
#include "paddle/fluid/memory/allocation/allocator_strategy.h"
71+
#include "paddle/fluid/memory/memcpy.h"
7172
#ifdef PADDLE_WITH_CUDA
7273
#include "paddle/fluid/memory/allocation/cuda_ipc_allocator.h"
7374
#endif
@@ -203,6 +204,28 @@ static void TensorCopyFrom(phi::DenseTensor *dst,
203204
}
204205
}
205206

207+
template <typename PlaceType>
208+
static void TensorCopyFromPaddleTensor(phi::DenseTensor *dst,
209+
const paddle::Tensor &src,
210+
const PlaceType &place,
211+
int64_t batch_size) {
212+
// paddle::memory::Copy(dst->place(),
213+
// dst->Holder()->ptr(),
214+
// place,
215+
// src.data(),
216+
// src.numel());
217+
218+
#if defined(PADDLE_WITH_CUDA)
219+
if (dst->place() == phi::GPUPlace() && place == phi::GPUPlace()) {
220+
cudaMemcpy(
221+
dst->Holder()->ptr(), src.data(), src.size(), cudaMemcpyDeviceToDevice);
222+
} else if (dst->place() == phi::CPUPlace() && place == phi::GPUPlace()) {
223+
cudaMemcpy(
224+
dst->Holder()->ptr(), src.data(), src.size(), cudaMemcpyDeviceToHost);
225+
}
226+
#endif
227+
}
228+
206229
void BindTensor(pybind11::module &m) { // NOLINT
207230
using namespace paddle::framework; // NOLINT
208231
py::class_<phi::DenseTensor> framework_tensor(
@@ -349,6 +372,16 @@ void BindTensor(pybind11::module &m) { // NOLINT
349372
py::arg("tensor"),
350373
py::arg("place"),
351374
py::arg("batch_size") = -1)
375+
.def("_copy_from_paddle_tensor",
376+
&TensorCopyFromPaddleTensor<paddle::platform::Place>,
377+
py::arg("tensor"),
378+
py::arg("place"),
379+
py::arg("batch_size") = -1)
380+
.def("_copy_from_paddle_tensor",
381+
&TensorCopyFromPaddleTensor<paddle::platform::CUDAPlace>,
382+
py::arg("tensor"),
383+
py::arg("place"),
384+
py::arg("batch_size") = -1)
352385
.def("set",
353386
SetTensorFromPyArray<paddle::platform::CPUPlace>,
354387
py::arg("array"),

paddle/phi/common/data_type.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License. */
1818
#include "paddle/phi/common/bfloat16.h"
1919
#include "paddle/phi/common/complex.h"
2020
#include "paddle/phi/common/float16.h"
21+
#include "paddle/phi/common/float8_e4m3.h"
2122
#include "paddle/utils/test_macros.h"
2223

2324
namespace phi {
@@ -33,6 +34,7 @@ using complex128 = ::phi::dtype::complex<double>;
3334
using float16 = ::phi::dtype::float16;
3435
using bfloat16 = ::phi::dtype::bfloat16;
3536
using pstring = ::phi::dtype::pstring;
37+
using float8 = ::phi::dtype::float8_e4m3;
3638

3739
// The enum value are consistent with jit/property.proto
3840
enum class TEST_API DataType {
@@ -70,6 +72,7 @@ enum class TEST_API DataType {
7072
// This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits.
7173
BFLOAT16,
7274

75+
FLOAT8,
7376
NUM_DATA_TYPES,
7477
// See Note [ Why we need ALL in basic kernel key member? ]
7578
ALL_DTYPE = UNDEFINED,
@@ -80,6 +83,7 @@ inline size_t SizeOf(DataType data_type) {
8083
case DataType::BOOL:
8184
case DataType::UINT8:
8285
case DataType::INT8:
86+
case DataType::FLOAT8:
8387
return 1;
8488
case DataType::BFLOAT16:
8589
case DataType::FLOAT16:
@@ -120,6 +124,7 @@ inline size_t SizeOf(DataType data_type) {
120124
_(int64_t, DataType::INT64) \
121125
_(uint64_t, DataType::UINT64) \
122126
_(bfloat16, DataType::BFLOAT16) \
127+
_(float8, DataType::FLOAT8) \
123128
_(float16, DataType::FLOAT16) \
124129
_(float, DataType::FLOAT32) \
125130
_(double, DataType::FLOAT64) \
@@ -188,6 +193,9 @@ inline std::ostream& operator<<(std::ostream& os, DataType dtype) {
188193
case DataType::BFLOAT16:
189194
os << "bfloat16";
190195
break;
196+
case DataType::FLOAT8:
197+
os << "float8";
198+
break;
191199
case DataType::FLOAT16:
192200
os << "float16";
193201
break;
@@ -262,6 +270,7 @@ using bfloat16 = phi::bfloat16;
262270
using complex64 = phi::complex64;
263271
using complex128 = phi::complex128;
264272
using float16 = phi::float16;
273+
using float8 = phi::float8;
265274
using pstring = phi::pstring;
266275

267276
} // namespace paddle

0 commit comments

Comments
 (0)