Skip to content

Commit f2d8333

Browse files
authored
add nvtx event for profiling (#66133)
1 parent a5b80d0 commit f2d8333

File tree

3 files changed

+31
-3
lines changed

3 files changed

+31
-3
lines changed

paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc

+11-2
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@
2424
#include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h"
2525
#include "paddle/fluid/platform/collective_helper.h"
2626
#include "paddle/fluid/platform/device_context.h"
27+
#include "paddle/fluid/platform/profiler/event_tracing.h"
2728
#include "paddle/phi/core/infermeta_utils.h"
2829
#include "paddle/phi/core/meta_tensor.h"
2930
#include "paddle/phi/core/type_defs.h"
30-
3131
#include "paddle/pir/include/core/builtin_attribute.h"
3232
#include "paddle/pir/include/core/operation.h"
3333
#include "paddle/pir/include/core/value.h"
@@ -181,14 +181,23 @@ PhiKernelInstruction::~PhiKernelInstruction() { delete phi_kernel_; }
181181
void PhiKernelInstruction::Run() {
182182
VLOG(6) << "Begin run op " << phi_op_name_ << " infer meta.";
183183
if (infer_meta_interface_) {
184+
platform::RecordEvent record_event("PhiKernelInstruction::infermeta",
185+
platform::TracerEventType::UserDefined,
186+
1);
184187
infer_meta_interface_->infer_meta_(&(infer_meta_context_));
185188
}
186189
VLOG(6) << "End run op " << phi_op_name_ << " infer meta.";
187190
for (auto& pair : this->InplaceInfo()) {
188191
ShareVarBuffer(pair.first, pair.second);
189192
}
190193
VLOG(6) << "Begin run op " << phi_op_name_ << " kernel.";
191-
(*(phi_kernel_))(&(kernel_context_));
194+
{
195+
platform::RecordEvent record_event("PhiKernelInstruction::kernel launch",
196+
platform::TracerEventType::UserDefined,
197+
1);
198+
(*(phi_kernel_))(&(kernel_context_));
199+
}
200+
192201
VLOG(6) << "End run op " << phi_op_name_ << " kernel.";
193202
}
194203

paddle/fluid/framework/new_executor/pir_interpreter.cc

+5-1
Original file line numberDiff line numberDiff line change
@@ -1897,7 +1897,11 @@ void PirInterpreter::RunInstructionBase(InstructionBase* instr_node) {
18971897
}
18981898

18991899
if (!instr_node->IsArtificial()) {
1900-
instr_node->Run();
1900+
{
1901+
platform::RecordEvent record(
1902+
"InstrRun", platform::TracerEventType::UserDefined, 10);
1903+
instr_node->Run();
1904+
}
19011905

19021906
if (FLAGS_benchmark) {
19031907
instr_node->DeviceContext().Wait();

paddle/fluid/ir_adaptor/translator/op_translator.cc

+15
Original file line numberDiff line numberDiff line change
@@ -3179,6 +3179,20 @@ struct RepeatInterLeaveGradOpTranscriber : public OpTranscriber {
31793179
}
31803180
};
31813181

3182+
struct TopPSamplingOpTranscriber : public OpTranscriber {
3183+
void HandleNonexistentAttribute(pir::IrContext* ctx,
3184+
pir::AttributeMap* attribute_map,
3185+
const OpAttributeInfo& info) override {
3186+
if (info.name == "seed") {
3187+
(*attribute_map)[info.name] = pir::Int32Attribute::get(ctx, -1);
3188+
} else if (info.name == "k") {
3189+
(*attribute_map)[info.name] = pir::Int32Attribute::get(ctx, 0);
3190+
} else if (info.name == "mode") {
3191+
(*attribute_map)[info.name] = pir::StrAttribute::get(ctx, "truncated");
3192+
}
3193+
}
3194+
};
3195+
31823196
struct FusedElemwiseAddActivationOpTranscriber : public OpTranscriber {
31833197
void HandleNonexistentAttribute(pir::IrContext* ctx,
31843198
pir::AttributeMap* attribute_map,
@@ -3629,6 +3643,7 @@ OpTranslator::OpTranslator() {
36293643
special_handlers["slice"] = SliceOpTranscriber();
36303644
special_handlers["split"] = SplitOpTranscriber();
36313645
special_handlers["sum"] = AddNOpTranscriber();
3646+
special_handlers["top_p_sampling"] = TopPSamplingOpTranscriber();
36323647
special_handlers["tril_triu"] = TrilAndTriuOpTranscriber();
36333648
special_handlers["tril_triu_grad"] = TrilAndTriuGradOpTranscriber();
36343649
special_handlers["matmul"] = LegacyMatmulOpTranscriber();

0 commit comments

Comments
 (0)