From 77ca130a24d115dcc007e7da796ff4e8b4dfce40 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Thu, 17 Apr 2025 22:02:05 +0800 Subject: [PATCH 1/5] [Dy2St] Cleanup legacy run program op --- paddle/fluid/operators/CMakeLists.txt | 4 +- paddle/fluid/operators/run_program_op.cc | 265 ------------------ paddle/fluid/operators/unity_build_rule.cmake | 12 +- paddle/fluid/pybind/CMakeLists.txt | 3 +- 4 files changed, 5 insertions(+), 279 deletions(-) delete mode 100644 paddle/fluid/operators/run_program_op.cc diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 9bdafeeb87ebbf..d7f3e091bffcd3 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -78,12 +78,10 @@ endif() set(OP_HEADER_DEPS ${OP_HEADER_DEPS} phi common phi_utils static_prim_api get_expected_kernel_func) -register_operators(EXCLUDES py_func_op generated_op1 generated_op2 generated_op3 generated_op4 load_combine_op run_program_op quantize_linear_op +register_operators(EXCLUDES py_func_op generated_op1 generated_op2 generated_op3 generated_op4 load_combine_op quantize_linear_op save_combine_op sync_batch_norm_op activation_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS} processgroup_comm_utils) op_library(generated_op UNITY SRCS generated_op1.cc generated_op2.cc generated_op3.cc generated_op4.cc DEPS ${OP_HEADER_DEPS}) -op_library(run_program_op DEPS executor_cache ${OP_HEADER_DEPS}) -target_link_libraries(run_program_op phi common) op_library(quantize_linear_op DEPS phi common) op_library(save_combine_op DEPS phi) op_library(load_combine_op DEPS phi) diff --git a/paddle/fluid/operators/run_program_op.cc b/paddle/fluid/operators/run_program_op.cc deleted file mode 100644 index 526a1ced5a502b..00000000000000 --- a/paddle/fluid/operators/run_program_op.cc +++ /dev/null @@ -1,265 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include - -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle::operators { - -using BlockDesc = framework::BlockDesc; - -class RunProgramOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ(ctx->HasInputs("X"), - true, - common::errors::NotFound( - "Input(X) of RunProgramOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasOutputs("Out"), - true, - common::errors::NotFound( - "Output(Out) of RunProgramOp should not be null.")); - } - - protected: - /* [Why use single type kernel]: - * - * This op is similar to a control flow op, it doses not need - * a op kernel, but in order to make it execute under dynamic - * graph mode, implement it with op kernel. - * - * So whether the kernel data type is int, float or other type, - * which has no effect on its execution logic, so directly - * specified a data type here. - * - * Of course, the data type here is also not important. - */ - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); - } - - phi::KernelKey GetKernelTypeForVar( - const std::string& var_name, - const phi::DenseTensor& tensor, - const phi::KernelKey& expected_kernel_type) const override { - return phi::KernelKey(phi::Backend::ALL_BACKEND, - expected_kernel_type.layout(), - expected_kernel_type.dtype()); - } -}; - -class RunProgramOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", - "(vector)" - "The input tensors of RunProgram operator, also the feed targets " - "of loaded program.") - .AsDuplicable(); - AddInput("Params", - "(vector)" - "The input parameter of RunProgram operator, also the parameters " - "of the loaded program.") - .AsDuplicable() - .AsDispensable(); - AddOutput("Out", - "(vector)" - "The output tensors of RunProgram operator, also the fetch " - "targets of the loaded program.") - .AsDuplicable(); - AddOutput("OutScope", - "(StepScopeVar)" - "A vector of execution scope in RunProgram operator, which " - "contains at most one scope." - "NOTE: Do not use Scope directly because Scope output is not " - "currently supported."); - AddOutput("DOut", - "(vector)" - "The output tensors for GRAD Tensors in RunProgram forward " - "operator, the forward operator contains GRAD Tensors when it " - "computes double grad.") - .AsDuplicable() - .AsDispensable(); - AddOutput("CUDAGraph", "The output CUDA Graph when use_cuda_graph=True.") - .AsDispensable(); - AddAttr("global_block", - "(BlockDesc *)" - "The global block of executed program desc."); - AddAttr("start_op_index", - "(int64_t)" - "The index of the op to start execution"); - AddAttr("end_op_index", - "(int64_t)" - "The index of the op to stop execution"); - AddAttr("is_test", - "(bool, default false) Set to true for inference only, false " - "for training.") - .SetDefault(false); - AddAttr( - "in_pir_pt_mode", - "(bool, default false) Set to true when need to run in pir mode") - .SetDefault(false); - AddAttr( - "program_id", - "(int64_t)" - "The unique hash id used as cache key for ExecutorInfoCache."); - AddAttr("cuda_graph_capture_mode", - "(str, default '') The CUDA Graph capture mode. " - "Default '' means no CUDA Graph capturing.") - .SetDefault(""); - AddAttr("cuda_graph_pool_id", - "(int64_t, default 0) The CUDA Graph memory pool ID.") - .SetDefault(0); - AddAttr("use_interpretorcore", - "(bool, default false) Set to true for use interpretercore.") - .SetDefault(false); - AddAttr("forward_global_block", - "(BlockDesc *)" - "The global block of executed forward program desc.") - .SetDefault(nullptr); - AddAttr("backward_global_block", - "(BlockDesc *)" - "The global block of executed backward program desc.") - .SetDefault(nullptr); - AddAttr>("param_grad_names", - "std::vector" - "The names of parameter gradients.") - .SetDefault({}); - AddAttr>("out_grad_names", - "std::vector" - "The names of output gradients.") - .SetDefault({}); - AddAttr>("x_names", - "std::vector" - "The names of input tensors.") - .SetDefault({}); - AddAttr>("x_grad_names", - "std::vector" - "The names of input gradients.") - .SetDefault({}); - AddComment(R"DOC( -RunProgram operator. - -The RunProgram operator receives a program's feed targets, fetch targets, -and parameters, and receives the forward and backward program desc -as attributes, and then executes the program by executor. - -NOTE: This operator is added so that the inference model stored by -`fluid.io.save_inference_model` under the static graph mode can be loaded -under the dynamic graph mode for fine-tuning or inferencing. - -)DOC"); - } -}; - -class RunProgramGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ(ctx->HasInputs("X"), - true, - common::errors::NotFound( - "Input(X) of RunProgramGradOp should not be null.")); - PADDLE_ENFORCE_EQ( - ctx->HasInputs(framework::GradVarName("Out")), - true, - common::errors::NotFound( - "Input(Out@GRAD) of RunProgramGradOp should not be null.")); - // NOTE: The X@GRAD and Params@GRAD may not exist, - // because they can be set stop_gradient = True - } - - protected: - /* see [Why use single type kernel] */ - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); - } - - phi::KernelKey GetKernelTypeForVar( - const std::string& var_name, - const phi::DenseTensor& tensor, - const phi::KernelKey& expected_kernel_type) const override { - return phi::KernelKey(phi::Backend::ALL_BACKEND, - expected_kernel_type.layout(), - expected_kernel_type.dtype()); - } -}; - -template -struct FilterHelper {}; - -template <> -struct FilterHelper { - static void filter(const BlockDesc* desc, - imperative::TracedVarList* vec) { - auto f = [desc](std::shared_ptr ptr) { - return !desc->HasVar(ptr->Name()); - }; - auto new_end = std::remove_if(vec->begin(), vec->end(), f); - vec->resize(new_end - vec->begin()); - } -}; - -template <> -struct FilterHelper { - static void filter(const BlockDesc* desc, std::vector* vec) { - auto f = [desc](const std::string& name) { return !desc->HasVar(name); }; - auto new_end = std::remove_if(vec->begin(), vec->end(), f); - vec->resize(new_end - vec->begin()); - } -}; - -template -class RunProgramGradOpMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr grad_op) const override { - grad_op->SetType("run_program_grad"); - grad_op->SetInput("X", this->Input("X")); - grad_op->SetInput("Params", this->Input("Params")); - grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - grad_op->SetInput("OutScope", this->Output("OutScope")); - grad_op->SetInput("DOut", this->Output("DOut")); - if (this->HasOutput("CUDAGraph")) { - grad_op->SetInput("CUDAGraph", this->Output("CUDAGraph")); - } - grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - - auto block_desc = - PADDLE_GET_CONST(BlockDesc*, this->GetAttr("global_block")); - auto params_grad = this->InputGrad("Params"); - FilterHelper::filter(block_desc, ¶ms_grad); // filter the vector. - grad_op->SetOutput(framework::GradVarName("Params"), params_grad); - grad_op->SetAttrMap(this->Attrs()); - } -}; - -} // namespace paddle::operators - -namespace ops = paddle::operators; -REGISTER_OPERATOR(run_program, - ops::RunProgramOp, - ops::RunProgramOpMaker, - ops::RunProgramGradOpMaker, - ops::RunProgramGradOpMaker); -REGISTER_OPERATOR(run_program_grad, ops::RunProgramGradOp); diff --git a/paddle/fluid/operators/unity_build_rule.cmake b/paddle/fluid/operators/unity_build_rule.cmake index fca285a9a552c7..6544ea6bb06084 100644 --- a/paddle/fluid/operators/unity_build_rule.cmake +++ b/paddle/fluid/operators/unity_build_rule.cmake @@ -216,7 +216,6 @@ register_unity_group( cc roi_align_op.cc roll_op.cc - run_program_op.cc sampling_id_op.cc save_combine_op.cc save_op.cc @@ -317,7 +316,6 @@ register_unity_group( partial_concat_op.cc pyramid_hash_op.cc recurrent_op.cc - run_program_op.cc softmax_with_cross_entropy_op.cc warpctc_op.cc) register_unity_group(cc lstm_op.cu.cc rnn_op.cu.cc split_op.cu.cc @@ -439,14 +437,8 @@ register_unity_group( pad3d_op.cu pad_constant_like_op.cu pad_op.cu) -register_unity_group( - cu - partial_sum_op.cu - pixel_shuffle_op.cu - prelu_op.cu - run_program_op.cu - pull_box_extended_sparse_op.cu - pull_box_sparse_op.cu) +register_unity_group(cu partial_sum_op.cu pixel_shuffle_op.cu prelu_op.cu + pull_box_extended_sparse_op.cu pull_box_sparse_op.cu) register_unity_group(cu range_op.cu reverse_op.cu partial_concat_op.cu kldiv_loss_op.cu instance_norm_op.cu) register_unity_group( diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index d1c3ebde7a5fda..c22f2aee0116e1 100755 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -44,7 +44,8 @@ set(PYBIND_DEPS prim_utils detail_op_handle type_info - auto_parallel) + auto_parallel + executor_cache) if(WITH_GPU) list(APPEND PYBIND_DEPS gpu_event_timer) From 9871796e5bc7e09c8708500ac3b7d079189a792b Mon Sep 17 00:00:00 2001 From: SigureMo Date: Thu, 17 Apr 2025 22:09:52 +0800 Subject: [PATCH 2/5] remove ut --- .../test_run_program_op_deprecated.py | 533 ------------------ 1 file changed, 533 deletions(-) delete mode 100644 test/deprecated/legacy_test/test_run_program_op_deprecated.py diff --git a/test/deprecated/legacy_test/test_run_program_op_deprecated.py b/test/deprecated/legacy_test/test_run_program_op_deprecated.py deleted file mode 100644 index 0e84d9227add17..00000000000000 --- a/test/deprecated/legacy_test/test_run_program_op_deprecated.py +++ /dev/null @@ -1,533 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import contextlib -import os -import unittest - -import numpy as np - -import paddle -from paddle import _legacy_C_ops, base -from paddle.base import core, framework -from paddle.base.dygraph.base import switch_to_static_graph - -paddle.enable_static() - - -@contextlib.contextmanager -def program_scope_guard(): - prog = base.Program() - startup_prog = base.Program() - scope = base.core.Scope() - with base.scope_guard(scope): - with base.program_guard(prog, startup_prog): - with base.unique_name.guard(): - yield - - -@switch_to_static_graph -def _add_build_strategy_for(input_program, start_op_index, end_op_index): - compiled_program = paddle.static.CompiledProgram( - core.Graph(input_program.desc, start_op_index, end_op_index), - build_strategy=paddle.static.BuildStrategy(), - ) - compiled_program._compile( - core.Scope(), paddle.framework._current_expected_place() - ) - ir_graph = paddle.base.framework.IrGraph(compiled_program._graph) - built_program = ir_graph.to_program() - return built_program - - -@switch_to_static_graph -def _build_program_by_desc(program_desc): - prog = framework.Program() - prog.desc = program_desc - prog.blocks = [ - framework.Block(prog, i) for i in range(prog.desc.num_blocks()) - ] - prog._sync_with_cpp() - return prog - - -# NOTE: Because RunProgramOp has a special output of type std::vector, -# the OpTest cannot be used in RunProgramOp. The variable type cannot be specified -# when creating output variables in OpTest, default type is DenseTensor -# NOTE: the gradient test method in OpTest also cannot be used for RunProgramOp, -# because it hold BlockDesc type attr, OperatorFactory can't parse this attr type -# when create Operator, so here compare gradients with static graph -# NOTE: Here rewrite a simple unittest framework for RunProgramOp -class RunProgramOpTest(unittest.TestCase): - def build_model(self): - raise NotImplementedError( - "RunProgramOp test should implement build_model" - ) - - def check_output(self): - places = [] - if ( - os.environ.get('FLAGS_CI_both_cpu_and_gpu', 'False').lower() - in ['1', 'true', 'on'] - or not core.is_compiled_with_cuda() - ): - places.append(base.CPUPlace()) - if core.is_compiled_with_cuda(): - places.append(base.CUDAPlace(0)) - for place in places: - # TODO: RunProgramOp is not recommended for use in static graph mode now - self.expect_outs = self.run_static_model(place, is_test=True) - self.check_output_with_place(place) - - def check_grad(self): - places = [] - if ( - os.environ.get('FLAGS_CI_both_cpu_and_gpu', 'False').lower() - in ['1', 'true', 'on'] - or not core.is_compiled_with_cuda() - ): - places.append(base.CPUPlace()) - if core.is_compiled_with_cuda(): - places.append(base.CUDAPlace(0)) - for place in places: - # TODO: RunProgramOp is not recommended for use in static graph mode now - self.expect_grads = self.run_static_model(place, is_test=False) - self.check_grad_with_place(place) - - def run_static_model(self, place, is_test=True): - with program_scope_guard(): - startup_program = base.default_startup_program() - main_program = base.default_main_program() - - self.build_model() - - exe = base.Executor(place) - exe.run(startup_program) - - if is_test: - fetch_list = self.output_names['Out'] - else: - fetch_list = self.get_param_grad_names() - - outs = exe.run( - main_program, feed=self.inputs['X'], fetch_list=fetch_list - ) - return outs - - def get_program_desc(self): - with program_scope_guard(): - fwd_op_num = self.build_model() - return base.default_main_program().desc, fwd_op_num - - def get_forward_backward_program_desc( - self, whole_program_desc, forward_op_num, output_num - ): - program = _build_program_by_desc(whole_program_desc) - forward_program = _add_build_strategy_for(program, 0, forward_op_num) - backward_program = _add_build_strategy_for( - program, - forward_op_num + output_num, - program.desc.block(0).op_size(), - ) - return forward_program.desc, backward_program.desc - - def prepare_attrs(self): - return [ - 'global_block', - self.program_desc.block(0), - 'start_op_index', - 0, - 'end_op_index', - self.fwd_op_num, - 'program_id', - paddle.utils._hash_with_id(self.program_desc, self), - ] - - def get_param_grad_names(self): - grad_names = [] - for var_name in self.inputs['Params']: - grad_names.append(var_name + core.grad_var_suffix()) - return grad_names - - def check_output_with_place(self, place): - # Step 1. run op - actual_outs = self.calc_dygraph_output(place) - - # Step 2. compare output - for expect_v, actual_v in zip(self.expect_outs, actual_outs): - np.testing.assert_allclose( - expect_v, actual_v.numpy(), rtol=1e-05, atol=1e-05 - ) - - def check_grad_with_place(self, place): - # Step 1. calc grads - actual_grads = self.calc_dygraph_grad(place) - - # Step 2. compare grads - for expect_v, actual_v in zip(self.expect_grads, actual_grads): - np.testing.assert_array_almost_equal(expect_v, actual_v) - np.testing.assert_allclose( - expect_v, actual_v, rtol=1e-05, atol=1e-05 - ) - - def prepare_dygraph_input(self, place, return_param_list=False): - def create_var_base(is_input, name, np_value, stop_gradient): - var = core.eager.Tensor( - value=np_value, name=name, place=place, zero_copy=True - ) - var.stop_gradient = stop_gradient - return var - - # build inputs - inputs = {} - param_list = [] - inputs['X'] = [] - for name, np_value in self.inputs['X'].items(): - var = create_var_base(True, name, np_value, True) - inputs['X'].append(var) - inputs['Params'] = [] - for name, np_value in self.inputs['Params'].items(): - var = create_var_base(True, name, np_value, False) - inputs['Params'].append(var) - if return_param_list: - param_list.append(var) - - if return_param_list: - return inputs, param_list - return inputs - - def prepare_dygraph_output(self): - def create_var_base(is_input, name): - var = framework._create_tensor(dtype=None, shape=None, name=name) - var.stop_gradient = False - return var - - # build outputs - outputs = {} - outputs['Out'] = [] - for name in self.output_names['Out']: - outputs['Out'].append(create_var_base(False, name)) - - outputs['OutScope'] = [core.Scope()] - - return outputs - - def calc_dygraph_output(self, place): - self.program_desc, self.fwd_op_num = self.get_program_desc() - self.attrs = self.prepare_attrs() - - with base.dygraph.guard(place): - inputs = self.prepare_dygraph_input(place) - outputs = self.prepare_dygraph_output() - - ( - forward_program_desc, - backward_program_desc, - ) = self.get_forward_backward_program_desc( - self.program_desc, self.fwd_op_num, len(outputs['Out']) - ) - - use_interpretorcore = True - self.attrs.extend(('use_interpretorcore', use_interpretorcore)) - if use_interpretorcore: - self.attrs.extend( - ( - 'forward_global_block', - forward_program_desc.block(0), - 'backward_global_block', - backward_program_desc.block(0), - ) - ) - - self.attrs.extend( - ( - 'param_grad_names', - [p.name + '@GRAD' for p in inputs['Params']], - 'out_grad_names', - [out.name + '@GRAD' for out in outputs['Out']], - 'x_grad_names', - [p.name + '@GRAD' for p in inputs['X']], - 'x_names', - [t.name for t in inputs['X']], - ) - ) - - _legacy_C_ops.run_program( - inputs['X'], - inputs['Params'], - outputs['Out'], - outputs['OutScope'], - None, - *self.attrs, - ) - - return outputs['Out'] - - def calc_dygraph_grad(self, place): - self.program_desc, self.fwd_op_num = self.get_program_desc() - self.attrs = self.prepare_attrs() - - with base.dygraph.guard(place): - # Step 1. run forward - inputs, input_param_list = self.prepare_dygraph_input(place, True) - outputs = self.prepare_dygraph_output() - - ( - forward_program_desc, - backward_program_desc, - ) = self.get_forward_backward_program_desc( - self.program_desc, self.fwd_op_num, len(outputs['Out']) - ) - - use_interpretorcore = True - self.attrs.extend(('use_interpretorcore', use_interpretorcore)) - if use_interpretorcore: - self.attrs.extend( - ( - 'forward_global_block', - forward_program_desc.block(0), - 'backward_global_block', - backward_program_desc.block(0), - ) - ) - - self.attrs.extend( - ( - 'param_grad_names', - [p.name + '@GRAD' for p in inputs['Params']], - 'out_grad_names', - [out.name + '@GRAD' for out in outputs['Out']], - 'x_grad_names', - [p.name + '@GRAD' for p in inputs['X']], - 'x_names', - [t.name for t in inputs['X']], - ) - ) - - _legacy_C_ops.run_program( - inputs['X'], - inputs['Params'], - outputs['Out'], - outputs['OutScope'], - None, - *self.attrs, - ) - - for param in input_param_list: - var_type = self._get_grad_vartype(param.name) - if var_type is None: - continue - param._set_grad_type(var_type) - - # Step 2. run backward - # NOTE: in unittest, only support single output now - actual_outs = outputs['Out'] - assert len(actual_outs) == 1 - actual_outs[0].backward() - - # Step 3. prepare grads - grads = [] - for param in input_param_list: - grad = param.gradient() - grads.append(grad) - return grads - - def _get_grad_vartype(self, name): - assert self.program_desc is not None - grad_name = name + core.grad_var_suffix() - for i in range(self.program_desc.num_blocks()): - block = self.program_desc.block(i) - var_desc = block.find_var_recursive(grad_name.encode()) - return var_desc.type() if var_desc is not None else None - - -class TestRunProgramOpWithFC(RunProgramOpTest): - def setUp(self): - self.op_type = "run_program" - self.dtype = np.float32 - self.input_names = { - 'X': ['img'], - 'Params': ['weight_param', 'bias_param'], - } - self.output_names = {'Out': ['fc_0.tmp_2']} - - self.inputs = { - 'X': { - self.input_names['X'][0]: np.random.random( - (32, 1, 28, 28) - ).astype(self.dtype) - }, - 'Params': { - self.input_names['Params'][0]: np.random.random( - (784, 10) - ).astype(self.dtype), - self.input_names['Params'][1]: np.random.random( - (32, 10) - ).astype(self.dtype), - }, - } - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad() - - def build_model(self): - # 1. simple model - img = paddle.static.data( - name=self.input_names['X'][0], - shape=[None, 1, 28, 28], - dtype='float32', - ) - weight_attr = base.ParamAttr( - name=self.input_names['Params'][0], - learning_rate=0.5, - initializer=paddle.nn.initializer.Assign( - self.inputs['Params'][self.input_names['Params'][0]] - ), - trainable=True, - ) - bias_attr = base.ParamAttr( - name=self.input_names['Params'][1], - learning_rate=0.5, - initializer=paddle.nn.initializer.Assign( - self.inputs['Params'][self.input_names['Params'][1]] - ), - trainable=True, - ) - pred = paddle.static.nn.fc( - x=img, - size=10, - weight_attr=weight_attr, - bias_attr=bias_attr, - activation='relu', - ) - # 2. get forward op num - fwd_op_num = base.default_main_program().global_block().desc.op_size() - # 3. append backward - grads = base.backward.gradients(targets=[pred], inputs=[img]) - - return fwd_op_num - - -class TestRunProgramOpWithEmbedding(RunProgramOpTest): - def setUp(self): - self.op_type = "run_program" - self.dtype = np.float32 - self.input_names = {'X': ['x'], 'Params': ['emb_weight']} - self.output_names = {'Out': ['sum_0.tmp_0']} - - self.inputs = { - 'X': {'x': np.array([[1, 3, 0, 4, 7]]).astype("int64")}, - 'Params': { - 'emb_weight': np.random.random(size=(10, 16)).astype("float32") - }, - } - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - # NOTE: fetch not support SelectedRows, cannot compare - # sparse gradients with static mode, only run dygraph - places = [] - if ( - os.environ.get('FLAGS_CI_both_cpu_and_gpu', 'False').lower() - in ['1', 'true', 'on'] - or not core.is_compiled_with_cuda() - ): - places.append(base.CPUPlace()) - if core.is_compiled_with_cuda(): - places.append(base.CUDAPlace(0)) - for place in places: - # TODO: RunProgramOp is not recommended for use in static graph mode now - self.calc_dygraph_grad(place) - - def build_model(self): - # 1. simple model - x = paddle.static.data( - name=self.input_names['X'][0], shape=[-1, 5], dtype='int64' - ) - emb = paddle.static.nn.embedding( - input=x, - size=[10, 16], - param_attr=base.ParamAttr( - name="emb_weight", - learning_rate=10, - initializer=paddle.nn.initializer.Assign( - self.inputs['Params'][self.input_names['Params'][0]] - ), - ), - is_sparse=True, - ) - y = paddle.sum(emb, axis=-1) - # 2. get forward op num - fwd_op_num = base.default_main_program().global_block().desc.op_size() - # 3. append backward - grads = base.backward.gradients(targets=[y], inputs=[x]) - - return fwd_op_num - - -class Net(paddle.nn.Layer): - def __init__(self): - super().__init__() - self.fc1 = paddle.nn.Linear(10, 10) - self.fc2 = paddle.nn.Linear(10, 1) - - def forward(self, x): - out = self.fc1(x) - out.stop_gradient = True - out = self.fc2(out) - return out - - -class TestParametersWithStopGradient(unittest.TestCase): - def setUp(self): - self.seed = 2021 - self.iter = 5 - - def train(self, to_static): - # prepare env - paddle.seed(self.seed) - - net = Net() - if to_static: - net = paddle.jit.to_static(net, full_graph=True) - sgd = paddle.optimizer.SGD(0.01, parameters=net.parameters()) - - for i in range(self.iter): - x = paddle.rand([4, 10]) - out = net(x) - loss = paddle.mean(out) - - loss.backward() - sgd.minimize(loss) - net.clear_gradients() - - return loss - - def test_stop_gradient(self): - paddle.disable_static() - - dy_loss = self.train(to_static=False) - st_loss = self.train(to_static=True) - self.assertEqual(dy_loss, st_loss) - - paddle.enable_static() - - -if __name__ == "__main__": - unittest.main() From 81c3272a4bb03081a35ce5e3c7663d89777c763f Mon Sep 17 00:00:00 2001 From: SigureMo Date: Thu, 17 Apr 2025 23:22:34 +0800 Subject: [PATCH 3/5] add pir_transforms to test_tensorrt_engine_instruction deps --- test/cpp/inference/tensorrt/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/test/cpp/inference/tensorrt/CMakeLists.txt b/test/cpp/inference/tensorrt/CMakeLists.txt index cb68443c986db3..ebb48c20c836db 100644 --- a/test/cpp/inference/tensorrt/CMakeLists.txt +++ b/test/cpp/inference/tensorrt/CMakeLists.txt @@ -10,6 +10,7 @@ if(${TENSORRT_VERSION_NUMBER} GREATER_EQUAL 85) phi common pir_save_load + pir_transforms pir_tensorrt_plugin) set_tests_properties(test_tensorrt_engine_instruction PROPERTIES TIMEOUT 120) if(WITH_ONNXRUNTIME AND WIN32) From 80bd82aa154057c236e985bc61879b9766ae63ec Mon Sep 17 00:00:00 2001 From: SigureMo Date: Fri, 18 Apr 2025 02:28:12 +0800 Subject: [PATCH 4/5] write a new `ConstructAttrMapForLegacyRunProgram` for legacy run program op --- .../pybind/eager_legacy_custom_python_api.h | 3 +- paddle/fluid/pybind/op_function_common.cc | 74 +++++++++++++++++-- paddle/fluid/pybind/op_function_common.h | 7 ++ 3 files changed, 75 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/pybind/eager_legacy_custom_python_api.h b/paddle/fluid/pybind/eager_legacy_custom_python_api.h index 807008bde97754..9c4797e2b70d8b 100644 --- a/paddle/fluid/pybind/eager_legacy_custom_python_api.h +++ b/paddle/fluid/pybind/eager_legacy_custom_python_api.h @@ -43,8 +43,7 @@ static PyObject *eager_api_run_program(PyObject *self, // TOREMOVE Out = GetTensorPtrListFromArgs("run_program", "Out", args, 2, true, mesh); } framework::AttributeMap attrs; - // TODO(zengjinle): support CUDA Graph on eager mode - ConstructAttrMapFromPyArgs( + ConstructAttrMapForLegacyRunProgram( "run_program", args, 5, PyTuple_GET_SIZE(args), attrs); tstate = PyEval_SaveThread(); diff --git a/paddle/fluid/pybind/op_function_common.cc b/paddle/fluid/pybind/op_function_common.cc index b35da88b61f444..c188c468cfea06 100644 --- a/paddle/fluid/pybind/op_function_common.cc +++ b/paddle/fluid/pybind/op_function_common.cc @@ -1135,7 +1135,7 @@ void ConstructAttrMapFromPyArgs( } } -void ConstructAttrMapForRunProgram( +void ConstructAttrMapForLegacyRunProgram( const std::string& op_type, PyObject* args, ssize_t attr_start, @@ -1156,17 +1156,76 @@ void ConstructAttrMapForRunProgram( ssize_t); // Static map from keys to casting function pointers static const std::unordered_map kAttrFuncMap = { + {"forward_global_block", CastPyArg2AttrBlock}, + {"backward_global_block", CastPyArg2AttrBlock}, + {"is_test", CastPyArg2AttrBoolean}, + {"program_id", CastPyArg2AttrLong}, + {"param_grad_names", CastPyArg2AttrStrings}, + {"x_names", CastPyArg2AttrStrings}, + {"out_grad_names", CastPyArg2AttrStrings}, + {"x_grad_names", CastPyArg2AttrStrings}, {"cuda_graph_capture_mode", CastPyArg2AttrString}, - {"global_block", CastPyArg2AttrIRBlock}, + {"cuda_graph_pool_id", CastPyArg2AttrLong}, + {"in_pir_pt_mode", CastPyArg2AttrBoolean}, + }; + + PyObject* obj = nullptr; + for (ssize_t arg_pos = attr_start; arg_pos < attr_end; arg_pos += 2) { + VLOG(3) << "Start Process " << arg_pos; + Py_ssize_t key_len = 0; + const char* key_ptr = nullptr; + obj = PyTuple_GET_ITEM(args, arg_pos); + if (PyObject_CheckString(obj)) { + key_ptr = PyUnicode_AsUTF8AndSize(obj, &key_len); + } else { + PADDLE_THROW(common::errors::InvalidArgument( + "%s(): argument (position %d) must be str, but got %s", + op_type, + arg_pos, + ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT + } + std::string_view key_view(key_ptr, static_cast(key_len)); + VLOG(3) << "Start Process " << key_view; + obj = PyTuple_GET_ITEM(args, arg_pos + 1); + auto it = kAttrFuncMap.find(std::string(key_view)); + if (it != kAttrFuncMap.end()) { + // Call Cast function + it->second(obj, attrs, std::string(key_view), op_type, arg_pos); + } else { + PADDLE_THROW(common::errors::InvalidArgument( + "%.*s is not defined in this function.", + static_cast(key_view.size()), + key_view.data())); // NOLINT + } + } +} + +void ConstructAttrMapForRunProgram( + const std::string& op_type, + PyObject* args, + ssize_t attr_start, + ssize_t attr_end, + paddle::framework::AttributeMap& attrs) { // NOLINT + PADDLE_ENFORCE_EQ((attr_end - attr_start) % 2, + 0, + common::errors::InvalidArgument( + "The number of arguments for attributes should be even " + "but attr_start = %d, attr_end = %d.", + attr_start, + attr_end)); + + using CastFuncType = void (*)(PyObject*, + paddle::framework::AttributeMap&, + const std::string&, + const std::string&, + ssize_t); + // Static map from keys to casting function pointers + static const std::unordered_map kAttrFuncMap = { {"forward_program", CastPyArg2AttrIRProgram}, {"backward_program", CastPyArg2AttrIRProgram}, {"is_test", CastPyArg2AttrBoolean}, - {"use_interpretorcore", CastPyArg2AttrBoolean}, {"in_sot_mode", CastPyArg2AttrBoolean}, - {"start_op_index", CastPyArg2AttrLong}, - {"end_op_index", CastPyArg2AttrLong}, {"program_id", CastPyArg2AttrLong}, - {"cuda_graph_pool_id", CastPyArg2AttrLong}, {"fx", CastPyArg2AttrValues}, {"fp", CastPyArg2AttrValues}, {"fm", CastPyArg2AttrValues}, @@ -1178,7 +1237,8 @@ void ConstructAttrMapForRunProgram( {"bo_g", CastPyArg2AttrValues}, {"bx_g", CastPyArg2AttrValues}, {"bp_g", CastPyArg2AttrValues}, - {"bo", CastPyArg2AttrValues}}; + {"bo", CastPyArg2AttrValues}, + }; PyObject* obj = nullptr; for (ssize_t arg_pos = attr_start; arg_pos < attr_end; arg_pos += 2) { diff --git a/paddle/fluid/pybind/op_function_common.h b/paddle/fluid/pybind/op_function_common.h index 9940ab4932cd2c..527df7bcc6fa70 100644 --- a/paddle/fluid/pybind/op_function_common.h +++ b/paddle/fluid/pybind/op_function_common.h @@ -206,6 +206,13 @@ void ConstructAttrMapFromPyArgs( ssize_t attr_end, paddle::framework::AttributeMap& attrs); // NOLINT +void ConstructAttrMapForLegacyRunProgram( + const std::string& op_type, + PyObject* args, + ssize_t attr_start, + ssize_t attr_end, + paddle::framework::AttributeMap& attrs); // NOLINT + void ConstructAttrMapForRunProgram( const std::string& op_type, PyObject* args, From 6c4ffb26cdd523c8f38bc0a78a6e1b5f934857a2 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Fri, 18 Apr 2025 09:33:34 +0800 Subject: [PATCH 5/5] add some missing attrs for translated layer --- paddle/fluid/pybind/op_function_common.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/paddle/fluid/pybind/op_function_common.cc b/paddle/fluid/pybind/op_function_common.cc index c188c468cfea06..8f4437cc1c2d22 100644 --- a/paddle/fluid/pybind/op_function_common.cc +++ b/paddle/fluid/pybind/op_function_common.cc @@ -1167,6 +1167,10 @@ void ConstructAttrMapForLegacyRunProgram( {"cuda_graph_capture_mode", CastPyArg2AttrString}, {"cuda_graph_pool_id", CastPyArg2AttrLong}, {"in_pir_pt_mode", CastPyArg2AttrBoolean}, + {"use_interpretorcore", CastPyArg2AttrBoolean}, + {"global_block", CastPyArg2AttrBlock}, + {"start_op_index", CastPyArg2AttrLong}, + {"end_op_index", CastPyArg2AttrLong}, }; PyObject* obj = nullptr;